diff --git a/.gitignore b/.gitignore index 900e5a53cbcf3bbb5e00389cca004c49f8600a66..bdcb067fc26d2a18ed88034ab616c08095794e17 100644 --- a/.gitignore +++ b/.gitignore @@ -4,12 +4,11 @@ node_modules /.bazelrc /.tf_configure.bazelrc /bazel-* -/third_party/py/numpy/numpy_include -/tools/bazel.rc +/bazel_pip +/third_party/eigen3/mkl_include +/third_party/mkl/* /tools/python_bin_path.sh /tools/git/gen -/util/python/python_include -/util/python/python_lib /pip_test /_python_build *.pyc diff --git a/.mention-bot b/.mention-bot deleted file mode 100644 index 9e4858977f5da2992ccc4053dfbbda3f5f86ee90..0000000000000000000000000000000000000000 --- a/.mention-bot +++ /dev/null @@ -1,11 +0,0 @@ -{ - "maxReviewers": 2, - "numFilesToCheck": 10, - "userBlacklist": ["tensorflower-gardener"], - "requiredOrgs": ["tensorflow"], - "skipAlreadyAssignedPR": true, - "skipAlreadyMentionedPR": true, - "skipTitle": "Branch", - "delayed": true, - "delayedUntil": "10m" -} diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 5ae5c0fbbcd5b8da7e3f3f98e01f455e0c82e588..c78b6b1a150c98fa379a87f935e77b5803837f11 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -27,3 +27,140 @@ contributions, often because we probably won't get to them right now. If you decide to start on an issue, leave a comment so that other people know that you're working on it. If you want to help out, but not alone, use the issue comment thread to coordinate. + +### Contribution guidelines and standards + +Before sending your pull request for +[review](https://github.com/tensorflow/tensorflow/pulls), +make sure your changes are consistent with the guidelines and follow the +TensorFlow coding style. + +#### General guidelines and philosophy for contribution + +* Include unit tests when you contribute new features, as they help to + a) prove that your code works correctly, b) guard against future breaking + changes to lower the maintenance cost. +* Bug fixes also generally require unit tests, because the presence of bugs + usually indicates insufficient test coverage. +* Keep API compatibility in mind when you change code in core TensorFlow, + e.g., code in [tensorflow/core](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core) and [tensorflow/python](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python). + TensorFlow has reached version 1 and hence cannot make + non-backward-compatible API changes without a major release. Reviewers of your + pull request will comment on any API compatibility issues. +* When you contribute a new feature to TensorFlow, the maintenance burden is (by + default) transferred to the TensorFlow team. This means that benefit of + contribution must be compared against the cost of maintaining the feature. +* Full new features (e.g., a new op implementing a cutting-edge algorithm) + typically will live in + [tensorflow/contrib](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib) + to get some airtime before decision is made regarding whether they are to be + migrated to the core. + +#### License + +Include a license at the top of new files. + +* [C/C++ license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/op.cc#L1) +* [Python license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/nn.py#L1) +* [Java license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/Graph.java#L1) +* [Go license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/go/operation.go#L1) +* [Bash license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/ci_build/ci_sanity.sh#L2) +* [HTML license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tensorboard/dist/index.html#L2) +* [JavaScript/TypeScript license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tensorboard/components/tf_backend/backend.ts#L1) + +Bazel BUILD files also need to include a license section, e.g., +[BUILD example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/BUILD#L61). + +#### C++ coding style + +Changes to TensorFlow C++ code should conform to +[Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html). + +Use `clang-tidy` to check your C/C++ changes. To install clang-tidy on ubuntu:16.04, do: + +```bash +apt-get install -y clang-tidy +``` + +You can check a C/C++ file by doing: + + +```bash +clang-format --style=google > /tmp/my_cc_file.cc +diff /tmp/my_cc_file.cc +``` + +#### Python coding style + +Changes to TensorFlow Python code should conform to +[Google Python Style Guide](https://google.github.io/styleguide/pyguide.html) + +Use `pylint` to check your Python changes. To install `pylint` and +retrieve TensorFlow's custom style definition: + +```bash +pip install pylint +wget -O /tmp/pylintrc https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/tools/ci_build/pylintrc +``` + +To check a file with `pylint`: + +```bash +pylint --rcfile=/tmp/pylintrc myfile.py +``` + +#### Coding style for other languages + +* [Google Java Style Guide](https://google.github.io/styleguide/javaguide.html) +* [Google JavaScript Style Guide](https://google.github.io/styleguide/jsguide.html) +* [Google Shell Style Guide](https://google.github.io/styleguide/shell.xml) + +#### Running sanity check + +If you have Docker installed on your system, you can perform a sanity check on +your changes by running the command: + +```bash +tensorflow/tools/ci_build/ci_build.sh CPU tensorflow/tools/ci_build/ci_sanity.sh +``` + +This will catch most license, Python coding style and BUILD file issues that +may exist in your changes. + +#### Running unit tests + +There are two ways to run TensorFlow unit tests. + +1. Using tools and libraries installed directly on your system. + + Refer to the + [CPU-only developer Dockerfile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/docker/Dockerfile.devel) and + [GPU developer Dockerfile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/docker/Dockerfile.devel-gpu) + for the required packages. Alternatively, use the said + [Docker images](https://hub.docker.com/r/tensorflow/tensorflow/tags/), e.g., + `tensorflow/tensorflow:nightly-devel` and `tensorflow/tensorflow:nightly-devel-gpu` + for development to avoid installing the packages directly on your system. + + Once you have the packages installed, you can run a specific unit test in + bazel by doing as follows: + + If the tests are to be run on GPU, add CUDA paths to LD_LIBRARY_PATH and add + the `cuda` option flag + + ```bash + export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" + + export flags="--config=opt --config=cuda -k" + ``` + + For example, to run all tests under tensorflow/python, do: + + ```bash + bazel test ${flags} //tensorflow/python/... + ``` + +2. Using Docker and TensorFlow's CI scripts. + + See + [TensorFlow Builds](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/ci_build) for details. + diff --git a/README.md b/README.md index ff1124b99d4ace40b520f91a34c5944f21bcf092..2878dab2601351dabbfbcadfbe6a4ae94864ce56 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ guidelines](CONTRIBUTING.md).** **We use [GitHub issues](https://github.com/tensorflow/tensorflow/issues) for tracking requests and bugs, but please see -[Community](tensorflow/docs_src/about/index.md#community) for general questions +[Community](https://www.tensorflow.org/community/) for general questions and discussion.** ## Installation diff --git a/RELEASE.md b/RELEASE.md index f078d336abb040edd81d7a5ded69f62d409119a4..d4e3bac01c6e250d81fb835a1058fe7316e4e0c2 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -18,6 +18,10 @@ If at all unsure, first test your code with TF 1.1; ensure it raises no errors, and then upgrade to TF 1.2. +## Bug Fixes and Other Changes +* In python, `Operation.get_attr` on type attributes returns the Python DType + version of the type to match expected get_attr documentation rather than the + protobuf enum. # Release 1.1.0 diff --git a/WORKSPACE b/WORKSPACE index cab8389a55ccfeddb9dc077c9b999edbe775f25d..b2d6fb542b0343b52cb7308102eef9478daba242 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -2,11 +2,11 @@ workspace(name = "org_tensorflow") http_archive( name = "io_bazel_rules_closure", - sha256 = "60fc6977908f999b23ca65698c2bb70213403824a84f7904310b6000d78be9ce", - strip_prefix = "rules_closure-5ca1dab6df9ad02050f7ba4e816407f88690cf7d", + sha256 = "4be8a887f6f38f883236e77bb25c2da10d506f2bf1a8e5d785c0f35574c74ca4", + strip_prefix = "rules_closure-aac19edc557aec9b603cd7ffe359401264ceff0d", urls = [ - "http://bazel-mirror.storage.googleapis.com/github.com/bazelbuild/rules_closure/archive/5ca1dab6df9ad02050f7ba4e816407f88690cf7d.tar.gz", # 2017-02-03 - "https://github.com/bazelbuild/rules_closure/archive/5ca1dab6df9ad02050f7ba4e816407f88690cf7d.tar.gz", + "http://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/aac19edc557aec9b603cd7ffe359401264ceff0d.tar.gz", # 2017-05-10 + "https://github.com/bazelbuild/rules_closure/archive/aac19edc557aec9b603cd7ffe359401264ceff0d.tar.gz", ], ) diff --git a/configure b/configure index fad3fdbebd944e2bb54b719a7f43a8be840fe0ea..308369efd32a2c12e0da5d818b1a704b755bff7f 100755 --- a/configure +++ b/configure @@ -35,12 +35,9 @@ function is_windows() { fi } -function sed_hyphen_i() { - if is_macos; then - sed -i '' "$@" - else - sed -i "$@" - fi +function sed_in_place() { + sed -e $1 $2 > "$2.bak" + mv "$2.bak" $2 } function write_to_bazelrc() { @@ -51,11 +48,126 @@ function write_action_env_to_bazelrc() { write_to_bazelrc "build --action_env $1=\"$2\"" } +function python_path { + "$PYTHON_BIN_PATH" - <&2 + if [ -z "$fromuser" ]; then + exit 1 + fi + PYTHON_BIN_PATH="" + # Retry + done + + if [ -z "$PYTHON_LIB_PATH" ]; then + # Split python_path into an array of paths, this allows path containing spaces + IFS=',' + python_lib_path=($(python_path)) + unset IFS + + if [ 1 = "$USE_DEFAULT_PYTHON_LIB_PATH" ]; then + PYTHON_LIB_PATH=${python_lib_path[0]} + echo "Using python library path: $PYTHON_LIB_PATH" + + else + echo "Found possible Python library paths:" + for x in "${python_lib_path[@]}"; do + echo " $x" + done + set -- "${python_lib_path[@]}" + echo "Please input the desired Python library path to use. Default is ["$1"]" + read b || true + if [ "$b" == "" ]; then + PYTHON_LIB_PATH=${python_lib_path[0]} + echo "Using python library path: $PYTHON_LIB_PATH" + else + PYTHON_LIB_PATH="$b" + fi + fi + fi + + if [ ! -x "$PYTHON_BIN_PATH" ] || [ -d "$PYTHON_BIN_PATH" ]; then + echo "PYTHON_BIN_PATH is not executable. Is it the python binary?" + exit 1 + fi + + local python_major_version=$("${PYTHON_BIN_PATH}" -c 'from __future__ import print_function; import sys; print(sys.version_info[0]);') + if [ "$python_major_version" == "" ]; then + echo -e "\n\nERROR: Problem getting python version. Is $PYTHON_BIN_PATH the correct python binary?" + exit 1 + fi + + # Convert python path to Windows style before writing into bazel.rc + if is_windows; then + PYTHON_BIN_PATH="$(cygpath -m "$PYTHON_BIN_PATH")" + fi + + # Set-up env variables used by python_configure.bzl + write_action_env_to_bazelrc "PYTHON_BIN_PATH" "$PYTHON_BIN_PATH" + write_action_env_to_bazelrc "PYTHON_LIB_PATH" "$PYTHON_LIB_PATH" + write_to_bazelrc "build --define PYTHON_BIN_PATH=\"$PYTHON_BIN_PATH\"" + write_to_bazelrc "build --define PYTHON_LIB_PATH=\"$PYTHON_LIB_PATH\"" + write_to_bazelrc "build --force_python=py$python_major_version" + write_to_bazelrc "build --host_force_python=py$python_major_version" + write_to_bazelrc "build --python${python_major_version}_path=\"$PYTHON_BIN_PATH\"" + write_to_bazelrc "test --force_python=py$python_major_version" + write_to_bazelrc "test --host_force_python=py$python_major_version" + write_to_bazelrc "test --define PYTHON_BIN_PATH=\"$PYTHON_BIN_PATH\"" + write_to_bazelrc "test --define PYTHON_LIB_PATH=\"$PYTHON_LIB_PATH\"" + write_to_bazelrc "run --define PYTHON_BIN_PATH=\"$PYTHON_BIN_PATH\"" + write_to_bazelrc "run --define PYTHON_LIB_PATH=\"$PYTHON_LIB_PATH\"" + + # Write tools/python_bin_path.sh + echo "export PYTHON_BIN_PATH=\"$PYTHON_BIN_PATH\"" > tools/python_bin_path.sh +} + # This file contains customized config settings. rm -f .tf_configure.bazelrc touch .tf_configure.bazelrc touch .bazelrc -sed_hyphen_i "/tf_configure/d" .bazelrc +sed_in_place "/tf_configure/d" .bazelrc echo "import %workspace%/.tf_configure.bazelrc" >> .bazelrc # Delete any leftover BUILD files from the Makefile build, which would interfere @@ -65,61 +177,63 @@ if [ -d "${MAKEFILE_DOWNLOAD_DIR}" ]; then find ${MAKEFILE_DOWNLOAD_DIR} -type f -name '*BUILD' -delete fi -## Set up python-related environment settings -while true; do +setup_python + +## Set up MKL related environment settings +while [ "$TF_NEED_MKL" == "" ]; do fromuser="" - if [ -z "$PYTHON_BIN_PATH" ]; then - default_python_bin_path=$(which python || which python3 || true) - read -p "Please specify the location of python. [Default is $default_python_bin_path]: " PYTHON_BIN_PATH - fromuser="1" - if [ -z "$PYTHON_BIN_PATH" ]; then - PYTHON_BIN_PATH=$default_python_bin_path - fi - fi - if [ -e "$PYTHON_BIN_PATH" ]; then - break - fi - echo "Invalid python path. ${PYTHON_BIN_PATH} cannot be found" 1>&2 - if [ -z "$fromuser" ]; then - exit 1 - fi - PYTHON_BIN_PATH="" - # Retry + read -p "Do you wish to build TensorFlow with MKL support? [y/N] " INPUT + fromuser="1" + case $INPUT in + [Yy]* ) echo "MKL support will be enabled for TensorFlow"; TF_NEED_MKL=1;; + [Nn]* ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;; + "" ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;; + * ) echo "Invalid selection: " $INPUT;; + esac done -export PYTHON_BIN_PATH -write_action_env_to_bazelrc "PYTHON_BIN_PATH" "$PYTHON_BIN_PATH" -# TODO(ngiraldo): allow the user to optionally set PYTHON_INCLUDE_PATH and NUMPY_INCLUDE_PATH -## Set up MKL related environment settings -if false; then # Disable building with MKL for now - while [ "$TF_NEED_MKL" == "" ]; do +OSNAME=`uname -s` + +if [ "$TF_NEED_MKL" == "1" ]; then # TF_NEED_MKL + while [ "$TF_DOWNLOAD_MKL" == "" ]; do fromuser="" - read -p "Do you wish to build TensorFlow with MKL support (experimental)? [y/N] " INPUT + read -p "Do you wish to download MKL LIB from the web? [Y/n] " INPUT fromuser="1" case $INPUT in - [Yy]* ) echo "MKL support (experimental) (will be enabled for TensorFlow"; TF_NEED_MKL=1;; - [Nn]* ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;; - "" ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;; - * ) echo "Invalid selection: " $INPUT;; + [Yy]* ) TF_DOWNLOAD_MKL=1;; + [Nn]* ) TF_DOWNLOAD_MKL=0;; + "" ) TF_DOWNLOAD_MKL=1;; + * ) echo "Invalid selection: " $INPUT; exit 1;; esac done - OSNAME=`uname -s` - - if [ "$TF_NEED_MKL" == "1" ]; then # TF_NEED_MKL + if [[ "$TF_DOWNLOAD_MKL" == "1" ]]; then DST=`dirname $0` - ARCHIVE_BASENAME=mklml_lnx_2017.0.2.20170209.tgz - GITHUB_RELEASE_TAG=v0.5 + ARCHIVE_BASENAME=mklml_lnx_2018.0.20170425.tgz + GITHUB_RELEASE_TAG=v0.7 MKLURL="https://github.com/01org/mkl-dnn/releases/download/$GITHUB_RELEASE_TAG/$ARCHIVE_BASENAME" - if ! [ -e "$DST/third_party/mkl/$ARCHIVE_BASENAME" ]; then - wget --no-check-certificate -P $DST/third_party/mkl/ $MKLURL + if ! [ -e "${DST}/third_party/mkl/${ARCHIVE_BASENAME}" ]; then + curl -fSsL -o "${DST}/third_party/mkl/${ARCHIVE_BASENAME}" "${MKLURL}" fi tar -xzf $DST/third_party/mkl/$ARCHIVE_BASENAME -C $DST/third_party/mkl/ extracted_dir_name="${ARCHIVE_BASENAME%.*}" MKL_INSTALL_PATH=$DST/third_party/mkl/$extracted_dir_name MKL_INSTALL_PATH=`${PYTHON_BIN_PATH} -c "import os; print(os.path.realpath(os.path.expanduser('${MKL_INSTALL_PATH}')))"` - if [ "$OSNAME" == "Linux" ]; then + else + default_mkl_path=/opt/intel/mklml + fromuser="" + read -p "Please specify the location where MKL is installed. [Default is $default_mkl_path]: " MKL_INSTALL_PATH + fromuser="1" + if [ -z "$MKL_INSTALL_PATH" ]; then + MKL_INSTALL_PATH=$default_mkl_path + fi + # Result returned from "read" will be used unexpanded. That make "~" unuseable. + # Going through one more level of expansion to handle that. + MKL_INSTALL_PATH=`${PYTHON_BIN_PATH} -c "import os; print(os.path.realpath(os.path.expanduser('${MKL_INSTALL_PATH}')))"` + fi + + if [ "$OSNAME" == "Linux" ]; then # Full MKL configuration MKL_RT_LIB_PATH="lib/intel64/libmkl_rt.so" #${TF_MKL_EXT}#TODO version? MKL_RT_OMP_LIB_PATH="../compiler/lib/intel64/libiomp5.so" #TODO VERSION? @@ -127,24 +241,29 @@ if false; then # Disable building with MKL for now # MKL-ML configuration MKL_ML_LIB_PATH="lib/libmklml_intel.so" #${TF_MKL_EXT}#TODO version? MKL_ML_OMP_LIB_PATH="lib/libiomp5.so" #TODO VERSION? - elif [ "$OSNAME" == "Darwin" ]; then + elif [ "$OSNAME" == "Darwin" ]; then echo "Darwin is unsupported yet"; exit 1 - fi + fi - if [ -e "$MKL_INSTALL_PATH/${MKL_ML_LIB_PATH}" ]; then + if [ -e "$MKL_INSTALL_PATH/${MKL_ML_LIB_PATH}" ]; then ln -sf $MKL_INSTALL_PATH/${MKL_ML_LIB_PATH} third_party/mkl/ ln -sf $MKL_INSTALL_PATH/${MKL_ML_OMP_LIB_PATH} third_party/mkl/ ln -sf $MKL_INSTALL_PATH/include third_party/mkl/ ln -sf $MKL_INSTALL_PATH/include third_party/eigen3/mkl_include - else - echo "ERROR: $MKL_INSTALL_PATH/${MKL_ML_LIB_PATH} does not exist"; - exit 1 - fi - - if [ -z "$fromuser" ]; then + loc=$(locate -e libdl.so.2 | sed -n 1p) + ln -sf $loc third_party/mkl/libdl.so.2 + elif [ -e "$MKL_INSTALL_PATH/${MKL_RT_LIB_PATH}" ]; then + ln -sf $MKL_INSTALL_PATH/${MKL_RT_LIB_PATH} third_party/mkl/ + ln -sf $MKL_INSTALL_PATH/${MKL_RT_OMP_LIB_PATH} third_party/mkl/ + ln -sf $MKL_INSTALL_PATH/include third_party/mkl/ + ln -sf $MKL_INSTALL_PATH/include third_party/eigen3/mkl_include + loc=$(locate -e libdl.so.2 | sed -n 1p) + ln -sf $loc third_party/mkl/libdl.so.2 + else + echo "ERROR: $MKL_INSTALL_PATH/${MKL_ML_LIB_PATH} nor $MKL_INSTALL_PATH/${MKL_RT_LIB_PATH} exists"; exit 1 - fi + fi cat > third_party/mkl/mkl.config < third_party/mkl/mkl.config <> tools/bazel.rc for opt in $CC_OPT_FLAGS; do - echo "build:opt --cxxopt=$opt --copt=$opt" >> tools/bazel.rc + write_to_bazelrc "build:opt --cxxopt=$opt --copt=$opt" done # Run the gen_git_source to create links where bazel can track dependencies for @@ -321,31 +435,6 @@ done export TF_CUDA_CLANG write_action_env_to_bazelrc "TF_CUDA_CLANG" "$TF_CUDA_CLANG" -# Set up which gcc nvcc should use as the host compiler -# No need to set this on Windows -while [[ "$TF_CUDA_CLANG" != "1" ]] && ! is_windows && true; do - fromuser="" - if [ -z "$GCC_HOST_COMPILER_PATH" ]; then - default_gcc_host_compiler_path=$(which gcc || true) - read -p "Please specify which gcc should be used by nvcc as the host compiler. [Default is $default_gcc_host_compiler_path]: " GCC_HOST_COMPILER_PATH - fromuser="1" - if [ -z "$GCC_HOST_COMPILER_PATH" ]; then - GCC_HOST_COMPILER_PATH="$default_gcc_host_compiler_path" - fi - fi - if [ -e "$GCC_HOST_COMPILER_PATH" ]; then - export GCC_HOST_COMPILER_PATH - write_action_env_to_bazelrc "GCC_HOST_COMPILER_PATH" "$GCC_HOST_COMPILER_PATH" - break - fi - echo "Invalid gcc path. ${GCC_HOST_COMPILER_PATH} cannot be found" 1>&2 - if [ -z "$fromuser" ]; then - exit 1 - fi - GCC_HOST_COMPILER_PATH="" - # Retry -done - # Set up which clang we should use as the cuda / host compiler. while [[ "$TF_CUDA_CLANG" == "1" ]] && true; do fromuser="" @@ -386,6 +475,11 @@ while true; do else default_cuda_path="$(cygpath -m "$CUDA_PATH")" fi + elif is_linux; then + # If the default doesn't exist, try an alternative default. + if [ ! -d $default_cuda_path ] && [ -d /opt/cuda ]; then + default_cuda_path=/opt/cuda + fi fi read -p "Please specify the location where CUDA $TF_CUDA_VERSION toolkit is installed. Refer to README.md for more details. [Default is $default_cuda_path]: " CUDA_TOOLKIT_PATH fromuser="1" @@ -425,6 +519,35 @@ while true; do CUDA_TOOLKIT_PATH="" done +# Set up which gcc nvcc should use as the host compiler +# No need to set this on Windows +while [[ "$TF_CUDA_CLANG" != "1" ]] && ! is_windows && true; do + fromuser="" + if [ -z "$GCC_HOST_COMPILER_PATH" ]; then + default_gcc_host_compiler_path=$(which gcc || true) + cuda_bin_symlink="$CUDA_TOOLKIT_PATH/bin/gcc" + if [ -L "$cuda_bin_symlink" ]; then + default_gcc_host_compiler_path=$(readlink $cuda_bin_symlink) + fi + read -p "Please specify which gcc should be used by nvcc as the host compiler. [Default is $default_gcc_host_compiler_path]: " GCC_HOST_COMPILER_PATH + fromuser="1" + if [ -z "$GCC_HOST_COMPILER_PATH" ]; then + GCC_HOST_COMPILER_PATH="$default_gcc_host_compiler_path" + fi + fi + if [ -e "$GCC_HOST_COMPILER_PATH" ]; then + export GCC_HOST_COMPILER_PATH + write_action_env_to_bazelrc "GCC_HOST_COMPILER_PATH" "$GCC_HOST_COMPILER_PATH" + break + fi + echo "Invalid gcc path. ${GCC_HOST_COMPILER_PATH} cannot be found" 1>&2 + if [ -z "$fromuser" ]; then + exit 1 + fi + GCC_HOST_COMPILER_PATH="" + # Retry +done + # Find out where the cuDNN library is installed while true; do # Configure the cuDNN version to use. diff --git a/tensorflow/BUILD b/tensorflow/BUILD index cca5a80314387ae0e6a33d3f318d03c49f5d8b0e..e3b88291057e6bec39a5c23e00d96152960c2bbf 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -202,7 +202,6 @@ filegroup( "//tensorflow/compiler/xla/client:all_files", "//tensorflow/compiler/xla/client/lib:all_files", "//tensorflow/compiler/xla/legacy_flags:all_files", - "//tensorflow/compiler/xla/port:all_files", "//tensorflow/compiler/xla/service:all_files", "//tensorflow/compiler/xla/service/cpu:all_files", "//tensorflow/compiler/xla/service/gpu:all_files", @@ -213,6 +212,7 @@ filegroup( "//tensorflow/contrib:all_files", "//tensorflow/contrib/android:all_files", "//tensorflow/contrib/batching:all_files", + "//tensorflow/contrib/batching/kernels:all_files", "//tensorflow/contrib/batching/test_util:all_files", "//tensorflow/contrib/batching/util:all_files", "//tensorflow/contrib/bayesflow:all_files", @@ -262,6 +262,7 @@ filegroup( "//tensorflow/contrib/seq2seq:all_files", "//tensorflow/contrib/session_bundle:all_files", "//tensorflow/contrib/session_bundle/example:all_files", + "//tensorflow/contrib/signal:all_files", "//tensorflow/contrib/slim:all_files", "//tensorflow/contrib/slim/python/slim/data:all_files", "//tensorflow/contrib/slim/python/slim/nets:all_files", @@ -289,6 +290,7 @@ filegroup( "//tensorflow/core/grappler/costs:all_files", "//tensorflow/core/grappler/inputs:all_files", "//tensorflow/core/grappler/optimizers:all_files", + "//tensorflow/core/grappler/utils:all_files", "//tensorflow/core/kernels:all_files", "//tensorflow/core/kernels/hexagon:all_files", "//tensorflow/core/ops/compat:all_files", @@ -298,6 +300,7 @@ filegroup( "//tensorflow/core/util/ctc:all_files", "//tensorflow/core/util/tensor_bundle:all_files", "//tensorflow/examples/android:all_files", + "//tensorflow/examples/benchmark:all_files", "//tensorflow/examples/how_tos/reading_data:all_files", "//tensorflow/examples/image_retraining:all_files", "//tensorflow/examples/label_image:all_files", @@ -314,7 +317,10 @@ filegroup( "//tensorflow/python:all_files", "//tensorflow/python/debug:all_files", "//tensorflow/python/estimator:all_files", + "//tensorflow/python/feature_column:all_files", "//tensorflow/python/kernel_tests:all_files", + "//tensorflow/python/kernel_tests/distributions:all_files", + "//tensorflow/python/ops/distributions:all_files", "//tensorflow/python/saved_model:all_files", "//tensorflow/python/tools:all_files", "//tensorflow/tensorboard:all_files", @@ -322,6 +328,100 @@ filegroup( "//tensorflow/tensorboard/backend:all_files", "//tensorflow/tensorboard/backend/event_processing:all_files", "//tensorflow/tensorboard/components:all_files", + "//tensorflow/tensorboard/components/tf_audio_dashboard:all_files", + "//tensorflow/tensorboard/components/tf_audio_dashboard/demo:all_files", + "//tensorflow/tensorboard/components/tf_audio_dashboard/demo/data:all_files", + "//tensorflow/tensorboard/components/tf_audio_dashboard_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_backend:all_files", + "//tensorflow/tensorboard/components/tf_backend_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_backend_d3v4/test:all_files", + "//tensorflow/tensorboard/components/tf_color_scale:all_files", + "//tensorflow/tensorboard/components/tf_color_scale/demo:all_files", + "//tensorflow/tensorboard/components/tf_color_scale_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_color_scale_d3v4/test:all_files", + "//tensorflow/tensorboard/components/tf_dashboard_common:all_files", + "//tensorflow/tensorboard/components/tf_dashboard_common/demo:all_files", + "//tensorflow/tensorboard/components/tf_dashboard_common_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_dashboard_common_d3v4/test:all_files", + "//tensorflow/tensorboard/components/tf_distribution_dashboard:all_files", + "//tensorflow/tensorboard/components/tf_distribution_dashboard/demo:all_files", + "//tensorflow/tensorboard/components/tf_distribution_dashboard/demo/data:all_files", + "//tensorflow/tensorboard/components/tf_distribution_dashboard_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_globals:all_files", + "//tensorflow/tensorboard/components/tf_globals_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_graph:all_files", + "//tensorflow/tensorboard/components/tf_graph/demo:all_files", + "//tensorflow/tensorboard/components/tf_graph_app:all_files", + "//tensorflow/tensorboard/components/tf_graph_app/demo:all_files", + "//tensorflow/tensorboard/components/tf_graph_app_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_graph_app_d3v4/demo:all_files", + "//tensorflow/tensorboard/components/tf_graph_board:all_files", + "//tensorflow/tensorboard/components/tf_graph_board/demo:all_files", + "//tensorflow/tensorboard/components/tf_graph_board_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_graph_board_d3v4/demo:all_files", + "//tensorflow/tensorboard/components/tf_graph_common:all_files", + "//tensorflow/tensorboard/components/tf_graph_common_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_graph_controls:all_files", + "//tensorflow/tensorboard/components/tf_graph_controls/demo:all_files", + "//tensorflow/tensorboard/components/tf_graph_controls_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_graph_controls_d3v4/demo:all_files", + "//tensorflow/tensorboard/components/tf_graph_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_graph_d3v4/demo:all_files", + "//tensorflow/tensorboard/components/tf_graph_dashboard:all_files", + "//tensorflow/tensorboard/components/tf_graph_dashboard/demo:all_files", + "//tensorflow/tensorboard/components/tf_graph_dashboard_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_graph_dashboard_d3v4/demo:all_files", + "//tensorflow/tensorboard/components/tf_graph_info:all_files", + "//tensorflow/tensorboard/components/tf_graph_info/demo:all_files", + "//tensorflow/tensorboard/components/tf_graph_info_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_graph_info_d3v4/demo:all_files", + "//tensorflow/tensorboard/components/tf_graph_loader:all_files", + "//tensorflow/tensorboard/components/tf_graph_loader/demo:all_files", + "//tensorflow/tensorboard/components/tf_graph_loader_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_graph_loader_d3v4/demo:all_files", + "//tensorflow/tensorboard/components/tf_histogram_dashboard:all_files", + "//tensorflow/tensorboard/components/tf_histogram_dashboard/demo:all_files", + "//tensorflow/tensorboard/components/tf_histogram_dashboard/demo/data:all_files", + "//tensorflow/tensorboard/components/tf_histogram_dashboard_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_image_dashboard:all_files", + "//tensorflow/tensorboard/components/tf_image_dashboard/demo:all_files", + "//tensorflow/tensorboard/components/tf_image_dashboard/demo/data:all_files", + "//tensorflow/tensorboard/components/tf_image_dashboard_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_imports:all_files", + "//tensorflow/tensorboard/components/tf_imports_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_option_selector_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_scalar_dashboard:all_files", + "//tensorflow/tensorboard/components/tf_scalar_dashboard/demo:all_files", + "//tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data:all_files", + "//tensorflow/tensorboard/components/tf_scalar_dashboard_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_scalar_dashboard_d3v4/demo:all_files", + "//tensorflow/tensorboard/components/tf_storage:all_files", + "//tensorflow/tensorboard/components/tf_storage_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_storage_d3v4/test:all_files", + "//tensorflow/tensorboard/components/tf_tensorboard_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_text_dashboard:all_files", + "//tensorflow/tensorboard/components/tf_text_dashboard/demo:all_files", + "//tensorflow/tensorboard/components/tf_text_dashboard/demo/data:all_files", + "//tensorflow/tensorboard/components/tf_text_dashboard_d3v4:all_files", + "//tensorflow/tensorboard/components/vz_data_summary:all_files", + "//tensorflow/tensorboard/components/vz_distribution_chart:all_files", + "//tensorflow/tensorboard/components/vz_distribution_chart/demo:all_files", + "//tensorflow/tensorboard/components/vz_distribution_chart_d3v4:all_files", + "//tensorflow/tensorboard/components/vz_histogram_timeseries:all_files", + "//tensorflow/tensorboard/components/vz_histogram_timeseries/demo:all_files", + "//tensorflow/tensorboard/components/vz_histogram_timeseries_d3v4:all_files", + "//tensorflow/tensorboard/components/vz_line_chart:all_files", + "//tensorflow/tensorboard/components/vz_line_chart/demo:all_files", + "//tensorflow/tensorboard/components/vz_line_chart_d3v4:all_files", + "//tensorflow/tensorboard/components/vz_projector:all_files", + "//tensorflow/tensorboard/components/vz_projector_d3v4:all_files", + "//tensorflow/tensorboard/components/vz_projector_d3v4/test:all_files", + "//tensorflow/tensorboard/components/vz_sorting:all_files", + "//tensorflow/tensorboard/components/vz_sorting/test:all_files", + "//tensorflow/tensorboard/components/vz_sorting_d3v4:all_files", + "//tensorflow/tensorboard/components/vz_sorting_d3v4/test:all_files", + "//tensorflow/tensorboard/demo:all_files", + "//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:all_files", "//tensorflow/tensorboard/lib:all_files", "//tensorflow/tensorboard/plugins:all_files", "//tensorflow/tensorboard/plugins/projector:all_files", diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 4ad69ae3fbdfbb8e3ab3c868fea4976c59dd9e71..3ab4e8efcdb5b05cf8922edd302e7cbf3a3597f1 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -58,6 +58,7 @@ tf_cuda_library( "//tensorflow/cc/saved_model:loader", "//tensorflow/cc:gradients", "//tensorflow/cc:ops", + "//tensorflow/cc:grad_ops", "//tensorflow/cc:scope_internal", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 0f66a47b4ad66ecfbbdfbed15dddb378c62308b9..f4775783f9f88c941445b62603c92cae00d34715 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -738,8 +738,7 @@ tensorflow::string OutputName(const TF_Output& output) { const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper, const char* attr_name, TF_Status* status) { - const tensorflow::AttrValue* attr = - tensorflow::AttrSlice(oper->node.def()).Find(attr_name); + const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name); if (attr == nullptr) { status->status = InvalidArgument("Operation has no attr named '", attr_name, "'."); @@ -1101,14 +1100,14 @@ static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc, if (status->status.ok()) { // Run shape inference function for newly added node. - // - // TODO(b/28152992): Enable returning the result of this - // code-path once we have converted all python shape functions - // to call their C++ versions. - desc->graph->refiner.AddNode(ret).IgnoreError(); - + status->status = desc->graph->refiner.AddNode(ret); + } + if (status->status.ok()) { // Add the node to the name-to-node mapping. desc->graph->name_map[ret->name()] = ret; + } else if (ret != nullptr) { + desc->graph->graph.RemoveNode(ret); + ret = nullptr; } } @@ -1135,7 +1134,7 @@ const char* TF_OperationOpType(TF_Operation* oper) { } const char* TF_OperationDevice(TF_Operation* oper) { - return oper->node.def().device().c_str(); + return oper->node.requested_device().c_str(); } int TF_OperationNumOutputs(TF_Operation* oper) { @@ -1150,8 +1149,8 @@ TF_DataType TF_OperationOutputType(TF_Output oper_out) { int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name, TF_Status* status) { NameRangeMap name_ranges; - status->status = NameRangesForNode(oper->node.def(), oper->node.op_def(), - nullptr, &name_ranges); + status->status = + NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges); if (!status->status.ok()) return -1; auto iter = name_ranges.find(arg_name); if (iter == name_ranges.end()) { @@ -1172,8 +1171,8 @@ TF_DataType TF_OperationInputType(TF_Input oper_in) { int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name, TF_Status* status) { NameRangeMap name_ranges; - status->status = NameRangesForNode(oper->node.def(), oper->node.op_def(), - &name_ranges, nullptr); + status->status = + NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr); if (!status->status.ok()) return -1; auto iter = name_ranges.find(arg_name); if (iter == name_ranges.end()) { @@ -1411,26 +1410,27 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name, } } -#define DEFINE_GETATTR(func, c_type, cpp_type, list_field) \ - void func(TF_Operation* oper, const char* attr_name, c_type* value, \ - TF_Status* status) { \ - cpp_type v; \ - status->status = tensorflow::GetNodeAttr(oper->node.def(), attr_name, &v); \ - *value = static_cast(v); \ - } \ - void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \ - int max_values, TF_Status* status) { \ - const auto* attr = GetAttrValue(oper, attr_name, status); \ - if (!status->status.ok()) return; \ - if (attr->value_case() != tensorflow::AttrValue::kList) { \ - status->status = \ - InvalidArgument("Value for '", attr_name, "' is not a list."); \ - return; \ - } \ - const auto len = std::min(max_values, attr->list().list_field##_size()); \ - for (int i = 0; i < len; ++i) { \ - values[i] = static_cast(attr->list().list_field(i)); \ - } \ +#define DEFINE_GETATTR(func, c_type, cpp_type, list_field) \ + void func(TF_Operation* oper, const char* attr_name, c_type* value, \ + TF_Status* status) { \ + cpp_type v; \ + status->status = \ + tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v); \ + *value = static_cast(v); \ + } \ + void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \ + int max_values, TF_Status* status) { \ + const auto* attr = GetAttrValue(oper, attr_name, status); \ + if (!status->status.ok()) return; \ + if (attr->value_case() != tensorflow::AttrValue::kList) { \ + status->status = \ + InvalidArgument("Value for '", attr_name, "' is not a list."); \ + return; \ + } \ + const auto len = std::min(max_values, attr->list().list_field##_size()); \ + for (int i = 0; i < len; ++i) { \ + values[i] = static_cast(attr->list().list_field(i)); \ + } \ } DEFINE_GETATTR(TF_OperationGetAttrInt, int64_t, tensorflow::int64, i); DEFINE_GETATTR(TF_OperationGetAttrFloat, float, float, f); @@ -1441,7 +1441,8 @@ DEFINE_GETATTR(TF_OperationGetAttrType, TF_DataType, DataType, type); void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name, int64_t* value, int num_dims, TF_Status* status) { PartialTensorShape shape; - status->status = tensorflow::GetNodeAttr(oper->node.def(), attr_name, &shape); + status->status = + tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape); if (!status->status.ok()) return; auto len = std::min(shape.dims(), num_dims); for (int i = 0; i < len; ++i) { @@ -1455,7 +1456,7 @@ void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name, int storage_size, TF_Status* status) { std::vector shapes; status->status = - tensorflow::GetNodeAttr(oper->node.def(), attr_name, &shapes); + tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes); if (!status->status.ok()) return; auto len = std::min(static_cast(shapes.size()), max_values); int64_t* p = storage; @@ -1522,7 +1523,7 @@ void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name, TF_Tensor** value, TF_Status* status) { *value = nullptr; Tensor t; - status->status = tensorflow::GetNodeAttr(oper->node.def(), attr_name, &t); + status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t); if (!status->status.ok()) return; *value = new TF_Tensor{static_cast(t.dtype()), t.shape(), tensorflow::TensorCApi::Buffer(t)}; @@ -1533,7 +1534,7 @@ void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name, TF_Tensor** values, int max_values, TF_Status* status) { std::vector ts; - status->status = tensorflow::GetNodeAttr(oper->node.def(), attr_name, &ts); + status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts); if (!status->status.ok()) return; const auto len = std::min(max_values, static_cast(ts.size())); for (int i = 0; i < len; ++i) { diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index e2aeef0d88f87f0e1567db81576c8639fe82b01b..ec9b01b388d1138644e28e3206e32726347b3d5e 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -95,7 +95,7 @@ TF_CAPI_EXPORT extern const char* TF_Version(); // -------------------------------------------------------------------------- // TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. // The enum values here are identical to corresponding values in types.proto. -typedef enum { +typedef enum TF_DataType { TF_FLOAT = 1, TF_DOUBLE = 2, TF_INT32 = 3, // Int32 tensors are always in 'host' memory. @@ -127,7 +127,7 @@ TF_CAPI_EXPORT extern size_t TF_DataTypeSize(TF_DataType dt); // -------------------------------------------------------------------------- // TF_Code holds an error code. The enum values here are identical to // corresponding values in error_codes.proto. -typedef enum { +typedef enum TF_Code { TF_OK = 0, TF_CANCELLED = 1, TF_UNKNOWN = 2, @@ -629,7 +629,7 @@ TF_CAPI_EXPORT extern int TF_OperationGetControlOutputs( int max_control_outputs); // TF_AttrType describes the type of the value of an attribute on an operation. -typedef enum { +typedef enum TF_AttrType { TF_ATTR_STRING = 0, TF_ATTR_INT = 1, TF_ATTR_FLOAT = 2, diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 0ddc59db20e6d8cf08f37155431285b69c625302..cdb7406c86e8b10d24c303615d13089272bcab5d 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/tag_constants.h" #include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/example/feature.pb.h" +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/graph.pb_text.h" #include "tensorflow/core/framework/node_def.pb_text.h" #include "tensorflow/core/framework/node_def_util.h" @@ -278,6 +279,19 @@ static void Int32Deallocator(void* data, size_t, void* arg) { delete[] static_cast(data); } +// Create a tensor with values of type TF_INT8 provided by `values`. +static TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, + const char* values) { + int64_t num_values = 1; + for (int i = 0; i < num_dims; ++i) { + num_values *= dims[i]; + } + TF_Tensor* t = + TF_AllocateTensor(TF_INT8, dims, num_dims, sizeof(char) * num_values); + memcpy(TF_TensorData(t), values, sizeof(char) * num_values); + return t; +} + static TF_Tensor* Int32Tensor(int32 v) { const int num_bytes = sizeof(int32); int32* values = new int32[1]; @@ -293,16 +307,21 @@ TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, return TF_FinishOperation(desc, s); } -TF_Operation* ScalarConst(int32 v, TF_Graph* graph, TF_Status* s, - const char* name = "scalar") { - unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor); +TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, + const char* name = "const") { TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name); - TF_SetAttrTensor(desc, "value", tensor.get(), s); + TF_SetAttrTensor(desc, "value", t, s); if (TF_GetCode(s) != TF_OK) return nullptr; - TF_SetAttrType(desc, "dtype", TF_INT32); + TF_SetAttrType(desc, "dtype", TF_TensorType(t)); return TF_FinishOperation(desc, s); } +TF_Operation* ScalarConst(int32 v, TF_Graph* graph, TF_Status* s, + const char* name = "scalar") { + unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor); + return Const(tensor.get(), graph, s, name); +} + TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, const char* name = "add") { TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); @@ -1093,6 +1112,35 @@ TEST(CAPI, SessionPRun) { TF_DeleteStatus(s); } +TEST(CAPI, ShapeInferenceError) { + // TF_FinishOperation should fail if the shape of the added operation cannot + // be inferred. + TF_Status* status = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + // Create this failure by trying to add two nodes with incompatible shapes + // (A tensor with shape [2] and a tensor with shape [3] cannot be added). + const char data[] = {1, 2, 3}; + const int64_t vec2_dims[] = {2}; + unique_tensor_ptr vec2_tensor( + Int8Tensor(vec2_dims, TF_ARRAYSIZE(vec2_dims), data), TF_DeleteTensor); + TF_Operation* vec2 = Const(vec2_tensor.get(), graph, status, "vec2"); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + const int64_t vec3_dims[] = {3}; + unique_tensor_ptr vec3_tensor( + Int8Tensor(vec3_dims, TF_ARRAYSIZE(vec3_dims), data), TF_DeleteTensor); + TF_Operation* vec3 = Const(vec3_tensor.get(), graph, status, "vec3"); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_Operation* add = Add(vec2, vec3, graph, status); + ASSERT_NE(TF_OK, TF_GetCode(status)); + ASSERT_TRUE(add == nullptr); + + TF_DeleteGraph(graph); + TF_DeleteStatus(status); +} + TEST(CAPI, ColocateWith) { TF_Status* s = TF_NewStatus(); TF_Graph* graph = TF_NewGraph(); @@ -1535,7 +1583,8 @@ Test op with no grad registered. x: input y: output -)doc"); +)doc") + .SetShapeFn(tensorflow::shape_inference::UnknownShape); class CApiGradientsTest : public ::testing::Test { protected: @@ -1801,18 +1850,6 @@ TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_NoGradInputs) { TestGradientsError(false); } -// Create a tensor with values of type TF_INT8 provided by `values`. -TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) { - int64_t num_values = 1; - for (int i = 0; i < num_dims; ++i) { - num_values *= dims[i]; - } - TF_Tensor* t = - TF_AllocateTensor(TF_INT8, dims, num_dims, sizeof(char) * num_values); - memcpy(TF_TensorData(t), values, sizeof(char) * num_values); - return t; -} - void StringVectorToArrays(const std::vector& v, std::unique_ptr* ptrs, std::unique_ptr* lens) { @@ -1828,9 +1865,13 @@ void StringVectorToArrays(const std::vector& v, // Registers two ops, each with a single attribute called 'v'. // The attribute in one op will have a type 'type', the other // will have list(type). -#define ATTR_TEST_REGISTER_OP(type) \ - REGISTER_OP("CApiAttributesTestOp" #type).Attr("v: " #type); \ - REGISTER_OP("CApiAttributesTestOpList" #type).Attr("v: list(" #type ")") +#define ATTR_TEST_REGISTER_OP(type) \ + REGISTER_OP("CApiAttributesTestOp" #type) \ + .Attr("v: " #type) \ + .SetShapeFn(tensorflow::shape_inference::UnknownShape); \ + REGISTER_OP("CApiAttributesTestOpList" #type) \ + .Attr("v: list(" #type ")") \ + .SetShapeFn(tensorflow::shape_inference::UnknownShape) ATTR_TEST_REGISTER_OP(string); ATTR_TEST_REGISTER_OP(int); ATTR_TEST_REGISTER_OP(float); diff --git a/tensorflow/c/generate-pc.sh b/tensorflow/c/generate-pc.sh new file mode 100755 index 0000000000000000000000000000000000000000..ea2eed011c62e535047e5f40d1f5b34fbb6ad2be --- /dev/null +++ b/tensorflow/c/generate-pc.sh @@ -0,0 +1,49 @@ +#!/usr/bin/env bash + +TF_PREFIX='/usr/local' + +usage() { + echo "Usage: $0 OPTIONS" + echo -e "-p, --prefix\tset installation prefix (default: /usr/local)" + echo -e "-v, --version\tset TensorFlow version" + echo -e "-h, --help\tdisplay this message" +} + +# read the options +ARGS=`getopt -o p:v:h --long prefix:,version:,help -n $0 -- "$@"` +eval set -- "$ARGS" + +# extract options and their arguments into variables. +while true ; do + case "$1" in + -h|--help) usage ; exit ;; + -p|--prefix) + case "$2" in + "") shift 2 ;; + *) TF_PREFIX=$2 ; shift 2 ;; + esac ;; + -v|--version) + case "$2" in + "") shift 2 ;; + *) TF_VERSION=$2 ; shift 2 ;; + esac ;; + --) shift ; echo "Try '$0 --help' for more information."; exit 1 ;; + *) echo "Internal error! Try '$0 --help' for more information." ; exit 1 ;; + esac +done + +echo "Generating pkgconfig file for TensorFlow $TF_VERSION in $TF_PREFIX" + +cat << EOF > tensorflow.pc +prefix=${TF_PREFIX} +exec_prefix=\${prefix} +libdir=\${exec_prefix}/lib +includedir=\${prefix}/include + +Name: TensorFlow +Version: ${TF_VERSION} +Description: Library for computation using data flow graphs for scalable machine learning +Requires: +Libs: -L\${libdir} -ltensorflow +Cflags: -I\${includedir} +EOF diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index 42fa139282a524f761dbebb2b55cf1ae043526e5..8d4260a0b9ca38593a912398e8460d826fb31ccf 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -91,6 +91,7 @@ cc_library( deps = [ ":array_grad", ":math_grad", + ":nn_grad", ], ) @@ -388,6 +389,16 @@ tf_gen_op_wrappers_cc( visibility = ["//tensorflow:internal"], ) +tf_gen_op_wrappers_cc( + name = "functional_ops", + include_internal_ops = 1, + op_lib_names = [ + "functional_ops", + ], + pkg = "//tensorflow/core", + visibility = ["//tensorflow:internal"], +) + tf_gen_op_wrappers_cc( name = "resource_variable_ops", include_internal_ops = 1, diff --git a/tensorflow/cc/client/client_session.cc b/tensorflow/cc/client/client_session.cc index 2732f3f5010d7522a1cf8631183e9b4df7ac86d8..2879445441d0a80c1320a30976412b416feaecc9 100644 --- a/tensorflow/cc/client/client_session.cc +++ b/tensorflow/cc/client/client_session.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/cc/client/client_session.h" #include +#include #include #include "tensorflow/core/platform/env.h" @@ -31,7 +32,7 @@ class ClientSession::Impl { friend class ClientSession; Impl(Session* session, std::shared_ptr graph) - : session_(session), graph_(graph) {} + : session_(session), graph_(std::move(graph)) {} static SessionOptions MakeDefaultSessionOptions(const string& target); Status MaybeExtendGraph() const; diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index b7e9948e9d4f3ed3e655802fce4d1febcf68c07f..71aa986f918de68822d457422f6c7a73d6253819 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -198,7 +198,7 @@ string PrintTensorProto(const TensorProto& proto) { ").AsTensorProto()"); } -string PrintAttrValue(string op, const AttrValue& attr_value) { +string PrintAttrValue(const string& op, const AttrValue& attr_value) { switch (attr_value.value_case()) { case AttrValue::kS: return PrintString(attr_value.s()); @@ -740,11 +740,10 @@ void OpInfo::GetOutput(string* out) const { return; } strings::StrAppend(out, " ::tensorflow::NameRangeMap _outputs_range;\n"); - strings::StrAppend( - out, - " ::tensorflow::Status _status_ = " - "::tensorflow::NameRangesForNode(ret->def(), ret->op_def(), " - "nullptr, &_outputs_range);\n"); + strings::StrAppend(out, + " ::tensorflow::Status _status_ = " + "::tensorflow::NameRangesForNode(*ret, ret->op_def(), " + "nullptr, &_outputs_range);\n"); strings::StrAppend(out, " if (!_status_.ok()) {\n", " ", scope_str, ".UpdateStatus(_status_);\n", " return;\n"); strings::StrAppend(out, " }\n\n"); diff --git a/tensorflow/cc/framework/cc_ops_test.cc b/tensorflow/cc/framework/cc_ops_test.cc index 6dc0d84c16d5b534341575b384997cc398c80bec..5da23036eaadbef270ba839357dc4613bf3bf490 100644 --- a/tensorflow/cc/framework/cc_ops_test.cc +++ b/tensorflow/cc/framework/cc_ops_test.cc @@ -32,10 +32,11 @@ Output Linear(const Scope& scope, Input x, Input w, Input b) { return BiasAdd(cop_scopes.last, m, b); } -void GetColocationConstraints(Output tensor, std::vector* constraints) { +void GetColocationConstraints(const Output& tensor, + std::vector* constraints) { constraints->clear(); - TF_EXPECT_OK( - GetNodeAttr(tensor.op().node()->def(), kColocationAttrName, constraints)); + TF_EXPECT_OK(GetNodeAttr(tensor.op().node()->attrs(), kColocationAttrName, + constraints)); } } // namespace @@ -158,11 +159,11 @@ TEST(CCOpTest, KernelLabel) { Scope root = Scope::NewRootScope(); auto add = Add(root.WithKernelLabel("AddWithKernelLabel"), 1.0f, 2.0f); TF_EXPECT_OK(root.status()); - const auto& attrs = add.z.op().node()->def().attr(); - ASSERT_TRUE(attrs.find("_kernel") != attrs.end()); - auto kernel_attr = attrs.find("_kernel")->second; - TF_EXPECT_OK(AttrValueHasType(kernel_attr, "string")); - EXPECT_EQ(kernel_attr.s(), "AddWithKernelLabel"); + AttrSlice attrs = add.z.op().node()->attrs(); + const auto* kernel_attr = attrs.Find("_kernel"); + ASSERT_TRUE(kernel_attr); + TF_EXPECT_OK(AttrValueHasType(*kernel_attr, "string")); + EXPECT_EQ(kernel_attr->s(), "AddWithKernelLabel"); } TEST(CCOpTest, ColocateWith) { @@ -189,8 +190,7 @@ TEST(CCOpTest, ColocateWith) { Scope with_colocate = root.ColocateWith(c3).ColocateWith(c4); auto c6 = Const(with_colocate.WithOpName("c6").ClearColocation(), 7); - const auto& attrs = c6.op().node()->def().attr(); - EXPECT_TRUE(attrs.find("_class") == attrs.end()); + EXPECT_FALSE(c6.op().node()->attrs().Find("_class")); } TEST(CCOpTest, TemplatedConst) { diff --git a/tensorflow/cc/framework/gradients_test.cc b/tensorflow/cc/framework/gradients_test.cc index 7783bdce3a7ef72ed157d620bf43517af79e1aaf..6a249825812b4d39b55f7170a35436b6ae88c020 100644 --- a/tensorflow/cc/framework/gradients_test.cc +++ b/tensorflow/cc/framework/gradients_test.cc @@ -260,7 +260,7 @@ TEST_F(GradientsTest, StackUnstack_StopBackprop) { } TEST_F(GradientsTest, DependentGradOutputs) { - // Tests that dependant gradients (in this case the gradients w.r.t to the + // Tests that dependent gradients (in this case the gradients w.r.t to the // output and one input of MatMul) are computed properly. // Create two chained MatMul ops. diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index 8b7fc1406f06e80590a98c65dd79be858b21cc0d..32c0822de69da7989ceaa4028539db928b6fcea3 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -271,9 +271,9 @@ Scope::Impl::Impl(const Scope& other, Tags::Colocate, std::unordered_set Scope::Impl::GetColocationConstraints( const Operation& colocate_with_op) const { std::unordered_set current_constraints(colocation_constraints_); - const NodeDef& node_def = colocate_with_op.node()->def(); + const AttrSlice attrs = colocate_with_op.node()->attrs(); std::vector node_constraints; - if (GetNodeAttr(node_def, kColocationAttrName, &node_constraints).ok()) { + if (GetNodeAttr(attrs, kColocationAttrName, &node_constraints).ok()) { for (const string& entry : node_constraints) { StringPiece s(entry); if (s.Consume(kColocationGroupPrefix)) { diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc index 26abd2438e652f29a1d25caf689ab0606a12b00a..37f07e71a0dff9144f193679bbcfcf581c1538cf 100644 --- a/tensorflow/cc/gradients/array_grad.cc +++ b/tensorflow/cc/gradients/array_grad.cc @@ -43,9 +43,9 @@ Status PackGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { int N; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "N", &N)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "N", &N)); int axis; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "axis", &axis)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis)); grad_outputs->reserve(N); auto grad_op = Unstack(scope, grad_inputs[0], N, Unstack::Axis(axis)); @@ -60,7 +60,7 @@ Status UnpackGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { int axis; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "axis", &axis)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis)); grad_outputs->push_back(Stack(scope, grad_inputs, Stack::Axis(axis))); return scope.status(); } @@ -162,7 +162,7 @@ Status CheckNumericsGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { string message; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "message", &message)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "message", &message)); string err_msg = strings::StrCat( "Not a number (NaN) or infinity (Inf) values detected in gradient. ", message); @@ -215,9 +215,9 @@ Status ReverseSequenceGrad(const Scope& scope, const Operation& op, std::vector* grad_outputs) { auto seq_lengths = op.input(1); int batch_dim; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "batch_dim", &batch_dim)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "batch_dim", &batch_dim)); int seq_dim; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "seq_dim", &seq_dim)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "seq_dim", &seq_dim)); grad_outputs->push_back( ReverseSequence(scope, grad_inputs[0], seq_lengths, seq_dim, ReverseSequence::BatchDim(batch_dim))); @@ -267,7 +267,8 @@ Status SpaceToBatchGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { int block_size; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "block_size", &block_size)); grad_outputs->push_back( BatchToSpace(scope, grad_inputs[0], op.input(1), block_size)); grad_outputs->push_back(NoGradient()); @@ -290,7 +291,8 @@ Status BatchToSpaceGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { int block_size; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "block_size", &block_size)); grad_outputs->push_back( SpaceToBatch(scope, grad_inputs[0], op.input(1), block_size)); grad_outputs->push_back(NoGradient()); @@ -313,7 +315,8 @@ Status SpaceToDepthGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { int block_size; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "block_size", &block_size)); grad_outputs->push_back(DepthToSpace(scope, grad_inputs[0], block_size)); return scope.status(); } @@ -323,7 +326,8 @@ Status DepthToSpaceGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { int block_size; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "block_size", &block_size)); grad_outputs->push_back(SpaceToDepth(scope, grad_inputs[0], block_size)); return scope.status(); } @@ -333,7 +337,7 @@ Status MirrorPadGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { string mode; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "mode", &mode)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode)); grad_outputs->push_back(tensorflow::ops::internal::MirrorPadGrad( scope, grad_inputs[0], op.input(1), mode)); grad_outputs->push_back(NoGradient()); @@ -346,7 +350,7 @@ Status MirrorPadGradGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { string mode; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "mode", &mode)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode)); grad_outputs->push_back(MirrorPad(scope, grad_inputs[0], op.input(1), mode)); grad_outputs->push_back(NoGradient()); return scope.status(); diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index aff0653139538820a705371ee9446a3d38ca69b5..8c1a01f518f9ad3a4571c2f36c01d4eae712e813 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -21,6 +21,17 @@ namespace tensorflow { namespace ops { namespace { +// Conjugate helper function returns the conjugate of an Output if it +// is complex valued. +Output ConjugateHelper(const Scope& scope, const Output& out) { + DataType dtype = out.type(); + if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) { + return Conj(scope, out); + } else { + return out; + } +} + // TODO(andydavis) Add control dependencies to gradient functions (as needed). Status AbsGrad(const Scope& scope, const Operation& op, @@ -44,9 +55,11 @@ REGISTER_GRADIENT_OP("Neg", NegGrad); Status InvGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - // dx = dy * (-1 * (y * y)) + // dy/dx = -1/x^2 = -y^2 + auto dydx = Neg(scope, Square(scope, op.output(0))); + // grad(x) = grad(y) * conj(dy/dx) grad_outputs->push_back( - Mul(scope, grad_inputs[0], Neg(scope, Square(scope, op.output(0))))); + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); return scope.status(); } REGISTER_GRADIENT_OP("Inv", InvGrad); @@ -55,10 +68,12 @@ REGISTER_GRADIENT_OP("Reciprocal", InvGrad); Status SquareGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - // dx = dy * (2 * x) + // dy/dx = (2 * x) auto two = Cast(scope, Const(scope, 2), op.input(0).type()); + auto dydx = Mul(scope, two, op.input(0)); + // grad(x) = grad(y) * conj(dy/dx) grad_outputs->push_back( - Mul(scope, grad_inputs[0], Mul(scope, two, op.input(0)))); + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); return scope.status(); } REGISTER_GRADIENT_OP("Square", SquareGrad); @@ -68,11 +83,12 @@ Status SqrtGrad(const Scope& scope, const Operation& op, std::vector* grad_outputs) { // y = sqrt(x) // dy/dx = 0.5 * (1 / sqrt(x)) = 0.5 * (1 / y) - // dx = dy * (0.5 * (1 / y)) auto y_inv = Reciprocal(scope, op.output(0)); auto half = Cast(scope, Const(scope, 0.5), op.input(0).type()); - auto dx = Mul(scope, grad_inputs[0], Mul(scope, half, y_inv)); - grad_outputs->push_back(dx); + auto dydx = Mul(scope, half, y_inv); + // grad(x) = grad(y) * conj(dy/dx) + grad_outputs->push_back( + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); return scope.status(); } REGISTER_GRADIENT_OP("Sqrt", SqrtGrad); @@ -82,14 +98,14 @@ Status RsqrtGrad(const Scope& scope, const Operation& op, std::vector* grad_outputs) { // y = 1/x^1/2 = x^-1/2 // dy/dx = -1/2 * x^-3/2 = -1/2 * x^-1/2 * x^-1 = -1/2 * y * x^-1 - // dx = dy * (-1/2 * y * x^-1) auto x_inv = Reciprocal(scope, op.input(0)); auto y = op.output(0); auto neghalf = Cast(scope, Const(scope, -0.5), op.input(0).type()); auto a = Mul(scope, neghalf, x_inv); - auto b = Mul(scope, a, y); - auto dx = Mul(scope, grad_inputs[0], b); - grad_outputs->push_back(dx); + auto dydx = Mul(scope, a, y); + // grad(x) = grad(y) * conj(dy/dx) + grad_outputs->push_back( + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); return scope.status(); } REGISTER_GRADIENT_OP("Rsqrt", RsqrtGrad); @@ -97,10 +113,11 @@ REGISTER_GRADIENT_OP("Rsqrt", RsqrtGrad); Status ExpGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - // y = exp(x) - // dy/dx = exp(x) - // dx = dy * y - grad_outputs->push_back(Mul(scope, grad_inputs[0], op.output(0))); + // dy/dx = exp(x) = y + // grad(x) = grad(y) * conj(dy/dx) + // = grad(y) * conj(y) + grad_outputs->push_back( + Mul(scope, grad_inputs[0], ConjugateHelper(scope, op.output(0)))); return scope.status(); } REGISTER_GRADIENT_OP("Exp", ExpGrad); @@ -108,10 +125,12 @@ REGISTER_GRADIENT_OP("Exp", ExpGrad); Status Expm1Grad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - // f(x) = expm1(x) - // df/dx = exp(x) - // dx = dy * exp(x) - grad_outputs->push_back(Mul(scope, grad_inputs[0], Exp(scope, op.input(0)))); + // y = expm1(x) + // dy/dx = exp(x) + auto dydx = Exp(scope, op.input(0)); + // grad(x) = grad(y) * conj(dy/dx) + grad_outputs->push_back( + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); return scope.status(); } REGISTER_GRADIENT_OP("Expm1", Expm1Grad); @@ -119,11 +138,12 @@ REGISTER_GRADIENT_OP("Expm1", Expm1Grad); Status LogGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - // f(x) = log(x) = y - // df/dx = 1 / x - // dx = dy * (1 / x) + // y = log(x) + // dy/dx = 1 / x + auto dydx = Reciprocal(scope, op.input(0)); + // grad(x) = grad(y) * conj(dy/dx) grad_outputs->push_back( - Mul(scope, grad_inputs[0], Reciprocal(scope, op.input(0)))); + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); return scope.status(); } REGISTER_GRADIENT_OP("Log", LogGrad); @@ -131,12 +151,13 @@ REGISTER_GRADIENT_OP("Log", LogGrad); Status Log1pGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - // f(x) = log1p(x) = y - // df/dx = 1 / (1 + x) - // dx = dy * (1 / (1 + x)) + // y = log1p(x) + // dy/dx = 1 / (1 + x) auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); + auto dydx = Reciprocal(scope, Add(scope, one, op.input(0))); + // grad(x) = grad(y) * conj(dy/dx) grad_outputs->push_back( - Div(scope, grad_inputs[0], Add(scope, one, op.input(0)))); + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); return scope.status(); } REGISTER_GRADIENT_OP("Log1p", Log1pGrad); @@ -146,11 +167,12 @@ Status TanhGrad(const Scope& scope, const Operation& op, std::vector* grad_outputs) { // y = tanh(x) // dy/dx = 1 - (tanh(x))^2 = 1 - y^2 - // dx = dy * (1 - y^2) auto y2 = Square(scope, op.output(0)); auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); - auto dx = Mul(scope, grad_inputs[0], Sub(scope, one, y2)); - grad_outputs->push_back(dx); + auto dydx = Sub(scope, one, y2); + // grad(x) = grad(y) * conj(dy/dx) + grad_outputs->push_back( + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); return scope.status(); } REGISTER_GRADIENT_OP("Tanh", TanhGrad); @@ -160,11 +182,13 @@ Status SigmoidGrad(const Scope& scope, const Operation& op, std::vector* grad_outputs) { // y = 1 / (1 + exp(-x)) // dy/dx = y * (1 - y) - // dx = dy * y * (1 - y) auto y = op.output(0); auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); - auto dx = Mul(scope, grad_inputs[0], Mul(scope, y, Sub(scope, one, y))); - grad_outputs->push_back(dx); + auto dydx = Mul(scope, y, Sub(scope, one, y)); + // dx = dy * y * (1 - y) + // grad(x) = grad(y) * conj(dy/dx) + grad_outputs->push_back( + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); return scope.status(); } REGISTER_GRADIENT_OP("Sigmoid", SigmoidGrad); @@ -185,9 +209,10 @@ Status SinGrad(const Scope& scope, const Operation& op, std::vector* grad_outputs) { // y = sin(x) // dy/dx = cos(x) - // dx = dy * cos(x) - auto dx = Mul(scope, grad_inputs[0], Cos(scope, op.input(0))); - grad_outputs->push_back(dx); + auto dydx = Cos(scope, op.input(0)); + // grad(x) = grad(y) * conj(dy/dx) + grad_outputs->push_back( + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); return scope.status(); } REGISTER_GRADIENT_OP("Sin", SinGrad); @@ -197,9 +222,10 @@ Status CosGrad(const Scope& scope, const Operation& op, std::vector* grad_outputs) { // y = cos(x) // dy/dx = -sin(x) - // dx = dy * -sin(x) - auto dx = Mul(scope, grad_inputs[0], Neg(scope, Sin(scope, op.input(0)))); - grad_outputs->push_back(dx); + auto dydx = Neg(scope, Sin(scope, op.input(0))); + // grad(x) = grad(y) * conj(dy/dx) + grad_outputs->push_back( + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); return scope.status(); } REGISTER_GRADIENT_OP("Cos", CosGrad); @@ -208,12 +234,12 @@ Status AsinGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { // y = asin(x) - // dy/dx = 1 / (1 - x * x)^1/2 - // dx = dy * (1 / (1 - x * x)^1/2) + // dy/dx = 1 / sqrt(1 - x^2) auto x2 = Square(scope, op.input(0)); auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); auto dydx = Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2))); - auto dx = Mul(scope, grad_inputs[0], dydx); + // grad(x) = grad(y) * conj(dy/dx) + auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)); grad_outputs->push_back(dx); return scope.status(); } @@ -239,9 +265,9 @@ Status TanGrad(const Scope& scope, const Operation& op, std::vector* grad_outputs) { // y = tan(x) // dy/dx = sec(x)^2 = 1 / cos(x)^2 - // dx = dy * (1 / cos(x)^2) auto dydx = Square(scope, Reciprocal(scope, Cos(scope, op.input(0)))); - auto dx = Mul(scope, grad_inputs[0], dydx); + // grad(x) = grad(y) * conj(dy/dx) + auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)); grad_outputs->push_back(dx); return scope.status(); } @@ -324,7 +350,7 @@ Status MatMulGradCommon(const Scope& scope, const Operation& op, const string& attr_adj_x, const string& attr_adj_y, std::vector* grad_outputs) { DataType dtype; - TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), "T", &dtype)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->attrs(), "T", &dtype)); if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) { return errors::Unimplemented( "MatMul gradient for complex data type is not supported yet."); @@ -332,8 +358,10 @@ Status MatMulGradCommon(const Scope& scope, const Operation& op, bool ta; bool tb; - TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_x, &ta)); - TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_y, &tb)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.output(0).node()->attrs(), attr_adj_x, &ta)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.output(0).node()->attrs(), attr_adj_y, &tb)); if (!ta && !tb) { return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1), diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index d7278929d4651f17d25670934b15e6da33d6a960..de6baa176936bcda7d0899c3795e1fbd37627058 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -56,23 +56,25 @@ class CWiseUnaryGradTest : public ::testing::Test { ATAN }; - void TestCWiseGrad(UnaryOpType op_type, std::function x_fn, - std::function dy_fn, - std::function dx_fn) { - Tensor x(DT_FLOAT, {2, 3, 2}); - auto x_flat = x.flat(); + template + void TestCWiseGrad(UnaryOpType op_type, const std::function& x_fn, + const std::function& dy_fn, + const std::function& dx_fn) { + DataType dtype = DataTypeToEnum::v(); + Tensor x(dtype, {2, 3, 2}); + auto x_flat = x.flat(); for (int i = 0; i < x_flat.size(); ++i) { x_flat(i) = x_fn(i); } - Tensor dy(DT_FLOAT, {2, 3, 2}); - auto dy_flat = dy.flat(); + Tensor dy(dtype, {2, 3, 2}); + auto dy_flat = dy.flat(); for (int i = 0; i < dy_flat.size(); ++i) { dy_flat(i) = dy_fn(x_flat(i)); } - Tensor dx(DT_FLOAT, {2, 3, 2}); - auto dx_flat = dx.flat(); + Tensor dx(dtype, {2, 3, 2}); + auto dx_flat = dx.flat(); for (int i = 0; i < dx_flat.size(); ++i) { dx_flat(i) = dx_fn(x_flat(i), dy_flat(i)); } @@ -146,7 +148,19 @@ class CWiseUnaryGradTest : public ::testing::Test { test::ExpectClose(output, dx); } - float RV(std::vector v) { return v[random::New64() % v.size()]; } + float RV(const std::vector& v) { + return v[random::New64() % v.size()]; + } + + complex64 CRV(const std::vector& v) { + return v[random::New64() % v.size()]; + } + + complex64 conjugate(const complex64& val) { + return complex64(val.real(), -val.imag()); + } + + const complex64 one_{1.0, 0}; Scope scope_; }; @@ -155,14 +169,14 @@ TEST_F(CWiseUnaryGradTest, Abs) { auto x_fn = [this](const int i) { return RV({-1, 0, 1}); }; auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; auto dx_fn = [this](const float x, const float dy) { return x * dy; }; - TestCWiseGrad(ABS, x_fn, dy_fn, dx_fn); + TestCWiseGrad(ABS, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Neg) { auto x_fn = [this](const int i) { return RV({-1, 0, 1}); }; auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; auto dx_fn = [this](const float x, const float dy) { return -dy; }; - TestCWiseGrad(NEG, x_fn, dy_fn, dx_fn); + TestCWiseGrad(NEG, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Reciprocal) { @@ -171,14 +185,36 @@ TEST_F(CWiseUnaryGradTest, Reciprocal) { auto dx_fn = [this](const float x, const float dy) { return -(1 / (x * x)) * dy; }; - TestCWiseGrad(INV, x_fn, dy_fn, dx_fn); + TestCWiseGrad(INV, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Reciprocal_Complex) { + auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; + auto dy_fn = [this](const complex64 x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64 x, const complex64 dy) { + return -conjugate(one_ / (x * x)) * dy; + }; + TestCWiseGrad(INV, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Square) { auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; auto dy_fn = [this](const float x) { return RV({0, -7, 7, -8, 8, -9, 9}); }; auto dx_fn = [this](const float x, const float dy) { return 2 * x * dy; }; - TestCWiseGrad(SQUARE, x_fn, dy_fn, dx_fn); + TestCWiseGrad(SQUARE, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Square_Complex) { + auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return conjugate(complex64(2, 0) * x) * dy; + }; + TestCWiseGrad(SQUARE, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Sqrt) { @@ -187,7 +223,18 @@ TEST_F(CWiseUnaryGradTest, Sqrt) { auto dx_fn = [this](const float x, const float dy) { return dy * 0.5 * (1.0 / std::sqrt(x)); }; - TestCWiseGrad(SQRT, x_fn, dy_fn, dx_fn); + TestCWiseGrad(SQRT, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Sqrt_Complex) { + auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return conjugate(complex64(0.5, 0) / std::sqrt(x)) * dy; + }; + TestCWiseGrad(SQRT, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Rsqrt) { @@ -196,7 +243,18 @@ TEST_F(CWiseUnaryGradTest, Rsqrt) { auto dx_fn = [this](const float x, const float dy) { return dy * -0.5 * (1 / std::sqrt(x)) * (1 / x); }; - TestCWiseGrad(RSQRT, x_fn, dy_fn, dx_fn); + TestCWiseGrad(RSQRT, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Rsqrt_Complex) { + auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return conjugate(complex64(-0.5, 0) / std::sqrt(x) / x) * dy; + }; + TestCWiseGrad(RSQRT, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Exp) { @@ -205,7 +263,18 @@ TEST_F(CWiseUnaryGradTest, Exp) { auto dx_fn = [this](const float x, const float dy) { return dy * std::exp(x); }; - TestCWiseGrad(EXP, x_fn, dy_fn, dx_fn); + TestCWiseGrad(EXP, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Exp_Complex) { + auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return dy * conjugate(std::exp(x)); + }; + TestCWiseGrad(EXP, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Expm1) { @@ -214,14 +283,36 @@ TEST_F(CWiseUnaryGradTest, Expm1) { auto dx_fn = [this](const float x, const float dy) { return dy * std::exp(x); }; - TestCWiseGrad(EXPM1, x_fn, dy_fn, dx_fn); + TestCWiseGrad(EXPM1, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Expm1_Complex) { + auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return dy * conjugate(std::exp(x)); + }; + TestCWiseGrad(EXPM1, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Log) { auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); }; auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; auto dx_fn = [this](const float x, const float dy) { return dy * (1.0 / x); }; - TestCWiseGrad(LOG, x_fn, dy_fn, dx_fn); + TestCWiseGrad(LOG, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Log_Complex) { + auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return dy * conjugate(one_ / x); + }; + TestCWiseGrad(LOG, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Log1p) { @@ -230,7 +321,20 @@ TEST_F(CWiseUnaryGradTest, Log1p) { auto dx_fn = [this](const float x, const float dy) { return dy * (1.0 / (1.0 + x)); }; - TestCWiseGrad(LOG1P, x_fn, dy_fn, dx_fn); + TestCWiseGrad(LOG1P, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Log1p_Complex) { + auto x_fn = [this](const int i) { + return CRV({{0, 0}, {1e-6, 0}, {2, -1}, {1, 2}, {3, 4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return dy / (one_ + conjugate(x)); + }; + TestCWiseGrad(LOG1P, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Tanh) { @@ -240,7 +344,21 @@ TEST_F(CWiseUnaryGradTest, Tanh) { const float y = std::tanh(x); return dy * (1.0 - y * y); }; - TestCWiseGrad(TANH, x_fn, dy_fn, dx_fn); + TestCWiseGrad(TANH, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Tanh_Complex) { + auto x_fn = [this](const int i) { + return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + const complex64 y = std::tanh(x); + return dy * conjugate((one_ - y * y)); + }; + TestCWiseGrad(TANH, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Sigmoid) { @@ -250,14 +368,28 @@ TEST_F(CWiseUnaryGradTest, Sigmoid) { const float y = 1.0 / (1.0 + std::exp(-x)); return dy * y * (1.0 - y); }; - TestCWiseGrad(SIGMOID, x_fn, dy_fn, dx_fn); + TestCWiseGrad(SIGMOID, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Sigmoid_Complex) { + auto x_fn = [this](const int i) { + return CRV({{1, 0}, {0, 0}, {2, -1}, {1, 2}, {3, 4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + const complex64 y = one_ / (one_ + std::exp(-x)); + return dy * conjugate(y * (one_ - y)); + }; + TestCWiseGrad(SIGMOID, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Sign) { auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; auto dx_fn = [this](const float x, const float dy) { return 0.0; }; - TestCWiseGrad(SIGN, x_fn, dy_fn, dx_fn); + TestCWiseGrad(SIGN, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Sin) { @@ -266,7 +398,20 @@ TEST_F(CWiseUnaryGradTest, Sin) { auto dx_fn = [this](const float x, const float dy) { return dy * std::cos(x); }; - TestCWiseGrad(SIN, x_fn, dy_fn, dx_fn); + TestCWiseGrad(SIN, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Sin_Complex) { + auto x_fn = [this](const int i) { + return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return dy * conjugate(std::cos(x)); + }; + TestCWiseGrad(SIN, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Cos) { @@ -275,7 +420,20 @@ TEST_F(CWiseUnaryGradTest, Cos) { auto dx_fn = [this](const float x, const float dy) { return dy * -1.0 * std::sin(x); }; - TestCWiseGrad(COS, x_fn, dy_fn, dx_fn); + TestCWiseGrad(COS, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Cos_Complex) { + auto x_fn = [this](const int i) { + return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return dy * conjugate(-std::sin(x)); + }; + TestCWiseGrad(COS, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Asin) { @@ -284,7 +442,24 @@ TEST_F(CWiseUnaryGradTest, Asin) { auto dx_fn = [this](const float x, const float dy) { return dy * (1.0 / std::sqrt(1.0 - x * x)); }; - TestCWiseGrad(ASIN, x_fn, dy_fn, dx_fn); + TestCWiseGrad(ASIN, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Asin_Complex) { + auto x_fn = [this](const int i) { + return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return dy / conjugate(std::sqrt(one_ - x * x)); + }; + // TODO(kbsriram) + // Enable test when the asin kernel supports complex numbers + if (false) { + TestCWiseGrad(ASIN, x_fn, dy_fn, dx_fn); + } } TEST_F(CWiseUnaryGradTest, Acos) { @@ -293,7 +468,24 @@ TEST_F(CWiseUnaryGradTest, Acos) { auto dx_fn = [this](const float x, const float dy) { return dy * (-1.0 / std::sqrt(1.0 - x * x)); }; - TestCWiseGrad(ACOS, x_fn, dy_fn, dx_fn); + TestCWiseGrad(ACOS, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Acos_Complex) { + auto x_fn = [this](const int i) { + return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return dy / -conjugate(std::sqrt(one_ - x * x)); + }; + // TODO(kbsriram) + // Add test when the acos kernel supports complex numbers + if (false) { + TestCWiseGrad(ACOS, x_fn, dy_fn, dx_fn); + } } TEST_F(CWiseUnaryGradTest, Tan) { @@ -303,7 +495,25 @@ TEST_F(CWiseUnaryGradTest, Tan) { const float cosx = std::cos(x); return dy * (1 / (cosx * cosx)); }; - TestCWiseGrad(TAN, x_fn, dy_fn, dx_fn); + TestCWiseGrad(TAN, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Tan_Complex) { + auto x_fn = [this](const int i) { + return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + const complex64 cosx = std::cos(x); + return dy / conjugate(cosx * cosx); + }; + // TODO(kbsriram) + // Enable when tan kernel supports complex inputs + if (false) { + TestCWiseGrad(TAN, x_fn, dy_fn, dx_fn); + } } TEST_F(CWiseUnaryGradTest, Atan) { @@ -312,7 +522,24 @@ TEST_F(CWiseUnaryGradTest, Atan) { auto dx_fn = [this](const float x, const float dy) { return dy * (1 / (1 + x * x)); }; - TestCWiseGrad(ATAN, x_fn, dy_fn, dx_fn); + TestCWiseGrad(ATAN, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Atan_Complex) { + auto x_fn = [this](const int i) { + return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return dy / (one_ + x * x); + }; + // TODO(kbsriram) + // Add test when the atan kernel supports complex numbers + if (false) { + TestCWiseGrad(ATAN, x_fn, dy_fn, dx_fn); + } } class CWiseUnaryComplexGradTest : public ::testing::Test { diff --git a/tensorflow/cc/ops/const_op_test.cc b/tensorflow/cc/ops/const_op_test.cc index 5a4770f879ff9a1422a63a88bd2b67ba201a0567..3184edeb3307cafcbfbc41c6477fd092ab613b46 100644 --- a/tensorflow/cc/ops/const_op_test.cc +++ b/tensorflow/cc/ops/const_op_test.cc @@ -28,9 +28,9 @@ void ExpectNodeEqual(const Node* n, gtl::ArraySlice values, TensorShape shape) { EXPECT_TRUE(n->IsConstant()); Tensor tensor; - TF_EXPECT_OK(GetNodeAttr(n->def(), "value", &tensor)); + TF_EXPECT_OK(GetNodeAttr(n->attrs(), "value", &tensor)); DataType dtype; - TF_EXPECT_OK(GetNodeAttr(n->def(), "dtype", &dtype)); + TF_EXPECT_OK(GetNodeAttr(n->attrs(), "dtype", &dtype)); EXPECT_EQ(tensor.dtype(), dtype); test::ExpectTensorEqual(tensor, test::AsTensor(values, shape)); } @@ -39,9 +39,9 @@ void ExpectTypeAndShape(const Node* n, DataType expected_dtype, TensorShape expected_shape) { EXPECT_TRUE(n->IsConstant()); Tensor tensor; - TF_EXPECT_OK(GetNodeAttr(n->def(), "value", &tensor)); + TF_EXPECT_OK(GetNodeAttr(n->attrs(), "value", &tensor)); DataType dtype; - TF_EXPECT_OK(GetNodeAttr(n->def(), "dtype", &dtype)); + TF_EXPECT_OK(GetNodeAttr(n->attrs(), "dtype", &dtype)); EXPECT_EQ(dtype, expected_dtype); EXPECT_EQ(expected_shape, TensorShape(tensor.shape())); } diff --git a/tensorflow/cc/ops/op_gen_overrides.pbtxt b/tensorflow/cc/ops/op_gen_overrides.pbtxt index cd94ddf4a1b67d3b98da7769db95bbda294e76db..1dffb10c03379571907e921c1add98d1f11625c3 100644 --- a/tensorflow/cc/ops/op_gen_overrides.pbtxt +++ b/tensorflow/cc/ops/op_gen_overrides.pbtxt @@ -22,7 +22,7 @@ op { name: "Where" input_rename: { from: "input" to: "condition" } } op { name: "ThreadUnsafeUnigramCandidateSampler", skip: true } # control_flow_ops -# TODO(josh11b): Hide Switch and Merge once we write and migrate users to +# TODO(joshl): Hide Switch and Merge once we write and migrate users to # a Cond() API. #op { name: "Switch" hide: true } #op { name: "Merge" hide: true } diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index b144bfc33e46c3db192cfb1e3ef8a0633e9fa519..908aa01a3470b67233c61d150ea955c1c13a8cd3 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -36,7 +36,7 @@ auto* load_attempt_count = monitoring::Counter<2>::New( "status"); auto* load_latency = monitoring::Counter<1>::New( "/tensorflow/cc/saved_model/load_latency", - "Latency in microseconds for SavedModels that were succesfully loaded.", + "Latency in microseconds for SavedModels that were successfully loaded.", "model_path"); constexpr char kLoadAttemptFail[] = "fail"; constexpr char kLoadAttemptSuccess[] = "success"; diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index c52a56b6428fb8a8415ed53477ba3e81c57b0ded..c12005a4cab903c15a4f95efa0fdc3b8b2563942 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -73,7 +73,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:compile_only_client", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/core:core_cpu", diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index 042a72745a78c4a11b22c85e3a094d78c4ab2ed5..bbdb342a623f5d4435e437fbb94e282b685751c9 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -152,8 +152,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape, string RewriteWithName(const string& name, string code, const std::vector>& rewrites) { str_util::ReplaceAllPairs(&code, rewrites); - str_util::ReplaceAll(&code, "{{NAME}}", name); - return code; + return str_util::StringReplace(code, "{{NAME}}", name, /*replace_all=*/true); } // Generate methods for args (inputs). @@ -366,7 +365,7 @@ Status GenerateHeader(const HeaderOpts& opts, const Config& config, #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -namespace Eigen { class ThreadPoolDevice; } +namespace Eigen { struct ThreadPoolDevice; } // (Implementation detail) Entry point to the function in the object file. extern "C" void {{ENTRY}}( diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index 46d7c03006a1344df17fc99c8b837f31ee86feb9..01963c6df4682ec8c23a93201d7fbbab63558060 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -15,7 +15,7 @@ #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -namespace Eigen { class ThreadPoolDevice; } +namespace Eigen { struct ThreadPoolDevice; } // (Implementation detail) Entry point to the function in the object file. extern "C" void entry_point( diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index 4b5534c164887ed0f3656808d8d328bb7b4f5975..0c7b97b01f43ea255ed4b7773ab5268396e7c306 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -27,7 +27,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/compile_only_client.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -203,14 +203,14 @@ Status RewriteAndPruneGraph(Graph* graph, const Config& config, for (const Node* n : graph->nodes()) { if (n->type_string() == kArgOp) { string feed_id; - TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kFeedIdAttr, &feed_id)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFeedIdAttr, &feed_id)); if (missing_feeds.erase(feed_id) == 0) { return errors::Aborted(kArgOp, " node found with unknown feed id: ", feed_id); } } else if (n->type_string() == kRetvalOp) { string fetch_id; - TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kFetchIdAttr, &fetch_id)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFetchIdAttr, &fetch_id)); if (missing_fetches.erase(fetch_id) == 0) { return errors::Aborted(kRetvalOp, " node found with unknown fetch id: ", fetch_id); @@ -234,7 +234,7 @@ Status CollectArgNodes(const Graph& graph, std::vector* arg_nodes) { for (Node* n : graph.nodes()) { if (n->type_string() == kArgOp) { int index; - TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); auto insert_result = indexed_arg_nodes.insert({index, n}); if (!insert_result.second) { const Node* dup = insert_result.first->second; @@ -264,9 +264,9 @@ Status CreateXlaArgs(const Graph& graph, for (const Node* node : arg_nodes) { XlaCompiler::Argument arg; arg.kind = XlaCompiler::Argument::kParameter; - TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "T", &arg.type)); - TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), kShapeAttr, &arg.shape)); - TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), kDebugNameAttr, &arg.name)); + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type)); + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &arg.shape)); + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name)); xla_args->push_back(arg); } return Status::OK(); @@ -274,7 +274,8 @@ Status CreateXlaArgs(const Graph& graph, // Converts the TensorFlow graph into an XLA computation, by executing the // graph symbolically, with each op building up the XLA HLO. -Status ConvertGraphToXla(xla::LocalClient* client, std::unique_ptr graph, +Status ConvertGraphToXla(xla::CompileOnlyClient* client, + std::unique_ptr graph, xla::Computation* computation, bool* has_context_arg) { // Create a device and context to convert the graph into an XLA computation. XlaOpRegistry::RegisterCompilationKernels(); @@ -288,18 +289,19 @@ Status ConvertGraphToXla(xla::LocalClient* client, std::unique_ptr graph, // Compile the graph into an XLA computation. XlaCompiler::Options compiler_options; compiler_options.client = client; - compiler_options.device_type = DeviceType(DEVICE_CPU_XLA_JIT); + DeviceType device_type(DEVICE_CPU_XLA_JIT); + compiler_options.device_type = &device_type; + compiler_options.flib_def = &graph->flib_def(); + compiler_options.graph_def_version = graph->versions().producer(); compiler_options.allow_cpu_custom_calls = true; XlaCompiler compiler(compiler_options); - std::unique_ptr flib_run(NewFunctionLibraryRuntime( - compiler.device_mgr(), Env::Default(), compiler.device(), - graph->versions().producer(), &graph->flib_def(), OptimizerOptions())); XlaCompiler::CompilationResult result; - TF_RETURN_IF_ERROR(compiler.CompileGraph("tfcompile", std::move(graph), - flib_run.get(), xla_args, &result)); + TF_RETURN_IF_ERROR(compiler.CompileGraph(XlaCompiler::CompileOptions(), + "tfcompile", std::move(graph), + xla_args, &result)); *has_context_arg = result.requires_runtime_context; - *computation = std::move(result.computation); + *computation = std::move(*result.computation); int num_const_results = 0; for (int i = 0; i < result.outputs.size(); ++i) { @@ -333,7 +335,8 @@ Status ConvertGraphToXla(xla::LocalClient* client, std::unique_ptr graph, } // Compiles the XLA computation into executable code. -Status CompileXla(xla::LocalClient* client, const xla::Computation& computation, +Status CompileXla(xla::CompileOnlyClient* client, + const xla::Computation& computation, const xla::cpu::CpuAotCompilationOptions& aot_opts, CompileResult* compile_result) { // Retrieves arg and result layouts from the computation. @@ -350,7 +353,7 @@ Status CompileXla(xla::LocalClient* client, const xla::Computation& computation, for (int i = 0; i < pshape->parameters_size(); ++i) { arg_layouts.push_back(pshape->mutable_parameters(i)); } - xla::LocalClient::AheadOfTimeComputationInstance instance; + xla::CompileOnlyClient::AotComputationInstance instance; instance.computation = &computation; instance.argument_layouts = std::move(arg_layouts); instance.result_layout = &pshape->result(); @@ -365,7 +368,7 @@ Status CompileXla(xla::LocalClient* client, const xla::Computation& computation, std::move(aot_or.ValueOrDie().back())); compile_result->entry_point = aot_opts.entry_point_name(); compile_result->pointer_size = - xla::LocalClient::PointerSizeForTriple(aot_opts.triple()); + xla::CompileOnlyClient::PointerSizeForTriple(aot_opts.triple()); return Status::OK(); } @@ -394,8 +397,9 @@ Status CompileGraph(std::unique_ptr graph, const MainFlags& flags, namespace gpu = perftools::gputools; gpu::Platform* cpu_platform = gpu::MultiPlatformManager::PlatformWithName("Host").ValueOrDie(); - xla::LocalClient* client = - xla::ClientLibrary::GetOrCreateLocalClient(cpu_platform).ValueOrDie(); + xla::CompileOnlyClient* client = + xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform) + .ValueOrDie(); xla::Computation computation; TF_RETURN_IF_ERROR(ConvertGraphToXla(client, std::move(graph), &computation, &compile_result->has_context_arg)); diff --git a/tensorflow/compiler/aot/runtime.cc b/tensorflow/compiler/aot/runtime.cc index 208de5498dbee6773683ac1aa2b33400a8a21f35..5772776666129ed55a479c8917e69df3f3ce2fc0 100644 --- a/tensorflow/compiler/aot/runtime.cc +++ b/tensorflow/compiler/aot/runtime.cc @@ -31,6 +31,8 @@ namespace { inline void* aligned_malloc(size_t size, int minimum_alignment) { #if defined(__ANDROID__) || defined(OS_ANDROID) || defined(OS_CYGWIN) return memalign(minimum_alignment, size); +#elif defined(COMPILER_MSVC) + return _aligned_malloc(size, minimum_alignment); #else // !__ANDROID__ && !OS_ANDROID && !OS_CYGWIN void* ptr = nullptr; // posix_memalign requires that the requested alignment be at least @@ -45,7 +47,13 @@ inline void* aligned_malloc(size_t size, int minimum_alignment) { #endif } -inline void aligned_free(void* aligned_memory) { free(aligned_memory); } +inline void aligned_free(void* aligned_memory) { +#if defined(COMPILER_MSVC) + _aligned_free(aligned_memory); +#else + free(aligned_memory); +#endif +} size_t align_to(size_t n, size_t align) { return (((n - 1) / align) + 1) * align; diff --git a/tensorflow/compiler/aot/tfcompile_util_test.cc b/tensorflow/compiler/aot/tfcompile_util_test.cc index 108ab1eab7bf3b087e8049c5b24d652d871789c8..c321d3ff4c779fbd2e9c67dfc1eb24c734a9103f 100644 --- a/tensorflow/compiler/aot/tfcompile_util_test.cc +++ b/tensorflow/compiler/aot/tfcompile_util_test.cc @@ -24,7 +24,7 @@ namespace tensorflow { namespace tfcompile { namespace { -void ExpectErrorContains(Status status, StringPiece str) { +void ExpectErrorContains(const Status& status, StringPiece str) { EXPECT_NE(Status::OK(), status); EXPECT_TRUE(StringPiece(status.error_message()).contains(str)) << "expected error: " << status.error_message() << " to contain: " << str; diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index e56f173d518232791d0f490a48bd40e8f14d6cfe..9b4e872ebe561c0d919b1982339896c12bc079f9 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -19,6 +19,7 @@ package( ) load("//tensorflow:tensorflow.bzl", "tf_kernel_library") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") # Target that bundles up the XLA CPU and GPU JIT devices. cc_library( @@ -48,12 +49,12 @@ cc_library( cc_library( name = "xla_gpu_jit", visibility = [":friends"], - deps = [ + deps = if_cuda([ ":jit_compilation_passes", "//tensorflow/compiler/jit/kernels:xla_local_launch_op", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:gpu_plugin", - ], + ]), alwayslink = 1, ) diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc index abb68f73d7e3870f733c350be0dc99ab21a6b083..48eed7fce07f0855934600890e157b2752d38838 100644 --- a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc @@ -66,9 +66,9 @@ static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) { int num_constant_args, num_resource_args; TF_RETURN_IF_ERROR( - GetNodeAttr(node->def(), kXlaNumConstantArgsAttr, &num_constant_args)); + GetNodeAttr(node->attrs(), kXlaNumConstantArgsAttr, &num_constant_args)); TF_RETURN_IF_ERROR( - GetNodeAttr(node->def(), kXlaNumResourceArgsAttr, &num_resource_args)); + GetNodeAttr(node->attrs(), kXlaNumResourceArgsAttr, &num_resource_args)); if (num_constant_args < 0 || num_resource_args < 0 || num_constant_args + num_resource_args > node->num_inputs()) { @@ -88,7 +88,7 @@ static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) { Node* launch_node; TF_RETURN_IF_ERROR(BuildLaunchNode( graph->NewName(node->name()), node->type_string(), node->def().attr(), - node->def().device(), const_dtypes, num_resource_args, arg_dtypes, + node->requested_device(), const_dtypes, num_resource_args, arg_dtypes, node->output_types(), graph, &launch_node)); launch_node->set_assigned_device_name(node->assigned_device_name()); @@ -173,7 +173,8 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& ndef, FunctionLibraryRuntime::Handle handle; // If ndef is not instantiable, e.g., the function does not exist, // simply bail out. - TF_RETURN_IF_ERROR(flr->Instantiate(ndef.op(), ndef.attr(), &handle)); + TF_RETURN_IF_ERROR( + flr->Instantiate(ndef.op(), AttrSlice(&ndef.attr()), &handle)); const FunctionBody* fbody = flr->GetFunctionBody(handle); CHECK(fbody); // Can't be nullptr since we just instantiated it. std::vector const_args(fbody->arg_types.size()); diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 1d2793d3c55f4436a07e4f632887561202d0498e..88ec45f8d86643aa4f7c643ac5bee333fb2ec559 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -88,9 +88,12 @@ class Encapsulator { // Build a FunctionDef for each subgraph, and add it 'library'. The values of // the 'group_attribute' annotations become the function names. + // If 'reuse_existing_functions' is set, use an existing function with the + // same name, if any. // If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before // function conversion. Status BuildFunctionDefs(const RewriteSubgraphFn& rewrite_subgraph_fn, + bool reuse_existing_functions, FunctionLibraryDefinition* library); // Write a copy of the input graph to 'graph_out', where the subgraphs are @@ -162,7 +165,7 @@ static const char* const kRetValOp = "_Retval"; // none. string Encapsulator::GetFunctionNameAttr(Node const* node) const { string attr; - if (!GetNodeAttr(node->def(), group_attribute_, &attr).ok()) { + if (!GetNodeAttr(node->attrs(), group_attribute_, &attr).ok()) { attr.clear(); } return attr; @@ -192,7 +195,7 @@ Status Encapsulator::SplitIntoSubgraphs() { // Check the device matches any existing device. string device = node->assigned_device_name().empty() - ? node->def().device() + ? node->requested_device() : node->assigned_device_name(); if (subgraph.device.empty()) { @@ -236,9 +239,16 @@ Status Encapsulator::SplitIntoSubgraphs() { // Create a new _Retval node DataType dtype = edge->src()->output_type(edge->src_output()); + if (IsRefType(dtype)) { + return errors::InvalidArgument( + "Ref Tensors (e.g., Variables) are not supported: tensor ", + edge->src()->name(), ":", edge->src_output()); + } + NodeDef ret_def; ret_def.set_op(kRetValOp); - ret_def.set_name(src_subgraph.graph->NewName("output")); + ret_def.set_name(strings::StrCat(edge->src()->name(), "_", + edge->src_output(), "_retval")); AddNodeAttr("T", dtype, &ret_def); AddNodeAttr("index", ret_index, &ret_def); Node* ret = src_subgraph.graph->AddNode(ret_def, &s); @@ -263,8 +273,16 @@ Status Encapsulator::SplitIntoSubgraphs() { // This is the first time we have seen this tensor. Create an _Arg node. DataType dtype = edge->dst()->input_type(edge->dst_input()); + if (IsRefType(dtype)) { + return errors::InvalidArgument( + "Ref Tensors (e.g., Variables) are not supported: tensor ", + edge->src()->name(), ":", edge->src_output()); + } + NodeDef arg_def; - NodeDefBuilder builder(dst_subgraph.graph->NewName("input"), kArgOp); + NodeDefBuilder builder(strings::StrCat(edge->src()->name(), "_", + edge->src_output(), "_arg"), + kArgOp); builder.Attr("T", dtype); builder.Attr("index", arg_index); s = builder.Finalize(&arg_def); @@ -291,11 +309,11 @@ Status Encapsulator::SplitIntoSubgraphs() { } Status Encapsulator::BuildFunctionDefs( - const RewriteSubgraphFn& rewrite_subgraph_fn, + const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions, FunctionLibraryDefinition* library) { // For each subgraph, build a FunctionDef. for (auto& subgraph_entry : subgraphs_) { - const string& name = subgraph_entry.first; + string name = subgraph_entry.first; Subgraph& subgraph = subgraph_entry.second; subgraph.call_node_def.set_op(name); @@ -332,6 +350,8 @@ Status Encapsulator::BuildFunctionDefs( for (auto& result : subgraph.results) { result.second = output_permutation[result.second]; } + + name = subgraph.call_node_def.op(); } FunctionDef fdef; @@ -346,7 +366,9 @@ Status Encapsulator::BuildFunctionDefs( strings::StrCat("encapsulate_fdef_", name), fdef); } - TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); + if (!reuse_existing_functions || library->Find(name) == nullptr) { + TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); + } } return Status::OK(); } @@ -545,14 +567,16 @@ Status Encapsulator::BuildOutputGraph(bool parallel_checking, Status EncapsulateSubgraphsInFunctions( string group_attribute, const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn, bool parallel_checking, - std::unique_ptr* graph_out, FunctionLibraryDefinition* library) { + bool reuse_existing_functions, std::unique_ptr* graph_out, + FunctionLibraryDefinition* library) { Status s; Encapsulator encapsulator(std::move(group_attribute), &graph_in); s = encapsulator.SplitIntoSubgraphs(); if (!s.ok()) return s; - s = encapsulator.BuildFunctionDefs(rewrite_subgraph_fn, library); + s = encapsulator.BuildFunctionDefs(rewrite_subgraph_fn, + reuse_existing_functions, library); if (!s.ok()) return s; std::unique_ptr out(new Graph(library)); @@ -569,7 +593,7 @@ static Status GetArgTypes(const Graph& graph, DataTypeVector* types) { for (Node* n : graph.nodes()) { if (n->type_string() == kArgOp) { int index; - TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); if (index < 0 || index >= types->size()) { return errors::InvalidArgument("Invalid argument number"); } @@ -586,7 +610,7 @@ static Status RenumberArguments(Graph* graph, for (Node* n : graph->nodes()) { if (n->type_string() == kArgOp) { int index; - TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); if (index < 0 || index >= permutation.size()) { return errors::InvalidArgument("Invalid argument number"); } @@ -674,7 +698,8 @@ Status EncapsulateSubgraphsPass::Run( TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions( kXlaClusterAttr, **options.graph, rewrite_subgraph, - flags->tf_xla_parallel_checking, &graph_out, library)); + flags->tf_xla_parallel_checking, /*reuse_existing_functions=*/false, + &graph_out, library)); if (VLOG_IS_ON(1)) { dump_graph::DumpGraphToFile("after_encapsulate_subgraphs", *graph_out, @@ -688,7 +713,7 @@ Status EncapsulateSubgraphsPass::Run( bool IsXlaCompiledKernel(const Node& node) { bool is_compiled = false; bool has_compilation_attr = - GetNodeAttr(node.def(), kXlaCompiledKernelAttr, &is_compiled).ok() && + GetNodeAttr(node.attrs(), kXlaCompiledKernelAttr, &is_compiled).ok() && is_compiled; return has_compilation_attr ? is_compiled : false; } diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h index 3ca7dfbf6a0ec29d9517139ffb952298d503cabc..b0987f76c91ed48df52fab303ea6052ebd8fd336 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h @@ -34,6 +34,8 @@ namespace tensorflow { // 'input_permutation' and 'output_permutation' are initialized to the identity // permutation. 'nodedef' is the NodeDef for the call to the function under // construction, provided to allow additional attributes to be set. +// The rewrite may also change the NodeDef's operator name, and that +// name will be used as the name of the generated function. typedef std::function* graph, std::vector* input_permutation, std::vector* output_permutation, NodeDef* node_def)> @@ -53,6 +55,9 @@ typedef std::function* graph_out, FunctionLibraryDefinition* library); + bool reuse_existing_functions, std::unique_ptr* graph_out, + FunctionLibraryDefinition* library); // The attribute that marks function calls produced by the encapsulate // subgraphs pass and that should in turn be compiled via _XlaLaunch operators. diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index faab7bd3d25d2491cf74faeb3b06acf4c2d6a054..a8869c8e2a7c164f97917cdae312289efb8b2663 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -76,7 +76,7 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected, #define TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(expected, actual) \ do { \ string diff; \ - EXPECT_TRUE(EqualFunctionDefLibrary(actual, expected, &diff)) \ + EXPECT_TRUE(EqualFunctionDefLibrary(expected, actual, &diff)) \ << diff << "\nActual: " << actual.DebugString(); \ } while (false) @@ -109,7 +109,7 @@ Node* Binary(ops::NodeOut a, ops::NodeOut b, return ops::BinaryOp("BinaryTest", a, b, opts); } -Node* AddNLike(std::vector inputs, +Node* AddNLike(const std::vector& inputs, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; NodeBuilder node_builder(opts.GetNameForOp("AddN"), "AddNLikeTest", @@ -144,8 +144,9 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) { std::unique_ptr graph_out; s = EncapsulateSubgraphsInFunctions("_encapsulate", *graph, - /* rewrite_subgraph_fn= */ {}, - /* parallel_checking= */ false, + /*rewrite_subgraph_fn=*/{}, + /*parallel_checking=*/false, + /*reuse_existing_functions=*/false, &graph_out, lib_def.get()); if (!s.ok()) return s; @@ -205,12 +206,12 @@ TEST(EncapsulateSubgraphsTest, OneFunction) { *library_expected.add_function() = test::function::XTimesTwo(); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"input__0:float", "input__1:float"}, {"output__2:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"c_0_retval:float"}, {}, { - {{"C"}, "UnaryTest", {"input__0"}}, - {{"c"}, "BinaryTest", {"input__1", "C:o:0"}, {}, {"C"}}, + {{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"c"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}, {"C"}}, }, - {{"output__2", "c:o:0"}}); + {{"c_0_retval", "c:o:0"}}); { std::unique_ptr lib_def( @@ -261,17 +262,17 @@ TEST(EncapsulateSubgraphsTest, TwoFunctions) { *library_expected.add_function() = test::function::XTimesTwo(); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"input__0:float"}, {"output__1:float"}, {}, + "F1", {"a_0_arg:float"}, {"c_0_retval:float"}, {}, { - {{"C"}, "UnaryTest", {"input__0"}}, + {{"C"}, "UnaryTest", {"a_0_arg"}}, }, - {{"output__1", "C:o:0"}}); + {{"c_0_retval", "C:o:0"}}); *library_expected.add_function() = FunctionDefHelper::Create( - "F2", {"input__0:float", "input__1:float"}, {"output__2:float"}, {}, + "F2", {"b_0_arg:float", "c_0_arg:float"}, {"d_0_retval:float"}, {}, { - {{"D"}, "BinaryTest", {"input__0", "input__1"}}, + {{"D"}, "BinaryTest", {"b_0_arg", "c_0_arg"}}, }, - {{"output__2", "D:o:0"}}); + {{"d_0_retval", "D:o:0"}}); { std::unique_ptr lib_def( @@ -340,7 +341,8 @@ TEST(EncapsulateSubgraphsTest, InputDeduplication) { std::unique_ptr graph; TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( "_cluster", graph_before_encapsulation, /*rewrite_subgraph_fn=*/{}, - /*parallel_checking=*/false, &graph, &library)); + /*parallel_checking=*/false, /*reuse_existing_functions=*/false, &graph, + &library)); std::vector expected_nodes = {"cluster1", "cluster2", "mul", "x"}; EXPECT_EQ(expected_nodes, GraphNodes(*graph)); @@ -371,7 +373,8 @@ TEST(EncapsulateSubgraphsTest, ParallelChecking) { std::unique_ptr graph; TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( "_cluster", graph_before_encapsulation, /*rewrite_subgraph_fn=*/{}, - /*parallel_checking=*/true, &graph, &library)); + /*parallel_checking=*/true, /*reuse_existing_functions=*/false, &graph, + &library)); std::vector expected_nodes = { "add1", "add2", "cluster1", "cluster1_parallel_check/_0", diff --git a/tensorflow/compiler/jit/graph_to_functiondef.cc b/tensorflow/compiler/jit/graph_to_functiondef.cc index 88e292a2c1ad8213bc49589a104b38622dee8327..83c23385008d56859b81abee7d292276036a45ee 100644 --- a/tensorflow/compiler/jit/graph_to_functiondef.cc +++ b/tensorflow/compiler/jit/graph_to_functiondef.cc @@ -126,8 +126,8 @@ Status GraphToFunctionDef(const Graph& graph, const string& name, if (node->type_string() == kArgOp) { int index; DataType type; - TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "T", &type)); - TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "index", &index)); + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &type)); + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index)); while (fdef->signature().input_arg_size() <= index) { fdef->mutable_signature()->add_input_arg(); } @@ -143,8 +143,8 @@ Status GraphToFunctionDef(const Graph& graph, const string& name, if (node->type_string() == kRetValOp) { int index; DataType type; - TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "T", &type)); - TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "index", &index)); + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &type)); + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index)); while (fdef->signature().output_arg_size() <= index) { fdef->mutable_signature()->add_output_arg(); } @@ -161,7 +161,7 @@ Status GraphToFunctionDef(const Graph& graph, const string& name, } NodeDef* node_def = fdef->add_node_def(); - node_def->CopyFrom(node->def()); + *node_def = node->def(); node_def->set_name(node_names.Uniquify(node->name())); // Reset input names based on graph rather than the NodeDef. @@ -203,8 +203,8 @@ Status GraphToFunctionDef(const Graph& graph, const string& name, // Populate tensor_renaming. NameRangeMap output_ranges; - TF_RETURN_IF_ERROR(NameRangesForNode(node->def(), node->op_def(), nullptr, - &output_ranges)); + TF_RETURN_IF_ERROR( + NameRangesForNode(*node, node->op_def(), nullptr, &output_ranges)); for (const auto& output : output_ranges) { for (int i = output.second.first; i < output.second.second; ++i) { const string tensor_name = strings::StrCat( diff --git a/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc index c741ccfb31efa8794ae745e2e52e3c91b20cfcfc..29c5ff724299ec84d31268c4227259ec02d10742 100644 --- a/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc @@ -34,7 +34,7 @@ namespace tensorflow { namespace { -Status BuildCompilationCache(ResourceMgr* rm, XlaCompilationCache** compiler) { +Status BuildCompilationCache(ResourceMgr* rm, XlaCompilationCache** cache) { XlaDevice::Metadata* metadata; Status s = rm->Lookup(rm->default_container(), "xla_metadata", &metadata); @@ -42,12 +42,8 @@ Status BuildCompilationCache(ResourceMgr* rm, XlaCompilationCache** compiler) { return s; } core::ScopedUnref metadata_ref(metadata); - XlaCompiler::Options options; - options.device_type = metadata->jit_device_type(); - options.client = metadata->client(); - options.allow_cpu_custom_calls = false; - options.local_executable_has_hybrid_result = false; - *compiler = new XlaCompilationCache(options); + *cache = + new XlaCompilationCache(metadata->client(), metadata->jit_device_type()); return Status::OK(); } @@ -59,7 +55,7 @@ XlaDeviceLaunchOp::XlaDeviceLaunchOp(OpKernelConstruction* ctx) OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func)); function_ = *func; VLOG(1) << "XlaDeviceLaunch created function=" - << Canonicalize(function_.name(), function_.attr()); + << Canonicalize(function_.name(), AttrSlice(&function_.attr())); DataTypeVector constant_types; OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types)); num_constant_args_ = constant_types.size(); @@ -85,29 +81,37 @@ std::vector SnapshotResourceVariables(OpKernelContext* ctx, void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) { VLOG(1) << "XlaDeviceLaunch::Compute " - << Canonicalize(function_.name(), function_.attr()); + << Canonicalize(function_.name(), AttrSlice(&function_.attr())); // We store information about the JIT-compiled XLA computation // in the ResourceMgr. ResourceMgr* rm = ctx->resource_manager(); OP_REQUIRES(ctx, rm, errors::Internal("No resource manager.")); - XlaCompilationCache* compiler; + XlaCompilationCache* cache; OP_REQUIRES_OK(ctx, rm->LookupOrCreate( - rm->default_container(), "xla_compiler", &compiler, - [rm](XlaCompilationCache** compiler) { - return BuildCompilationCache(rm, compiler); + rm->default_container(), "xla_compiler", &cache, + [rm](XlaCompilationCache** cache) { + return BuildCompilationCache(rm, cache); })); // Holds the reference to the JIT during evaluation. (We could probably // free it sooner because the ResourceMgr will retain a reference, but // this is more obviously correct.) - core::ScopedUnref compiler_ref(compiler); + core::ScopedUnref cache_ref(cache); std::vector variables = SnapshotResourceVariables(ctx, num_resource_args_); + XlaCompiler::Options options; + options.client = cache->client(); + options.device_type = &cache->device_type(); + options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); + options.graph_def_version = ctx->function_library()->graph_def_version(); + options.allow_cpu_custom_calls = false; + options.local_executable_has_hybrid_result = false; + const XlaCompiler::CompilationResult* kernel; - OP_REQUIRES_OK(ctx, compiler->Compile(function_, num_constant_args_, - variables, ctx, &kernel, nullptr)); + OP_REQUIRES_OK(ctx, cache->Compile(options, function_, num_constant_args_, + variables, ctx, &kernel, nullptr)); VLOG(1) << "XLA compilation complete..."; @@ -117,7 +121,7 @@ void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) { // Runs the computation, if any. There might not be a computation if all // outputs were compile-time constants. std::vector> outputs; - if (!kernel->computation.IsNull()) { + if (!kernel->computation->IsNull()) { auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape(); // Builds the inputs to the computation. @@ -148,8 +152,8 @@ void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) { Env* env = Env::Default(); auto start_time = env->NowMicros(); VLOG(1) << "Executing XLA Computation..."; - auto result = compiler->client()->Execute(kernel->computation, arg_ptrs, - &execution_options, &profile); + auto result = cache->client()->Execute(*kernel->computation, arg_ptrs, + &execution_options, &profile); auto elapsed = env->NowMicros() - start_time; OP_REQUIRES(ctx, result.ok(), result.status()); @@ -158,7 +162,7 @@ void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) { if (xla::ShapeUtil::IsTuple(kernel->xla_output_shape)) { auto outputs_or_error = - compiler->client()->DeconstructTuple(*result.ValueOrDie()); + cache->client()->DeconstructTuple(*result.ValueOrDie()); OP_REQUIRES(ctx, outputs_or_error.ok(), outputs_or_error.status()); outputs = outputs_or_error.ConsumeValueOrDie(); } else { diff --git a/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc index 8b43c7c1564a340b70e8cfa271a3ef50379b46bc..40acc0d81d08230b373823e333cd5e3e407b9c4f 100644 --- a/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc @@ -148,24 +148,28 @@ XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) OP_REQUIRES(ctx, num_resource_args == 0, errors::Unimplemented( "XlaLocalLaunchOp does not support resource variables")); -} - -Status XlaLocalLaunchOp::BuildCompilationCache(XlaCompilationCache** compiler) { - gpu::Platform::Id platform_id; if (device_type_ == DeviceType(DEVICE_CPU)) { - platform_id = gpu::host::kHostPlatformId; + platform_id_ = gpu::host::kHostPlatformId; } else if (device_type_ == DeviceType(DEVICE_GPU)) { - platform_id = gpu::cuda::kCudaPlatformId; + platform_id_ = gpu::cuda::kCudaPlatformId; } else { - return errors::InvalidArgument("Unknown device type for local _XlaLaunch"); + ctx->SetStatus( + errors::InvalidArgument("Unknown device type for local _XlaLaunch")); + return; } +} - auto platform = gpu::MultiPlatformManager::PlatformWithId(platform_id); +Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx, + XlaCompilationCache** cache) { + auto platform = gpu::MultiPlatformManager::PlatformWithId(platform_id_); if (!platform.ok()) { return StreamExecutorUtil::ConvertStatus(platform.status()); } - auto client = - xla::ClientLibrary::GetOrCreateLocalClient(platform.ValueOrDie()); + xla::LocalClientOptions client_options; + client_options.set_platform(platform.ValueOrDie()); + client_options.set_intra_op_parallelism_threads( + ctx->device()->tensorflow_cpu_worker_threads()->num_threads); + auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options); if (!client.ok()) { return client.status(); } @@ -175,18 +179,14 @@ Status XlaLocalLaunchOp::BuildCompilationCache(XlaCompilationCache** compiler) { return errors::InvalidArgument("No JIT device registered for ", device_type_.type()); } - XlaCompiler::Options options; - options.device_type = DeviceType(registration->compilation_device_name); - options.client = client.ValueOrDie(); - options.allow_cpu_custom_calls = (platform_id == gpu::host::kHostPlatformId); - options.local_executable_has_hybrid_result = true; - *compiler = new XlaCompilationCache(options); + *cache = new XlaCompilationCache( + client.ValueOrDie(), DeviceType(registration->compilation_device_name)); return Status::OK(); } void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { VLOG(1) << "XlaLocalLaunchOp::Compute " - << Canonicalize(function_.name(), function_.attr()); + << Canonicalize(function_.name(), AttrSlice(&function_.attr())); // We store information about the JIT-compiled XLA computation // in the ResourceMgr. ResourceMgr* rm = ctx->resource_manager(); @@ -195,23 +195,31 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { gpu::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; - XlaCompilationCache* compiler; + XlaCompilationCache* cache; OP_REQUIRES_OK(ctx, rm->LookupOrCreate( - rm->default_container(), "xla_compiler", &compiler, - [this](XlaCompilationCache** compiler) { - return BuildCompilationCache(compiler); + rm->default_container(), "xla_cache", &cache, + [this, ctx](XlaCompilationCache** cache) { + return BuildCompilationCache(ctx, cache); })); // Hold the reference to the JIT during evaluation. (We could probably // free it sooner because the ResourceMgr will retain a reference, but // this is more obviously correct.) - core::ScopedUnref compiler_ref(compiler); + core::ScopedUnref cache_ref(cache); + + xla::LocalClient* client = static_cast(cache->client()); - xla::LocalClient* client = static_cast(compiler->client()); + XlaCompiler::Options options; + options.client = client; + options.device_type = &cache->device_type(); + options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); + options.graph_def_version = ctx->function_library()->graph_def_version(); + options.allow_cpu_custom_calls = (platform_id_ == gpu::host::kHostPlatformId); + options.local_executable_has_hybrid_result = true; const XlaCompiler::CompilationResult* kernel; xla::LocalExecutable* executable; - OP_REQUIRES_OK(ctx, compiler->Compile(function_, num_constant_args_, {}, ctx, - &kernel, &executable)); + OP_REQUIRES_OK(ctx, cache->Compile(options, function_, num_constant_args_, {}, + ctx, &kernel, &executable)); VLOG(1) << "Executing XLA Computation..."; @@ -221,7 +229,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { std::unique_ptr output; bool output_is_tuple; - if (!kernel->computation.IsNull()) { + if (!kernel->computation->IsNull()) { // Build xla::ShapedBuffers that point directly to the Tensor buffers. std::vector> arg_buffers; arg_buffers.reserve(kernel->xla_input_shapes.size() + 1); @@ -260,8 +268,6 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { xla::ExecutableRunOptions run_options; run_options.set_stream(stream); run_options.set_allocator(&xla_allocator); - run_options.set_inter_op_thread_pool( - ctx->device()->tensorflow_cpu_worker_threads()->workers); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); Env* env = Env::Default(); auto start_time = env->NowMicros(); diff --git a/tensorflow/compiler/jit/kernels/xla_local_launch_op.h b/tensorflow/compiler/jit/kernels/xla_local_launch_op.h index 8023206762951a4dafba900dd291f2ee9bdbbdf3..5e4d3336a91001fac1d222709f64300e777247c7 100644 --- a/tensorflow/compiler/jit/kernels/xla_local_launch_op.h +++ b/tensorflow/compiler/jit/kernels/xla_local_launch_op.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/util/stream_executor_util.h" namespace tensorflow { @@ -43,11 +44,15 @@ class XlaLocalLaunchOp : public OpKernel { private: // Builds a XlaCompilationCache class suitable for the current device. - Status BuildCompilationCache(XlaCompilationCache** compiler); + Status BuildCompilationCache(OpKernelContext* ctx, + XlaCompilationCache** compiler); DeviceType device_type_; NameAttrList function_; int num_constant_args_; + + perftools::gputools::Platform::Id platform_id_; + TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp); }; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index b27c07d0d987aafef1943fd795293bd066ad36f6..73c4e80551485189d1e43fd93eed39083bd6b6b7 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -52,20 +52,22 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { // Make sure we don't recurse infinitely on recursive functions. const int kMaxRecursionDepth = 10; -bool IsCompilableCall(const NodeDef& call_def, DeviceType jit_device_type, - int depth, FunctionLibraryRuntime* lib_runtime); +bool IsCompilableCall(const NodeDef& call_def, + const DeviceType& jit_device_type, int depth, + FunctionLibraryRuntime* lib_runtime); -// Tests whether 'while_def' is a completely compilable loop. +// Tests whether 'while_node' is a completely compilable loop. // Every operator in the condition and body functions must be compilable for a // while loop to be compilable. -bool IsCompilableWhile(const NodeDef& while_def, DeviceType jit_device_type, - int depth, FunctionLibraryRuntime* lib_runtime) { - VLOG(2) << "Loop marking: " << while_def.op(); +bool IsCompilableWhile(const Node& while_node, + const DeviceType& jit_device_type, int depth, + FunctionLibraryRuntime* lib_runtime) { + VLOG(2) << "Loop marking: " << while_node.type_string(); const NameAttrList* name_attr; NodeDef call; Status status; - status = GetNodeAttr(while_def, "cond", &name_attr); + status = GetNodeAttr(while_node.attrs(), "cond", &name_attr); if (!status.ok()) { VLOG(2) << "Missing 'cond' attribute on While node."; return false; @@ -78,7 +80,7 @@ bool IsCompilableWhile(const NodeDef& while_def, DeviceType jit_device_type, VLOG(2) << "Can't compile loop condition: " << cond_func; return false; } - status = GetNodeAttr(while_def, "body", &name_attr); + status = GetNodeAttr(while_node.attrs(), "body", &name_attr); if (!status.ok()) { VLOG(2) << "Missing 'body' attribute on While node."; return false; @@ -98,8 +100,9 @@ bool IsCompilableWhile(const NodeDef& while_def, DeviceType jit_device_type, // Tests whether 'call_def' is a call to a completely compilable function. // Every operator in the function must be compilable for a function to be // compilable. -bool IsCompilableCall(const NodeDef& call_def, DeviceType jit_device_type, - int depth, FunctionLibraryRuntime* lib_runtime) { +bool IsCompilableCall(const NodeDef& call_def, + const DeviceType& jit_device_type, int depth, + FunctionLibraryRuntime* lib_runtime) { VLOG(2) << "Function marking: " << call_def.op(); if (depth > kMaxRecursionDepth) { @@ -109,7 +112,7 @@ bool IsCompilableCall(const NodeDef& call_def, DeviceType jit_device_type, FunctionLibraryRuntime::Handle handle; Status status = - lib_runtime->Instantiate(call_def.op(), call_def.attr(), &handle); + lib_runtime->Instantiate(call_def.op(), AttrSlice(call_def), &handle); if (!status.ok()) { VLOG(2) << "Could not instantiate " << call_def.op() << ": " << status; return false; @@ -131,11 +134,11 @@ bool IsCompilableCall(const NodeDef& call_def, DeviceType jit_device_type, for (Node* node : fbody->graph->nodes()) { if (node->IsSource() || node->IsSink()) continue; - if (node->def().op() == "_Arg" || node->def().op() == "_Retval") continue; - if (node->def().op() == "While") { + if (node->type_string() == "_Arg" || node->type_string() == "_Retval") + continue; + if (node->type_string() == "While") { // Handle functional While loop (not in open source build). - return IsCompilableWhile(node->def(), jit_device_type, depth + 1, - lib_runtime); + return IsCompilableWhile(*node, jit_device_type, depth + 1, lib_runtime); } if (!HasXLAKernel(*node, jit_device_type) && !IsCompilableCall(node->def(), jit_device_type, depth + 1, @@ -189,17 +192,16 @@ Status FindCompilationCandidates( if (!HasXLAKernel(*node, jit_device_type) && !IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime.get())) { VLOG(2) << "Compilation rejected node: unsupported op " << node->name() - << ": " << node->def().op(); + << ": " << node->type_string(); continue; } if (!registration->compile_resource_ops && HasResourceArgument(*node)) { VLOG(2) << "Compilation rejected node: resource argument " << node->name() - << ": " << node->def().op(); + << ": " << node->type_string(); continue; } - if (node->def().op() == "While" && - !IsCompilableWhile(node->def(), jit_device_type, 0, - lib_runtime.get())) { + if (node->type_string() == "While" && + !IsCompilableWhile(*node, jit_device_type, 0, lib_runtime.get())) { continue; } candidates->insert(node); @@ -316,10 +318,10 @@ Status MarkForCompilationPass::Run( // If there is a _XlaCompile annotation, use its value. bool compile = false; - Status status = GetNodeAttr(node->def(), kXlaCompileAttr, &compile); + Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile); if (status.ok()) return compile; - status = fld->GetAttr(node->def(), kXlaCompileAttr, &compile); + status = fld->GetAttr(*node, kXlaCompileAttr, &compile); if (status.ok()) return compile; // Otherwise use the value of global_jit_level. @@ -482,8 +484,8 @@ Status MarkForCompilationPass::RunImpl( // all nodes marked with _XlaCompile=true to also have a // _XlaScope property set (and raise an error otherwise); but // for now we don't do this. - if (GetNodeAttr(node_from->def(), kXlaScopeAttr, &from_scope).ok() && - GetNodeAttr(node_to->def(), kXlaScopeAttr, &to_scope).ok() && + if (GetNodeAttr(node_from->attrs(), kXlaScopeAttr, &from_scope).ok() && + GetNodeAttr(node_to->attrs(), kXlaScopeAttr, &to_scope).ok() && from_scope != to_scope) { continue; } @@ -538,10 +540,9 @@ Status MarkForCompilationPass::RunImpl( // Compile if the user marked this node _XlaCompile=true bool compile_attr = false; bool marked_for_compilation = false; - if (GetNodeAttr(n->def(), kXlaCompileAttr, &compile_attr).ok()) { + if (GetNodeAttr(n->attrs(), kXlaCompileAttr, &compile_attr).ok()) { marked_for_compilation = compile_attr; - } else if (options.flib_def - ->GetAttr(n->def(), kXlaCompileAttr, &compile_attr) + } else if (options.flib_def->GetAttr(*n, kXlaCompileAttr, &compile_attr) .ok()) { marked_for_compilation = compile_attr; } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 91e4a2b41c7026b6ca028ed6a7e61588d57e9e50..9f30e12e0e30fef6b4bcd0ea3c091842b008c29a 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -57,7 +57,7 @@ std::unordered_map GetClusters(const Graph& graph) { std::unordered_map ids; for (Node* node : graph.nodes()) { string cluster; - if (GetNodeAttr(node->def(), kXlaClusterAttr, &cluster).ok()) { + if (GetNodeAttr(node->attrs(), kXlaClusterAttr, &cluster).ok()) { CHECK(!cluster.empty()); ids[node->name()] = cluster; } diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 41abea02eb2d17423744dfb719ee9a3f6b8f1198..63ca77f9a912acce2078f3da43d64f2e10049380 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -37,9 +37,9 @@ limitations under the License. namespace tensorflow { -XlaCompilationCache::XlaCompilationCache(const XlaCompiler::Options& options) - : compiler_(options) {} - +XlaCompilationCache::XlaCompilationCache(xla::Client* client, + DeviceType device_type) + : client_(client), device_type_(std::move(device_type)) {} XlaCompilationCache::~XlaCompilationCache() = default; string XlaCompilationCache::DebugString() { @@ -95,7 +95,7 @@ Status XlaCompilationCache::BuildSignature( const NameAttrList& function, int num_constant_args, const std::vector& variable_args, OpKernelContext* ctx, Signature* signature) { - signature->name = Canonicalize(function.name(), function.attr()); + signature->name = Canonicalize(function.name(), AttrSlice(&function.attr())); signature->arg_values.resize(num_constant_args); signature->arg_types.reserve(ctx->num_inputs() - num_constant_args); @@ -205,8 +205,9 @@ Status BuildArguments(int num_constant_args, } // namespace Status XlaCompilationCache::Compile( - const NameAttrList& function, int num_constant_args, - const std::vector& variable_args, OpKernelContext* ctx, + const XlaCompiler::Options& options, const NameAttrList& function, + int num_constant_args, const std::vector& variable_args, + OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable) { VLOG(1) << "XlaCompilationCache::Compile " << DebugString(); @@ -263,21 +264,18 @@ Status XlaCompilationCache::Compile( TF_RETURN_IF_ERROR( BuildArguments(num_constant_args, variable_args, ctx, &args)); - std::unique_ptr flr(NewFunctionLibraryRuntime( - compiler_.device_mgr(), ctx->env(), compiler_.device(), - TF_GRAPH_DEF_VERSION, - ctx->function_library()->GetFunctionLibraryDefinition(), - OptimizerOptions(), nullptr /* custom_kernel_creator */)); - + XlaCompiler compiler(options); entry->compiled = true; - entry->compilation_status = compiler_.CompileFunction( - flr.get(), function, args, &entry->compilation_result); + entry->compilation_status = + compiler.CompileFunction(XlaCompiler::CompileOptions(), function, args, + &entry->compilation_result); } *compilation_result = &entry->compilation_result; if (entry->compilation_status.ok() && executable) { if (entry->executable == nullptr && - !entry->compilation_result.computation.IsNull()) { - entry->compilation_status = compiler_.BuildExecutable( + !entry->compilation_result.computation->IsNull()) { + XlaCompiler compiler(options); + entry->compilation_status = compiler.BuildExecutable( entry->compilation_result, &entry->executable); } *executable = entry->executable.get(); diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index ff67e48d1a9a9f16881c2e141b23ce8c479aef50..4ffcb68a3220b2354a3542e4c2a4d3e000969e0b 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -46,7 +46,7 @@ struct OptionalTensor { // bound. class XlaCompilationCache : public ResourceBase { public: - explicit XlaCompilationCache(const XlaCompiler::Options& options); + XlaCompilationCache(xla::Client* client, DeviceType device_type); ~XlaCompilationCache() override; // Compiles a function into a XlaCompiler::CompilationResult that can be used @@ -61,19 +61,21 @@ class XlaCompilationCache : public ResourceBase { // xla::LocalExecutable and sets `executable to point to it. The resulting // executable pointer may be null if the computation has no non-constant // outputs. - Status Compile(const NameAttrList& function, int num_constant_args, + Status Compile(const XlaCompiler::Options& options, + const NameAttrList& function, int num_constant_args, const std::vector& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable); - xla::Client* client() const { return compiler_.client(); } + xla::Client* client() const { return client_; } + const DeviceType& device_type() const { return device_type_; } string DebugString() override; private: - XlaCompiler compiler_; - std::unique_ptr function_library_runtime_; + xla::Client* const client_; + const DeviceType device_type_; // Describes the types, shapes and any compile-time constant arguments // to a kernel. Key that uniquely identifies a compilation output. diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 93f487c36ca5ca8f7e3930cf8f053367400d7920..5e336c5287bd9e2067e93cd8db8a5a1b62b62bd2 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -125,7 +125,7 @@ XlaDevice::XlaDevice(const SessionOptions& options, const DeviceType& jit_device_name, perftools::gputools::Platform* platform, Allocator* xla_allocator) - : LocalDevice(options, attrs, xla_allocator), + : LocalDevice(options, attrs), device_ordinal_(device_ordinal), jit_device_name_(jit_device_name), xla_allocator_(xla_allocator), diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index a52239df252b2b556987fa9701f43047765c60de..8699006ebc5aacafd46046a7c3f093356f687280 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -63,30 +63,10 @@ class XlaDeviceDummyOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("PlaceholderV2").Device(DEVICE), \ PlaceholderOp); \ \ - REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE), \ - ControlTriggerOp); \ - REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE), EnterOp); \ - REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE), ExitOp); \ - REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE), \ - NextIterationOp); \ - REGISTER_KERNEL_BUILDER(Name("Switch").Device(DEVICE).HostMemory("pred"), \ - SwitchOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("Merge").Device(DEVICE).HostMemory("value_index"), MergeOp); \ - REGISTER_KERNEL_BUILDER(Name("LoopCond") \ - .Device(DEVICE) \ - .HostMemory("input") \ - .HostMemory("output"), \ - IdentityOp); \ - \ REGISTER_KERNEL_BUILDER( \ Name("VarHandleOp").Device(DEVICE).HostMemory("resource"), \ ResourceHandleOp); -// TODO(b/32507444): the registrations for the control flow operators are -// temporary and exist primarily to work around a bug in the graph partitioning -// code. - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_ diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 0592e3d4b1993d132aa955171c3b523af9869fee..19f7ff835456855a2b2ab7d5856f1d3e6f7f9733 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -65,6 +65,20 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "adam_test", + size = "small", + srcs = ["adam_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + ], +) + tf_xla_py_test( name = "binary_ops_test", size = "small", @@ -156,6 +170,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "slice_ops_test", + size = "small", + srcs = ["slice_ops_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "function_test", size = "small", diff --git a/tensorflow/compiler/tests/adagrad_test.py b/tensorflow/compiler/tests/adagrad_test.py index 0a2c9e26c6fbd827d5ab669dea5419f9fa50025b..a5c5885b4284aee167ae4cb18f7e42820c6d251d 100644 --- a/tensorflow/compiler/tests/adagrad_test.py +++ b/tensorflow/compiler/tests/adagrad_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Functional tests for aggregate operations.""" +"""Tests for Adagrad.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/compiler/tests/adam_test.py b/tensorflow/compiler/tests/adam_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3215dc36e5b2d517aa951db1b0d41188185ef93a --- /dev/null +++ b/tensorflow/compiler/tests/adam_test.py @@ -0,0 +1,176 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Adam.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import adam + + +def adam_update_numpy(param, + g_t, + t, + m, + v, + alpha=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-8): + alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t) + + m_t = beta1 * m + (1 - beta1) * g_t + v_t = beta2 * v + (1 - beta2) * g_t * g_t + + param_t = param - alpha_t * m_t / (np.sqrt(v_t) + epsilon) + return param_t, m_t, v_t + + +class AdamOptimizerTest(XLATestCase): + + def testBasic(self): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + variable_scope.get_variable_scope().set_use_resource(True) + + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype) + + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + grads0 = array_ops.placeholder(dtype) + grads1 = array_ops.placeholder(dtype) + opt = adam.AdamOptimizer() + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + update.run(feed_dict={grads0: grads0_np, grads1: grads1_np}) + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testTensorLearningRate(self): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + variable_scope.get_variable_scope().set_use_resource(True) + + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype) + + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + grads0 = array_ops.placeholder(dtype) + grads1 = array_ops.placeholder(dtype) + opt = adam.AdamOptimizer(constant_op.constant(0.001)) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + update.run(feed_dict={grads0: grads0_np, grads1: grads1_np}) + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testSharing(self): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + variable_scope.get_variable_scope().set_use_resource(True) + + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype) + + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + grads0 = array_ops.placeholder(dtype) + grads1 = array_ops.placeholder(dtype) + opt = adam.AdamOptimizer() + update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run 3 steps of intertwined Adam1 and Adam2. + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + if t % 2 == 0: + update1.run(feed_dict={grads0: grads0_np, grads1: grads1_np}) + else: + update2.run(feed_dict={grads0: grads0_np, grads1: grads1_np}) + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index 7d91594db009f79475afc30ca4a8972b157806ee..2a71543f3febe3cb692fdcd563772c3bd2d3724a 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -94,7 +94,7 @@ class OpTestBuilder { explicit OpTestBuilder(const string& op_name); // Adds an input 'tensor'. - OpTestBuilder& Input(Tensor tensor); + OpTestBuilder& Input(const Tensor& tensor); // Sets an attribute. template @@ -111,8 +111,8 @@ class OpTestBuilder { // sets it to the NodeDef of the operator under test. Fills 'inputs' and // 'outputs' with the names of the input placeholder nodes and the output // identity nodes, respectively. - Status BuildGraph(string name_prefix, string device, bool use_jit, - GraphDef* graphdef, NodeDef** test_node_def, + Status BuildGraph(const string& name_prefix, const string& device, + bool use_jit, GraphDef* graphdef, NodeDef** test_node_def, std::vector* inputs, std::vector* outputs) const; @@ -127,7 +127,7 @@ OpTestBuilder::OpTestBuilder(const string& op_name) { node_def_.set_op(op_name); } -OpTestBuilder& OpTestBuilder::Input(Tensor tensor) { +OpTestBuilder& OpTestBuilder::Input(const Tensor& tensor) { VLOG(1) << "Adding input: " << tensor.DebugString(); inputs_.push_back(tensor); return *this; @@ -146,9 +146,9 @@ OpTestBuilder& OpTestBuilder::Attr(StringPiece attr_name, return *this; } -Status OpTestBuilder::BuildGraph(string name_prefix, string device, - bool use_jit, GraphDef* graphdef, - NodeDef** test_node_def, +Status OpTestBuilder::BuildGraph(const string& name_prefix, + const string& device, bool use_jit, + GraphDef* graphdef, NodeDef** test_node_def, std::vector* inputs, std::vector* outputs) const { OpRegistryInterface* op_registry = OpRegistry::Global(); @@ -209,7 +209,7 @@ class OpTest : public ::testing::Test { // Runs 'fn' up to --tf_xla_test_repetitions times, or until a failure occurs; // whichever happens first. - void Repeatedly(std::function fn); + void Repeatedly(const std::function& fn); // Select a random element from 'candidates'. template @@ -315,7 +315,7 @@ OpTest::OpTest() { TF_CHECK_OK(session_->Create(def)); } -void OpTest::Repeatedly(std::function fn) { +void OpTest::Repeatedly(const std::function& fn) { int const max_repetitions = tf_xla_test_repetitions; for (int i = 0; !HasFailure() && i < max_repetitions; ++i) { fn(); diff --git a/tensorflow/compiler/tests/slice_ops_test.py b/tensorflow/compiler/tests/slice_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..795885f8302dbf41ef04e37b87abdd0d4bf12727 --- /dev/null +++ b/tensorflow/compiler/tests/slice_ops_test.py @@ -0,0 +1,142 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for slicing.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import googletest + +class SliceTest(XLATestCase): + + def test1D(self): + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[10]) + with self.test_scope(): + o = array_ops.slice(i, [2], [4]) + params = { + i: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([2, 3, 4, 5], result) + + def test3D(self): + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[3, 3, 10]) + with self.test_scope(): + o = array_ops.slice(i, [1, 2, 2], [1, 1, 4]) + params = { + i: [[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [9, 8, 7, 6, 5, 4, 3, 2, 1, 0], + [5, 3, 1, 7, 9, 2, 4, 6, 8, 0]], + [[5, 5, 5, 5, 5, 5, 5, 5, 5, 5], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [8, 7, 6, 5, 4, 3, 2, 1, 8, 7]], + [[7, 5, 7, 5, 7, 5, 7, 5, 7, 5], + [1, 2, 1, 2, 1, 2, 1, 2, 1, 2], + [9, 8, 7, 9, 8, 7, 9, 8, 7, 9]]] + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([[[6, 5, 4, 3]]], result) + +class StridedSliceTest(XLATestCase): + + def test1D(self): + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[10]) + with self.test_scope(): + o = array_ops.strided_slice(i, [2], [6], [2]) + params = { + i: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([2, 4], result) + + def test1DNegtiveStride(self): + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[10]) + with self.test_scope(): + o = array_ops.strided_slice(i, [6], [2], [-2]) + params = { + i: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([6, 4], result) + + def test3D(self): + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[3, 3, 10]) + with self.test_scope(): + o = array_ops.strided_slice(i, [0, 2, 2], [2, 3, 6], [1, 1, 2]) + params = { + i: [[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [9, 8, 7, 6, 5, 4, 3, 2, 1, 0], + [5, 3, 1, 7, 9, 2, 4, 6, 8, 0]], + [[5, 5, 5, 5, 5, 5, 5, 5, 5, 5], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [8, 7, 6, 5, 4, 3, 2, 1, 8, 7]], + [[7, 5, 7, 5, 7, 5, 7, 5, 7, 5], + [1, 2, 1, 2, 1, 2, 1, 2, 1, 2], + [9, 8, 7, 9, 8, 7, 9, 8, 7, 9]]] + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([[[1, 9]], + [[6, 4]]], result) + + def test3DNegativeStride(self): + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[3, 4, 10]) + with self.test_scope(): + o = array_ops.strided_slice(i, [2, 2, 6], [0, 0, 2], [-1, -1, -2]) + params = { + i: [[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [9, 8, 7, 6, 5, 4, 3, 2, 1, 0], + [5, 3, 1, 7, 9, 2, 4, 6, 8, 0], + [4, 5, 2, 4, 3, 7, 6, 8, 9, 4]], + [[5, 5, 5, 5, 5, 5, 5, 5, 5, 5], + [4, 3, 4, 5, 7, 6, 5, 3, 4, 5], + [8, 7, 6, 5, 4, 3, 2, 1, 8, 7], + [7, 1, 7, 1, 8, 1, 8, 1, 3, 1]], + [[7, 5, 7, 5, 7, 5, 7, 5, 7, 5], + [1, 2, 1, 2, 1, 2, 1, 2, 1, 2], + [9, 8, 7, 9, 8, 7, 9, 8, 7, 9], + [9, 9, 5, 5, 6, 6, 3, 3, 6, 6]]] + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([[[9, 8], + [1, 1]], + [[2, 4], + [5, 7]]], result) + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index dcb9e2db2f8ca7ef6e89cb9c6493d15dcaacd46e..fef390fd67f38bc2b1a26cb2e80ffa4ca834d98d 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables @@ -36,6 +37,21 @@ from tensorflow.python.training.gradient_descent import GradientDescentOptimizer class VariableOpsTest(XLATestCase): """Test cases for resource variable operators.""" + def testOneWriteOneOutput(self): + # Regression test for a bug where computations with one non-constant + # output and one variable update were mishandled. + for dtype in self.numeric_types: + init = np.array([[1, 2], [3, 4]], dtype=dtype) + with self.test_session() as sess, self.test_scope(): + v = resource_variable_ops.ResourceVariable(init) + sess.run(variables.variables_initializer([v])) + p = array_ops.placeholder(dtype) + x = v.assign_add(p) + with ops.control_dependencies([x]): + y = v.read_value() + self.assertAllClose(np.array([[2, 3], [4, 5]], dtype=dtype), + sess.run(y, {p: 1})) + def testReadWrite(self): """Tests initialization, reading, and writing a resource variable.""" with self.test_session() as session: diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py index 1388a892ba5a1d07c05eedf277085099923ae901..f5c228f8305d740b994dadc34c93b4e0ae32d785 100644 --- a/tensorflow/compiler/tests/xla_device_test.py +++ b/tensorflow/compiler/tests/xla_device_test.py @@ -18,15 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - from tensorflow.python.client import session as session_lib -from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -48,34 +43,6 @@ class XlaDeviceTest(test.TestCase): result = sess.run(w, {x: [1.5, 0.5]}) self.assertAllClose(result, [12., 2.], rtol=1e-3) - def testLoops(self): - """Tests that loops work on XLA devices.""" - - with session_lib.Session() as session: - x = array_ops.placeholder(dtypes.float32) - with ops.device("device:XLA_CPU:0"): - c = lambda i, _: math_ops.less(i, 5) - b = lambda i, x: (i + 1, x * 2.0 + 1.0) - _, y = control_flow_ops.while_loop(c, b, (constant_op.constant(0), x)) - - result = session.run(y, {x: np.float32(2)}) - self.assertAllClose(result, np.float32(95), rtol=1e-3) - - def testCond(self): - """Tests that tf.cond works on XLA devices.""" - - with session_lib.Session() as session: - x = array_ops.placeholder(dtypes.float32) - y = array_ops.placeholder(dtypes.float32) - c = array_ops.placeholder(dtypes.bool) - with ops.device("device:XLA_CPU:0"): - z = x + 1.0 - w = control_flow_ops.cond(c, lambda: z, lambda: y) - t = math_ops.add(z, w) - - result = session.run(t, {x: np.float32(2), y: np.float32(4), c: True}) - self.assertAllClose(result, np.float32(6), rtol=1e-3) - if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index 44ff13ca34e740b12f28d4952ab968472e5d1e57..4adc17b8382bd423264a693a09e2cec0803ad9cf 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -108,7 +108,7 @@ Status BackwardsConstAnalysis(const Graph& g, if (must_be_const.find(node) != must_be_const.end()) { if (node->type_string() == "_Arg") { int index; - status = GetNodeAttr(node->def(), "index", &index); + status = GetNodeAttr(node->attrs(), "index", &index); if (!status.ok()) return; compile_time_const_args->at(index) = true; return; @@ -124,8 +124,8 @@ Status BackwardsConstAnalysis(const Graph& g, if (range.first == range.second) return; NameRangeMap input_name_ranges; - status = NameRangesForNode(node->def(), node->op_def(), &input_name_ranges, - nullptr); + status = + NameRangesForNode(*node, node->op_def(), &input_name_ranges, nullptr); if (!status.ok()) return; for (auto it = range.first; it != range.second; ++it) { diff --git a/tensorflow/compiler/tf2xla/kernels/function_ops.cc b/tensorflow/compiler/tf2xla/kernels/function_ops.cc index d718f98545f66cb79a77d758a3fb7ee486d87b4b..8dacb6627bde516c92cb07b747207adbe85ada5b 100644 --- a/tensorflow/compiler/tf2xla/kernels/function_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/function_ops.cc @@ -68,7 +68,8 @@ class SymbolicGradientOp : public AsyncOpKernel { done); OP_REQUIRES_OK_ASYNC( - ctx, lib->Instantiate(kGradientOp, def().attr(), &handle_), done); + ctx, lib->Instantiate(kGradientOp, AttrSlice(&def().attr()), &handle_), + done); FunctionLibraryRuntime::Options opts; opts.step_id = ctx->step_id(); diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc index eff23bd77d23afc882c67f8168270d1cb4413977..ef844cc6c5ae07a3e6331971023a280ee0cafe41 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/gather_functor.h" #include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/macros.h" namespace tensorflow { @@ -63,7 +64,7 @@ EIGEN_STRONG_INLINE void gather_float_int32_xla_impl(float* out, void** data) { // Implements gather on CPU. This is called by an XLA custom call, set up by // gather_op.cc. -extern "C" void __attribute__((visibility("default"))) +extern "C" void TF_EXPORT gather_float_int32_xla_impl(float* out, void** data) { tensorflow::gather_float_int32_xla_impl(out, data); } diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc index ae31f6f2006959c03941a1eb04b31aecf52424b0..4c8693d1976bf0817a01c2bacbbf4708202ce51e 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/gather_functor.h" #include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/macros.h" namespace tensorflow { @@ -63,7 +64,7 @@ EIGEN_STRONG_INLINE void gather_float_int64_xla_impl(float* out, void** data) { // Implements gather on CPU. This is called by an XLA custom call, set up by // gather_op.cc. -extern "C" void __attribute__((visibility("default"))) +extern "C" void TF_EXPORT gather_float_int64_xla_impl(float* out, void** data) { tensorflow::gather_float_int64_xla_impl(out, data); } diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc index 0033a949a372684caadce70bf46a996a942e9ec4..a71f2fcf0f7755d4e9ed2a9fd8b50a2e07bcfd2f 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc @@ -18,6 +18,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -43,7 +44,7 @@ EIGEN_STRONG_INLINE void argmax_float_1d_xla_impl(void* out, void** data) { // Implements argmax on CPU. This is called by an XLA custom call, set up by // index_ops.cc. -extern "C" void __attribute__((visibility("default"))) +extern "C" void TF_EXPORT argmax_float_1d_xla_impl(void* out, void** data) { tensorflow::argmax_float_1d_xla_impl(out, data); } diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc index be8ad2317c9ba6a39f839c4a535440fb94365aa9..f30eb6121fc858c50b9c00255e86105fe8ebcc54 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc @@ -18,6 +18,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -45,7 +46,7 @@ EIGEN_STRONG_INLINE void argmax_float_2d_xla_impl(void* out, void** data) { // Implements argmax on CPU. This is called by an XLA custom call, set up by // index_ops.cc. -extern "C" void __attribute__((visibility("default"))) +extern "C" void TF_EXPORT argmax_float_2d_xla_impl(void* out, void** data) { tensorflow::argmax_float_2d_xla_impl(out, data); } diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 03e02299e33a4e2bf62e757b2092db35288b0bea..bbe157bbeac56a396d946685c164867194accb42 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -77,11 +77,9 @@ class StridedSliceOp : public XlaOpKernel { gtl::InlinedVector dimensions_to_reverse; gtl::InlinedVector slice_begin, slice_end; + bool simple_strides = true; for (int i = 0; i < begin.size(); ++i) { - // TODO(phawkins): implement strides != 1 when b/30878775 is fixed. - OP_REQUIRES( - ctx, strides[i] == 1 || strides[i] == -1, - errors::Unimplemented("Strides != 1 or -1 are not yet implemented")); + simple_strides &= (std::abs(strides[i]) == 1); if (strides[i] > 0) { slice_begin.push_back(begin[i]); slice_end.push_back(end[i]); @@ -99,6 +97,36 @@ class StridedSliceOp : public XlaOpKernel { slice = ctx->builder()->Rev(slice, dimensions_to_reverse); } + // If at least one of the strides is > 1 (or < -1) then use Slice + // to pull out each of the strided slices, and Concat to put them + // together again. + if (!simple_strides) { + + // Re-adjust the begin and end now that the periphery has been + // sliced away. + for (int d = 0; d < strides.size(); ++d) { + slice_end[d] -= slice_begin[d]; + slice_begin[d] = 0; + } + + for (int d = 0; d < strides.size(); ++d) { + int64 stride = std::abs(strides[d]); + if (stride > 1) { + std::vector to_concat; + int64 end = slice_end[d]; + for (int64 i = 0; i < end; i += stride) { + slice_begin[d] = i; + slice_end[d] = i+1; + to_concat.push_back(ctx->builder()->Slice(slice, slice_begin, + slice_end)); + } + slice = ctx->builder()->ConcatInDim(to_concat, d); + slice_begin[d] = 0; + slice_end[d] = to_concat.size(); + } + } + } + slice = ctx->builder()->Reshape(slice, final_shape.dim_sizes()); ctx->SetOutput(0, slice); } diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index f1d81f871423b220c6859c1dedf79b1c36a43e65..ddd81cb490cd76065735a5b7e78d04fd76c05f82 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -165,6 +165,106 @@ class ResourceApplyAdagrad : public XlaOpKernel { }; REGISTER_XLA_OP(Name("ResourceApplyAdagrad"), ResourceApplyAdagrad); +class ResourceApplyAdam : public XlaOpKernel { + public: + explicit ResourceApplyAdam(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + DataType var_type, m_type, v_type; + TensorShape var_shape, m_shape, v_shape; + OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape)); + OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(1, &m_type, &m_shape)); + OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(2, &v_type, &v_shape)); + + OP_REQUIRES( + ctx, dtype_ == var_type && dtype_ == m_type && dtype_ == v_type, + errors::InvalidArgument( + "Types of variable arguments to ResourceApplyRMSProp must match: ", + DataTypeString(dtype_), " vs. ", DataTypeString(var_type), " vs. ", + DataTypeString(m_type), " vs. ", DataTypeString(v_type))); + + TensorShape beta1_power_shape = ctx->InputShape(3); + TensorShape beta2_power_shape = ctx->InputShape(4); + TensorShape lr_shape = ctx->InputShape(5); + TensorShape beta1_shape = ctx->InputShape(6); + TensorShape beta2_shape = ctx->InputShape(7); + TensorShape epsilon_shape = ctx->InputShape(8); + TensorShape grad_shape = ctx->InputShape(9); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power_shape), + errors::InvalidArgument("beta1_power is not a scalar: ", + beta1_power_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_power_shape), + errors::InvalidArgument("beta2_power is not a scalar: ", + beta2_power_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar : ", + lr_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_shape), + errors::InvalidArgument("beta1 is not a scalar: ", + beta1_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_shape), + errors::InvalidArgument("beta2 is not a scalar: ", + beta2_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape), + errors::InvalidArgument("epsilon is not a scalar: ", + epsilon_shape.DebugString())); + + OP_REQUIRES(ctx, var_shape.IsSameSize(m_shape), + errors::InvalidArgument("var and m do not have the same shape", + var_shape.DebugString(), " ", + m_shape.DebugString())); + OP_REQUIRES(ctx, var_shape.IsSameSize(v_shape), + errors::InvalidArgument("var and v do not have the same shape", + var_shape.DebugString(), " ", + v_shape.DebugString())); + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument( + "var and grad do not have the same shape", + var_shape.DebugString(), " ", grad_shape.DebugString())); + + xla::ComputationDataHandle var, m, v; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &m)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &v)); + xla::ComputationDataHandle beta1_power = ctx->Input(3); + xla::ComputationDataHandle beta2_power = ctx->Input(4); + xla::ComputationDataHandle lr = ctx->Input(5); + xla::ComputationDataHandle beta1 = ctx->Input(6); + xla::ComputationDataHandle beta2 = ctx->Input(7); + xla::ComputationDataHandle epsilon = ctx->Input(8); + xla::ComputationDataHandle grad = ctx->Input(9); + + // alpha <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t) + // m_t <- beta1 * m_{t-1} + (1 - beta1) * g_t + // v_t <- beta2 * v_{t-1} + (1 - beta2) * g_t * g_t + // variable <- variable - alpha * m_t / (sqrt(v_t) + epsilon) + + xla::ComputationBuilder* b = ctx->builder(); + xla::ComputationDataHandle half = XlaHelpers::FloatLiteral(b, dtype_, 0.5); + xla::ComputationDataHandle one = XlaHelpers::FloatLiteral(b, dtype_, 1.0); + xla::ComputationDataHandle two = XlaHelpers::FloatLiteral(b, dtype_, 2.0); + + xla::ComputationDataHandle alpha = + b->Div(b->Mul(lr, b->Pow(b->Sub(one, beta2_power), half)), + b->Sub(one, beta1_power)); + m = b->Add(m, b->Mul(b->Sub(grad, m), b->Sub(one, beta1))); + v = b->Add(v, b->Mul(b->Sub(b->Pow(grad, two), v), b->Sub(one, beta2))); + var = + b->Sub(var, b->Div(b->Mul(m, alpha), b->Add(b->Pow(v, half), epsilon))); + + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, v)); + } + + private: + DataType dtype_; +}; +REGISTER_XLA_OP(Name("ResourceApplyAdam"), ResourceApplyAdam); + class ResourceApplyRMSProp : public XlaOpKernel { public: explicit ResourceApplyRMSProp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} diff --git a/tensorflow/compiler/tf2xla/str_util.cc b/tensorflow/compiler/tf2xla/str_util.cc index ce25d631271b54a36078cd0d3ac4d318d58db9fa..2b0834fe7b6c4d2199267dbe0ec1f7c2785aa9c7 100644 --- a/tensorflow/compiler/tf2xla/str_util.cc +++ b/tensorflow/compiler/tf2xla/str_util.cc @@ -22,7 +22,7 @@ limitations under the License. namespace tensorflow { namespace str_util { -void ReplaceAll(string* text, StringPiece from, StringPiece to) { +static void ReplaceAll(string* text, StringPiece from, StringPiece to) { size_t pos = 0; while ((pos = text->find(from.data(), pos, from.size())) != string::npos) { text->replace(pos, from.size(), to.data(), to.size()); diff --git a/tensorflow/compiler/tf2xla/str_util.h b/tensorflow/compiler/tf2xla/str_util.h index 4920b1a4d4875192d6f06988b810ad388bc6293b..51f25009d7003db0d72296619a469ecbbbb1808d 100644 --- a/tensorflow/compiler/tf2xla/str_util.h +++ b/tensorflow/compiler/tf2xla/str_util.h @@ -29,10 +29,6 @@ limitations under the License. namespace tensorflow { namespace str_util { -// Replace all non-overlapping occurrences of from with to in-place in text. If -// from is empty, it matches at the beginning of the text and after every byte. -void ReplaceAll(string* text, StringPiece from, StringPiece to); - // Replace all non-overlapping occurrences of the given (from,to) pairs in-place // in text. If from is empty, it matches at the beginning of the text and after // every byte. Each (from,to) replacement pair is processed in the order it is diff --git a/tensorflow/compiler/tf2xla/str_util_test.cc b/tensorflow/compiler/tf2xla/str_util_test.cc index f992007a34532157f86c90c717a5e24c3923f22d..8817f6902a8e58e796ca5240a9a24d7506d38793 100644 --- a/tensorflow/compiler/tf2xla/str_util_test.cc +++ b/tensorflow/compiler/tf2xla/str_util_test.cc @@ -25,36 +25,6 @@ limitations under the License. namespace tensorflow { namespace str_util { -class ReplaceAllTest : public ::testing::Test { - protected: - void ExpectReplaceAll(string text, StringPiece from, StringPiece to, - StringPiece want) { - ReplaceAll(&text, from, to); - EXPECT_EQ(text, want); - } -}; - -TEST_F(ReplaceAllTest, Simple) { - ExpectReplaceAll("", "", "", ""); - ExpectReplaceAll("", "", "X", "X"); - ExpectReplaceAll("", "", "XYZ", "XYZ"); - ExpectReplaceAll("banana", "", "", "banana"); - ExpectReplaceAll("banana", "", "_", "_b_a_n_a_n_a_"); - ExpectReplaceAll("banana", "", "__", "__b__a__n__a__n__a__"); - ExpectReplaceAll("banana", "a", "a", "banana"); - ExpectReplaceAll("banana", "a", "", "bnn"); - ExpectReplaceAll("banana", "a", "X", "bXnXnX"); - ExpectReplaceAll("banana", "a", "XX", "bXXnXXnXX"); - ExpectReplaceAll("banana", "an", "an", "banana"); - ExpectReplaceAll("banana", "an", "", "ba"); - ExpectReplaceAll("banana", "an", "X", "bXXa"); - ExpectReplaceAll("banana", "an", "XY", "bXYXYa"); - ExpectReplaceAll("banana", "an", "XYZ", "bXYZXYZa"); - ExpectReplaceAll("foo {{bar}} baz {{bar}}", "{{bar}}", "X", "foo X baz X"); - ExpectReplaceAll("foo {{bar}} baz {{bar}}", "{{bar}}", "ABCDEFGHIJKLMNOP", - "foo ABCDEFGHIJKLMNOP baz ABCDEFGHIJKLMNOP"); -} - class ReplaceAllPairsTest : public ::testing::Test { protected: void ExpectReplaceAllPairs( diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index d86e741b69e08652bac2dd7b5295c8ab2d94433a..362a1018955f9b6adbdea5ba718b81e9a2389957 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -76,8 +76,7 @@ XlaCompilationDevice::XlaCompilationDevice(const SessionOptions& options, options, Device::BuildDeviceAttributes( "", type, Bytes(256 << 20), DeviceLocality(), - strings::StrCat("device: XLA compilation device ", type.type())), - cpu_allocator()), + strings::StrCat("device: XLA compilation device ", type.type()))), allocator_(new XlaCompilationAllocator()) {} XlaCompilationDevice::~XlaCompilationDevice() {} diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 33b4a43aa1544f883d4242148ce77eebb8a4c54c..d4a917671b9cb9031e04d6840625b034720934c7 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -57,16 +57,37 @@ Status CheckSignature(const DataTypeVector& types, } // namespace +bool XlaCompiler::Argument::operator==( + const XlaCompiler::Argument& other) const { + if (std::tie(kind, type, shape, name) != + std::tie(other.kind, other.type, other.shape, other.name)) { + return false; + } + if (constant_value.shape() != other.constant_value.shape()) { + return false; + } + return constant_value.tensor_data() == other.constant_value.tensor_data(); +} + XlaCompiler::XlaCompiler(XlaCompiler::Options options) : options_(std::move(options)), initialization_status_(Status::OK()), next_step_id_(1), - device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)), + device_( + new XlaCompilationDevice(SessionOptions(), *options_.device_type)), device_mgr_({device_}) { + // We no longer need the device_type. + options_.device_type = nullptr; + if (options_.populate_resource_manager) { initialization_status_ = (*options_.populate_resource_manager)(device_->resource_manager()); } + + flib_runtime_.reset(NewFunctionLibraryRuntime( + &device_mgr_, Env::Default(), device_, options.graph_def_version, + options.flib_def, OptimizerOptions(), + nullptr /* custom_kernel_creator */)); } XlaCompiler::~XlaCompiler() = default; @@ -76,37 +97,35 @@ int64 XlaCompiler::NextStepId() { return next_step_id_++; } -// Prunes any nodes from a function that are not dependencies of the _Retval -// nodes. Used to prune stateful ops from within a function body, such as -// variable initializers, that should not be executed unless requested. -static void PruneUnreachableNodes(Graph* graph) { - std::unordered_set nodes; - for (Node* node : graph->nodes()) { - if (node->type_string() == "_Retval" || - StringPiece(node->type_string()).ends_with("Send")) { - nodes.insert(node); - } - } - PruneForReverseReachability(graph, nodes); +uint64 XlaCompiler::SignatureHash::operator()( + const std::pair>& signature) const { + return std::hash()(signature.first); } Status XlaCompiler::CompileFunction( - FunctionLibraryRuntime* flr, const NameAttrList& function, + const XlaCompiler::CompileOptions& options, const NameAttrList& function, const std::vector& args, XlaCompiler::CompilationResult* result) { - const string function_id = Canonicalize(function.name(), function.attr()); + const string function_id = + Canonicalize(function.name(), AttrSlice(&function.attr())); VLOG(1) << "XlaCompiler::CompileFunction " << function_id; + auto it = cache_.find({function_id, args}); + if (it != cache_.end()) { + *result = it->second; + return Status::OK(); + } + FunctionLibraryRuntime::Handle handle; - TF_RETURN_IF_ERROR( - flr->Instantiate(function.name(), function.attr(), &handle)); + TF_RETURN_IF_ERROR(flib_runtime_->Instantiate( + function.name(), AttrSlice(&function.attr()), &handle)); - const FunctionBody* fbody = flr->GetFunctionBody(handle); + const FunctionBody* fbody = flib_runtime_->GetFunctionBody(handle); CHECK(fbody); TF_RETURN_IF_ERROR(CheckSignature(fbody->arg_types, args)); - std::unique_ptr graph(new Graph(flr->GetFunctionLibraryDefinition())); + std::unique_ptr graph(new Graph(options_.flib_def)); CopyGraph(*fbody->graph, graph.get()); if (VLOG_IS_ON(1)) { @@ -115,11 +134,13 @@ Status XlaCompiler::CompileFunction( } // Optimize the graph before running the compiler. - // TODO(pbar): The constant folder currently does not simplify int32 - // operations for devices other than CPU. OptimizerOptions opts; + opts.set_do_common_subexpression_elimination(true); + opts.set_do_function_inlining(true); + opts.set_do_constant_folding(true); GraphOptimizer optimizer(opts); - OptimizeGraph(flr, &graph); + optimizer.Optimize(flib_runtime_.get(), flib_runtime_->env(), + /*device=*/nullptr, &graph); if (VLOG_IS_ON(1)) { dump_graph::DumpGraphToFile( @@ -129,9 +150,10 @@ Status XlaCompiler::CompileFunction( VLOG(1) << "===================================================="; TF_RETURN_IF_ERROR( - CompileGraph(function_id, std::move(graph), flr, args, result)); + CompileGraph(options, function_id, std::move(graph), args, result)); VLOG(1) << "===================================================="; + cache_[{function_id, args}] = *result; return Status::OK(); } @@ -158,7 +180,7 @@ Status XlaCompiler::BuildExecutable( build_options.set_has_hybrid_result( options_.local_executable_has_hybrid_result); - auto compile_result = local_client->Compile(result.computation, + auto compile_result = local_client->Compile(*result.computation, argument_layouts, build_options); if (!compile_result.ok()) { return compile_result.status(); @@ -378,9 +400,9 @@ Status BuildComputation( } // namespace -Status XlaCompiler::CompileGraph(string const& name, +Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, + string const& name, std::unique_ptr graph, - FunctionLibraryRuntime* flib, const std::vector& args, CompilationResult* result) { VLOG(1) << "Executing graph symbolically to populate ComputationBuilder."; @@ -394,31 +416,29 @@ Status XlaCompiler::CompileGraph(string const& name, options_.resolve_compile_time_constants); core::ScopedUnref context_unref(context); - result->tuple_arg = options_.use_tuple_arg; + result->tuple_arg = options.use_tuple_arg; std::vector context_args; - TF_RETURN_IF_ERROR(BuildArguments(args, options_.use_tuple_arg, &builder, + TF_RETURN_IF_ERROR(BuildArguments(args, options.use_tuple_arg, &builder, &context_args, &result->input_mapping, &result->xla_input_shapes)); context->set_args(std::move(context_args)); - if (options_.prune_unreachable_nodes) { - PruneUnreachableNodes(graph.get()); - } - - TF_RETURN_IF_ERROR( - ExecuteGraph(context, std::move(graph), device_, flib, NextStepId())); + TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_, + flib_runtime_.get(), NextStepId())); int num_nonconst_outputs; + result->computation = std::make_shared(); TF_RETURN_IF_ERROR(BuildComputation( context->retvals(), context->variables(), context->has_side_effects(), - options_.return_updated_values_for_all_variables, &builder, - &result->computation, &num_nonconst_outputs, &result->variable_updates)); + options.return_updated_values_for_all_variables, &builder, + result->computation.get(), &num_nonconst_outputs, + &result->variable_updates)); result->requires_runtime_context = context->has_context_parameter(); // Tuple arguments and runtime context parameters are incompatible. - CHECK(!(options_.use_tuple_arg && result->requires_runtime_context)); + CHECK(!(options.use_tuple_arg && result->requires_runtime_context)); VLOG(2) << "Outputs: total: " << context->retvals().size() << " nonconstant: " << num_nonconst_outputs; @@ -434,13 +454,13 @@ Status XlaCompiler::CompileGraph(string const& name, } } - if (result->computation.IsNull()) { + if (result->computation->IsNull()) { return Status::OK(); } // Compute the output shapes, if there is a computation with non-constant // outputs. - auto computation_shape = client()->GetComputationShape(result->computation); + auto computation_shape = client()->GetComputationShape(*result->computation); if (!computation_shape.ok()) { return computation_shape.status(); } @@ -472,10 +492,10 @@ Status XlaCompiler::CompileGraph(string const& name, i < context->retvals().size(); ++i) { const XlaContext::HandleOrConstant& retval = context->retvals()[i]; if (!retval.is_constant) { - CHECK_LT(computation_output, num_nonconst_outputs); + CHECK_LT(computation_output, num_computation_outputs); OutputDescription& output = result->outputs[i]; output.is_constant = false; - if (num_nonconst_outputs > 1) { + if (num_computation_outputs > 1) { output.shape = XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape( result->xla_output_shape, computation_output)); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 3d28ca374609df28647d243544dcbf8cbf33e706..15f723ad782376b99ae7d72a5f15129e7880e9b1 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/notification.h" @@ -112,6 +113,8 @@ class XlaCompiler { // The name of this argument, used for debugging. string name; + + bool operator==(const Argument& other) const; }; struct OutputDescription { @@ -172,15 +175,22 @@ class XlaCompiler { // The XLA computation built from the tensorflow subgraph. May be null // if the output consists solely of compile-time constants. - xla::Computation computation; + std::shared_ptr computation; }; struct Options { - // Name of the compilation device to use. - DeviceType device_type = DeviceType(""); + // Name of the compilation device to use. Needs to be live only during + // XlaCompiler's constructor. + const DeviceType* device_type = nullptr; xla::Client* client = nullptr; + // Function library in which to find function definitions. Must be non-null. + const FunctionLibraryDefinition* flib_def = nullptr; + + // The graph def version to be compiled. + int graph_def_version = TF_GRAPH_DEF_VERSION; + // If 'allow_cpu_custom_calls' is true, kernels may make use of CustomCall() // for CPU; additionally, an optional XlaLocalRuntimeContext* may be passed // to the computation. @@ -198,6 +208,19 @@ class XlaCompiler { // computation. bool resolve_compile_time_constants = true; + // If not nullptr, populate_resource_manager is called with the + // compilation device's resource manager when the compilation + // device is created, and can be used to create metadata objects + // that can be accessed by XLA op kernels. + std::function* populate_resource_manager = nullptr; + }; + + explicit XlaCompiler(Options options); + ~XlaCompiler(); + + // Options pertaining to an individual call to CompileGraph() or + // CompileFunction(). + struct CompileOptions { // If `use_tuple_arg` is true, a single tuple parameter will be used for all // arguments; if false, each argument gets its own parameter. bool use_tuple_arg = false; @@ -208,23 +231,8 @@ class XlaCompiler { // modified by the computation. Used when compiling loop bodies to ensure // the input and output signatures match. bool return_updated_values_for_all_variables = false; - - // If 'prune_unreachable_nodes' is true, then nodes that are not - // dependencies of graph's _Retval nodes will be pruned before compilation. - // This is useful to prune stateful operators that should not be executed - // from a function body. - bool prune_unreachable_nodes = false; - - // If not nullptr, populate_resource_manager is called with the - // compilation device's resource manager when the compilation - // device is created, and can be used to create metadata objects - // that can be accessed by XLA op kernels. - std::function* populate_resource_manager = nullptr; }; - explicit XlaCompiler(Options options); - ~XlaCompiler(); - // Compiles a Tensorflow function `fn_name_attrs` into an XLA computation. // `args` describes the arguments to the function, each of which must either // be a runtime-parameter to the XLA computation, a compile-time constant, or @@ -235,7 +243,7 @@ class XlaCompiler { // arguments are returned as host memory tensors in the output list and are // not included in the XLA computation's outputs. The XLA computation is // null if there are no data-dependent outputs and no side effects. - Status CompileFunction(FunctionLibraryRuntime* flr, + Status CompileFunction(const CompileOptions& options, const NameAttrList& fn_name_attrs, const std::vector& args, CompilationResult* result); @@ -243,8 +251,8 @@ class XlaCompiler { // Compiles a tensorflow::Graph into an xla::Computation. // Similar to CompileFunction, but takes a Graph as input rather than a // function. - Status CompileGraph(string const& name, std::unique_ptr graph, - FunctionLibraryRuntime* flr, + Status CompileGraph(const CompileOptions& options, string const& name, + std::unique_ptr graph, const std::vector& args, CompilationResult* result); @@ -257,6 +265,7 @@ class XlaCompiler { xla::Client* client() const { return options_.client; } XlaCompilationDevice* device() const { return device_; } const DeviceMgr* device_mgr() const { return &device_mgr_; } + FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_.get(); } // Retrieves the channel handle associated with `key`. Allocates // a new channel handle if none exists. @@ -281,6 +290,17 @@ class XlaCompiler { XlaCompilationDevice* device_; // Owned by device_mgr_ DeviceMgr device_mgr_; + std::unique_ptr flib_runtime_; + + struct SignatureHash { + uint64 operator()( + const std::pair>& signature) const; + }; + + std::unordered_map>, + CompilationResult, SignatureHash> + cache_; + std::unordered_map channels_ GUARDED_BY(mu_); TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler); diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 1cc7f4abd15798b29fe065c65c618b0166007b7e..58d74057d101cdef89fca24ec6c0858291d825fa 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -96,6 +96,8 @@ REGISTER_XLA_OP(Name("DummyReadResource"), DummyReadResourceOp); class XlaCompilerTest : public ::testing::Test { protected: + XlaCompilerTest() : cpu_device_type_(DEVICE_CPU_XLA_JIT) {} + void SetUp() override { client_ = xla::ClientLibrary::LocalClientOrDie(); @@ -107,19 +109,13 @@ class XlaCompilerTest : public ::testing::Test { XlaCompiler::Options DefaultOptions() { XlaCompiler::Options options; - options.device_type = DeviceType(DEVICE_CPU_XLA_JIT); + options.device_type = &cpu_device_type_; options.client = client_; + options.flib_def = flib_def_.get(); return options; } - std::unique_ptr BuildFunctionLibraryRuntime( - const XlaCompiler& compiler) { - return std::unique_ptr(NewFunctionLibraryRuntime( - compiler.device_mgr(), /*env=*/nullptr, compiler.device(), - TF_GRAPH_DEF_VERSION, flib_def_.get(), OptimizerOptions(), - /*custom_kernel_creator=*/nullptr)); - } - + DeviceType cpu_device_type_; xla::Client* client_; std::unique_ptr flib_def_; }; @@ -127,15 +123,15 @@ class XlaCompilerTest : public ::testing::Test { // Tests compilation of an empty graph. TEST_F(XlaCompilerTest, EmptyReturnValues) { XlaCompiler compiler(DefaultOptions()); - auto flr = BuildFunctionLibraryRuntime(compiler); std::unique_ptr graph(new Graph(OpRegistry::Global())); XlaCompiler::CompilationResult result; - TF_ASSERT_OK(compiler.CompileGraph("add", std::move(graph), flr.get(), + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), /*args=*/{}, &result)); // No computation should be generated. - EXPECT_EQ(0, result.computation.handle().handle()); + EXPECT_EQ(0, result.computation->handle().handle()); } // Tests compilation and execution of a graph that adds two tensors. @@ -160,11 +156,10 @@ TEST_F(XlaCompilerTest, Simple) { // Compiles the graph. XlaCompiler compiler(DefaultOptions()); - auto flr = BuildFunctionLibraryRuntime(compiler); XlaCompiler::CompilationResult result; - TF_ASSERT_OK( - compiler.CompileGraph("add", std::move(graph), flr.get(), args, &result)); + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), args, &result)); // Tests that the generated computation works. std::unique_ptr param0_literal = @@ -178,7 +173,7 @@ TEST_F(XlaCompilerTest, Simple) { std::unique_ptr actual = client_ - ->Execute(result.computation, {param0_data.get(), param1_data.get()}) + ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); std::unique_ptr actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); @@ -213,14 +208,14 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { XlaCompiler::Options options = DefaultOptions(); options.resolve_compile_time_constants = true; XlaCompiler compiler(options); - auto flr = BuildFunctionLibraryRuntime(compiler); std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); CopyGraph(*graph, graph_copy.get()); XlaCompiler::CompilationResult result; - TF_ASSERT_OK(compiler.CompileGraph("constants", std::move(graph_copy), - flr.get(), args, &result)); + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), + "constants", std::move(graph_copy), args, + &result)); ASSERT_EQ(2, result.outputs.size()); EXPECT_TRUE(result.outputs[0].is_constant); @@ -235,7 +230,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr actual = - client_->Execute(result.computation, {param0_data.get()}) + client_->Execute(*result.computation, {param0_data.get()}) .ConsumeValueOrDie(); std::unique_ptr actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); @@ -250,14 +245,14 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { XlaCompiler::Options options = DefaultOptions(); options.resolve_compile_time_constants = false; XlaCompiler compiler(options); - auto flr = BuildFunctionLibraryRuntime(compiler); std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); CopyGraph(*graph, graph_copy.get()); XlaCompiler::CompilationResult result; - TF_ASSERT_OK(compiler.CompileGraph("constants", std::move(graph_copy), - flr.get(), args, &result)); + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), + "constants", std::move(graph_copy), args, + &result)); ASSERT_EQ(2, result.outputs.size()); EXPECT_FALSE(result.outputs[0].is_constant); @@ -270,7 +265,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr actual = - client_->Execute(result.computation, {param0_data.get()}) + client_->Execute(*result.computation, {param0_data.get()}) .ConsumeValueOrDie(); std::unique_ptr actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); @@ -312,13 +307,12 @@ TEST_F(XlaCompilerTest, ResourceManager) { }; options.populate_resource_manager = &populate_function; XlaCompiler compiler(options); - auto flr = BuildFunctionLibraryRuntime(compiler); EXPECT_EQ(0, resource->Get()); XlaCompiler::CompilationResult result; - TF_ASSERT_OK(compiler.CompileGraph("dummy", std::move(graph), flr.get(), args, - &result)); + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy", + std::move(graph), args, &result)); EXPECT_EQ(1, resource->Get()); diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 10d8b67bbd2d0e897e3ca55e584f575448a3a4fd..f8589edafc401bb511774ae3fede67f121efbcd7 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -89,7 +90,9 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral( case xla::U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case xla::F16: - LOG(FATAL) << "f16 literals not yet implemented"; + literal = *xla::LiteralUtil::CreateR0( + static_cast(value)); + break; case xla::TUPLE: LOG(FATAL) << "tuple element type is not integral"; case xla::OPAQUE: @@ -107,6 +110,9 @@ xla::ComputationDataHandle XlaHelpers::FloatLiteral(xla::ComputationBuilder* b, xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); switch (type) { + case xla::F16: + return b->ConstantR0(static_cast(value)); + break; case xla::F32: return b->ConstantR0(static_cast(value)); break; diff --git a/tensorflow/compiler/tf2xla/xla_local_runtime_context.h b/tensorflow/compiler/tf2xla/xla_local_runtime_context.h index cd773d64ed4154aa2a05ac2d15e9358614239b1f..dca420d6ee3fec45f88ac3b450ab0cb4fb83d38a 100644 --- a/tensorflow/compiler/tf2xla/xla_local_runtime_context.h +++ b/tensorflow/compiler/tf2xla/xla_local_runtime_context.h @@ -23,7 +23,7 @@ limitations under the License. // actually used. E.g. some ahead-of-time compiled computations don't need a // thread pool. namespace Eigen { -class ThreadPoolDevice; +struct ThreadPoolDevice; } namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index a022de36a26d5f85e11b11ccd8dba4760aa8552f..48831ce4c27dfc644e8cd821e04cce3639ec0af5 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -379,8 +379,8 @@ void XlaOpKernelContext::SetOpHasSideEffects() { XlaContext::Get(context_).AddSideEffects(); } -const XlaCompiler::Options& XlaOpKernelContext::GetCompilerOptions() const { - return XlaContext::Get(context_).compiler()->options(); +XlaCompiler* XlaOpKernelContext::compiler() const { + return XlaContext::Get(context_).compiler(); } void XlaOpKernelContext::CtxFailure(Status s) { context_->CtxFailure(s); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index f97e07bea5d13a8b5c65cbd378aba5a2a76d70d9..0a8a9284186e5b72a8a376ad159eb7b2482699c5 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -186,10 +186,9 @@ class XlaOpKernelContext { // Returns the underlying OpKernelContext. Use rarely. OpKernelContext* op_kernel_context() const { return context_; } - // Returns the options passed to the XlaCompiler that is being - // run. Used for, e.g., While to inherit options needed for nested - // computation. - const XlaCompiler::Options& GetCompilerOptions() const; + // Returns the XlaCompiler that is performing the compilation. Used for, e.g., + // While to compile nested computations. + XlaCompiler* compiler() const; // TODO(phawkins): find a better home for these helpers. diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 7576dff0cd701e06a46ee5a809f376c455fe391e..de09d4b23f8d8b140bbb37f32d651f3cede897ec 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -216,7 +216,7 @@ cc_test( ":test_helpers", ":types", ":util", - "//tensorflow/core:test", + ":xla_data_proto", "//tensorflow/core:test_main", ], ) @@ -256,6 +256,7 @@ cc_library( ":array3d", ":array4d", ":shape_util", + ":status_macros", ":types", ":util", ":xla_data_proto", @@ -274,6 +275,7 @@ cc_test( ":test", ":types", "//tensorflow/core:lib", + "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 3e9dfe2a922c913c528d586413c11e2da8cbdc39..2d96128e259da316a41e83bea221ae201ad88a13 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -99,6 +99,26 @@ cc_library( ], ) +cc_library( + name = "compile_only_client", + srcs = ["compile_only_client.cc"], + hdrs = ["compile_only_client.h"], + deps = [ + ":client", + ":computation", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:compile_only_service", + "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "@llvm//:support", + ], +) + # This target is used to instantiate the XLA service in-process and create # a client for it. cc_library( @@ -106,12 +126,14 @@ cc_library( srcs = ["client_library.cc"], hdrs = ["client_library.h"], deps = [ + ":compile_only_client", ":local_client", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:backend", + "//tensorflow/compiler/xla/service:compile_only_service", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:platform_util", diff --git a/tensorflow/compiler/xla/client/client_library.cc b/tensorflow/compiler/xla/client/client_library.cc index 93437023bc8956e449f828f5bf6dea7a6bff8610..8238261e1c90cadeda9005e437d684d3770bd67b 100644 --- a/tensorflow/compiler/xla/client/client_library.cc +++ b/tensorflow/compiler/xla/client/client_library.cc @@ -43,6 +43,16 @@ int LocalClientOptions::number_of_replicas() const { return number_of_replicas_; } +LocalClientOptions& LocalClientOptions::set_intra_op_parallelism_threads( + int num_threads) { + intra_op_parallelism_threads_ = num_threads; + return *this; +} + +int LocalClientOptions::intra_op_parallelism_threads() const { + return intra_op_parallelism_threads_; +} + /* static */ ClientLibrary& ClientLibrary::Singleton() { static ClientLibrary* c = new ClientLibrary; return *c; @@ -69,22 +79,24 @@ ClientLibrary::~ClientLibrary() = default; TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); } - auto it = client_library.instances_.find(platform->id()); - if (it != client_library.instances_.end()) { + auto it = client_library.local_instances_.find(platform->id()); + if (it != client_library.local_instances_.end()) { return it->second->client.get(); } ServiceOptions service_options; service_options.set_platform(platform); service_options.set_number_of_replicas(replica_count); + service_options.set_intra_op_parallelism_threads( + options.intra_op_parallelism_threads()); - std::unique_ptr instance = MakeUnique(); + auto instance = MakeUnique(); TF_ASSIGN_OR_RETURN(instance->service, LocalService::NewService(service_options)); instance->client = MakeUnique(instance->service.get()); LocalClient* cl = instance->client.get(); - client_library.instances_.insert( + client_library.local_instances_.insert( std::make_pair(platform->id(), std::move(instance))); return cl; } @@ -99,9 +111,35 @@ ClientLibrary::~ClientLibrary() = default; perftools::gputools::Platform* platform) { ClientLibrary& client_library = Singleton(); tensorflow::mutex_lock lock(client_library.service_mutex_); - auto it = client_library.instances_.find(platform->id()); - CHECK(it != client_library.instances_.end()); + auto it = client_library.local_instances_.find(platform->id()); + CHECK(it != client_library.local_instances_.end()); return it->second->service.get(); } +/* static */ StatusOr +ClientLibrary::GetOrCreateCompileOnlyClient( + perftools::gputools::Platform* platform) { + ClientLibrary& client_library = Singleton(); + tensorflow::mutex_lock lock(client_library.service_mutex_); + + if (platform == nullptr) { + TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); + } + + auto it = client_library.compile_only_instances_.find(platform->id()); + if (it != client_library.compile_only_instances_.end()) { + return it->second->client.get(); + } + + auto instance = MakeUnique(); + TF_ASSIGN_OR_RETURN(instance->service, + CompileOnlyService::NewService(platform)); + instance->client = MakeUnique(instance->service.get()); + CompileOnlyClient* cl = instance->client.get(); + + client_library.compile_only_instances_.insert( + std::make_pair(platform->id(), std::move(instance))); + return cl; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/client_library.h b/tensorflow/compiler/xla/client/client_library.h index 2bc319f9333368635690add017ad3d89947e2551..3ddd235d0efeeb78f49eafbf670d7c74a88960dd 100644 --- a/tensorflow/compiler/xla/client/client_library.h +++ b/tensorflow/compiler/xla/client/client_library.h @@ -26,7 +26,9 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/client/compile_only_client.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/compile_only_service.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/statusor.h" @@ -51,9 +53,14 @@ class LocalClientOptions { LocalClientOptions& set_number_of_replicas(int number_of_replicas); int number_of_replicas() const; + // Sets the thread pool size for parallel execution of an individual operator. + LocalClientOptions& set_intra_op_parallelism_threads(int num_threads); + int intra_op_parallelism_threads() const; + private: perftools::gputools::Platform* platform_ = nullptr; int number_of_replicas_ = -1; + int intra_op_parallelism_threads_ = -1; }; class ClientLibrary { @@ -76,6 +83,13 @@ class ClientLibrary { // access user computations from client. static LocalService* GetXlaService(perftools::gputools::Platform* platform); + // Singleton constructor-or-accessor for compile-only clients. Arguments: + // + // platform : The platform the underlying XLA service should target. If + // null then default platform is used. + static StatusOr GetOrCreateCompileOnlyClient( + perftools::gputools::Platform* platform = nullptr); + private: // Returns the singleton instance of ClientLibrary. static ClientLibrary& Singleton(); @@ -90,10 +104,21 @@ class ClientLibrary { std::unique_ptr client; }; + struct CompileOnlyInstance { + // Service that is wrapped by the singleton client object. + std::unique_ptr service; + // Singleton client object. + std::unique_ptr client; + }; + tensorflow::mutex service_mutex_; // Guards the singleton creation state. std::unordered_map> - instances_ GUARDED_BY(service_mutex_); + local_instances_ GUARDED_BY(service_mutex_); + + std::unordered_map> + compile_only_instances_ GUARDED_BY(service_mutex_); TF_DISALLOW_COPY_AND_ASSIGN(ClientLibrary); }; diff --git a/tensorflow/compiler/xla/client/compile_only_client.cc b/tensorflow/compiler/xla/client/compile_only_client.cc new file mode 100644 index 0000000000000000000000000000000000000000..2ff6f0b300f9e2cc776e60bb27a3952356657780 --- /dev/null +++ b/tensorflow/compiler/xla/client/compile_only_client.cc @@ -0,0 +1,59 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/compile_only_client.h" + +#include "external/llvm/include/llvm/ADT/Triple.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace se = ::perftools::gputools; + +namespace xla { + +StatusOr>> +CompileOnlyClient::CompileAheadOfTime( + const tensorflow::gtl::ArraySlice computations, + const AotCompilationOptions& options) { + std::vector service_instances; + service_instances.reserve(computations.size()); + for (const AotComputationInstance& instance : computations) { + service_instances.push_back({}); + CompileOnlyService::AotComputationInstance& service_instance = + service_instances.back(); + TF_RET_CHECK(instance.computation != nullptr); + service_instance.computation = instance.computation->handle(); + service_instance.argument_layouts = instance.argument_layouts; + service_instance.result_layout = instance.result_layout; + } + return compiler_service_->CompileAheadOfTime(service_instances, options); +} + +int64 CompileOnlyClient::PointerSizeForTriple( + tensorflow::StringPiece target_triple) { + llvm::Triple triple( + llvm::Triple::normalize(llvm_ir::AsStringRef(target_triple))); + if (triple.isArch64Bit()) { + return 8; + } else if (triple.isArch32Bit()) { + return 4; + } else { + CHECK(triple.isArch16Bit()); + return 2; + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/compile_only_client.h b/tensorflow/compiler/xla/client/compile_only_client.h new file mode 100644 index 0000000000000000000000000000000000000000..5900048711384e0240a3cd502260eb388eb40f51 --- /dev/null +++ b/tensorflow/compiler/xla/client/compile_only_client.h @@ -0,0 +1,66 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_ + +#include "tensorflow/compiler/xla/client/client.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/service/compile_only_service.h" +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { + +// An XLA Client specialization for doing ahead-of-time compilation. This does +// not require (or attempt to instantiate) an execution-capable backend for the +// relevant platform. +class CompileOnlyClient : public Client { + public: + explicit CompileOnlyClient(CompileOnlyService* service) + : Client(service), compiler_service_(service) {} + + CompileOnlyClient(const CompileOnlyClient&) = delete; + void operator=(const CompileOnlyClient&) = delete; + + // A description of a computation to compile using CompileAheadOfTime. + struct AotComputationInstance { + const Computation* computation; + // Inform the compiler of the expected layout for arguments. + std::vector argument_layouts; + // Specifies the expected result layout. + const Shape* result_layout; + }; + + // Compiles a list of computations for ahead-of-time execution. This is + // intended for use in static compilation. The |options| parameter describes + // the target for which the compiler should emit code. + StatusOr>> + CompileAheadOfTime( + const tensorflow::gtl::ArraySlice computations, + const AotCompilationOptions& options); + + // Returns the size of a pointer in bytes for a given triple. + static int64 PointerSizeForTriple(tensorflow::StringPiece triple); + + private: + CompileOnlyService* compiler_service_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_ diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index 87ceb43d1fe6650e1d160f3099b883ea208d8aac..6af69eeec12dec0ea1303826859d4655cf92932e 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -668,6 +668,14 @@ class ComputationBuilder { // then Build() should be used instead. Computation BuildAndNoteError(); + // Returns the first error that was encountered while building the + // computation. When an error is encountered, by default we return a vacuous + // ComputationDataHandle and inform the user of the error that occurred while + // building the computation when they make a final call to Build(). + // + // See also set_die_immediately_on_error(). + Status first_error() const { return first_error_; } + private: using PopulateLiteral = std::function; diff --git a/tensorflow/compiler/xla/client/global_data.h b/tensorflow/compiler/xla/client/global_data.h index eb11d91034ba524f093ff80fa7cd0473e04eac2c..b7929357d06032b55c04bf0391f7fa703ee15f17 100644 --- a/tensorflow/compiler/xla/client/global_data.h +++ b/tensorflow/compiler/xla/client/global_data.h @@ -23,13 +23,15 @@ limitations under the License. namespace xla { -// Wraps a GlobalDataHandle with a lifetime. +// A GlobalData object represents a globally-accessible allocation of +// data in the associated XLA service. class GlobalData { public: // Gives ownership of the global data handle to this object. GlobalData(ServiceInterface* parent, GlobalDataHandle handle); - // Unregisters the wrapped handle. + // Unregisters the wrapped handle, which causes the service to + // deallocate the associated data. ~GlobalData(); const GlobalDataHandle& handle() const { return handle_; } diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index bfd14bc1c010353e3e473f10dd6c030cb0438648..02cf57e7632a2064e646d4dc441e3ec119053564 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -176,17 +176,24 @@ StatusOr> LocalExecutable::Run( TF_RETURN_IF_ERROR(ValidateExecutionOptions(arguments, options, *backend_)); ExecutableRunOptions actual_options = options; - Backend::StreamPtr stream; if (options.stream() == nullptr) { TF_ASSIGN_OR_RETURN( - stream, BorrowStreamForDevice(options.device_ordinal(), backend_)); + Backend::StreamPtr stream, + BorrowStreamForDevice(options.device_ordinal(), backend_)); actual_options.set_stream(stream.get()); } if (options.allocator() == nullptr) { actual_options.set_allocator(backend_->memory_allocator()); } - ServiceExecutableRunOptions service_options(actual_options, - backend_->StreamBorrower()); + + // For local client execution on CPU backends: + // *) The thread pool used for eigen CPU ops is from + // ExecutableRunOptions.eigen_intra_op_thread_pool. + // *) The thread pool used for XLA CPU ops is from + // backend_->eigen_intra_op_thread_pool(). + ServiceExecutableRunOptions service_options( + actual_options, backend_->StreamBorrower(), + backend_->eigen_intra_op_thread_pool()); if (executable_->dumping()) { return ExecuteAndDump(&service_options, arguments); @@ -253,46 +260,6 @@ StatusOr> LocalClient::AllocateBufferOnDevice( return std::unique_ptr(new GlobalData(local_service_, handle)); } -tensorflow::Status LocalClient::ResolveArguments( - const tensorflow::gtl::ArraySlice arguments, - int device_ordinal, - std::vector* argument_ptrs) { - return local_service_->ResolveArguments(arguments, device_ordinal, - argument_ptrs); -} - -StatusOr>> -LocalClient::CompileAheadOfTime( - const tensorflow::gtl::ArraySlice - computations, - const AotCompilationOptions& options) { - std::vector service_instances; - service_instances.reserve(computations.size()); - for (const AheadOfTimeComputationInstance& instance : computations) { - service_instances.push_back({}); - LocalService::AheadOfTimeComputationInstance& service_instance = - service_instances.back(); - TF_RET_CHECK(instance.computation != nullptr); - service_instance.computation = instance.computation->handle(); - service_instance.argument_layouts = instance.argument_layouts; - service_instance.result_layout = instance.result_layout; - } - return local_service_->CompileAheadOfTime(service_instances, options); -} - -int64 LocalClient::PointerSizeForTriple(tensorflow::StringPiece target_triple) { - llvm::Triple triple( - llvm::Triple::normalize(llvm_ir::AsStringRef(target_triple))); - if (triple.isArch64Bit()) { - return 8; - } else if (triple.isArch32Bit()) { - return 4; - } else { - CHECK(triple.isArch16Bit()); - return 2; - } -} - se::Platform* LocalClient::platform() const { return local_service_->backend().platform(); } diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 2c467efcea119b66ad08e0636eca0f1acec3a3b8..49ffed4dde6ba9b6683d42cefec593a0c35bca6e 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -148,7 +148,7 @@ class LocalExecutable { const ExecutableBuildOptions& build_options_; }; -// An XLA service client object for use when the client and service run in +// An XLA Client specialization for use when the client and service run in // the same process. class LocalClient : public Client { public: @@ -158,14 +158,6 @@ class LocalClient : public Client { LocalClient(const LocalClient&) = delete; void operator=(const LocalClient&) = delete; - // For an array of arguments held on the local service, validate - // that each is placed on the specified device_ordinal, and return - // the DeviceMemoryBase corresponding to each argument. - tensorflow::Status ResolveArguments( - const tensorflow::gtl::ArraySlice arguments, - int device_ordinal, - std::vector* argument_ptrs); - // Return a handle to a buffer large enough to hold shape, allocated // on device_ordinal on the local service. If // allocate_space_for_deep_copy, the buffer is large enough to hold @@ -182,30 +174,6 @@ class LocalClient : public Client { const tensorflow::gtl::ArraySlice argument_layouts, const ExecutableBuildOptions& options); - // A description of a computation to compile using CompileAheadOfTime. - struct AheadOfTimeComputationInstance { - const Computation* computation; - // Inform the compiler of the expected layout for arguments. - std::vector argument_layouts; - // Specifies the expected result layout. - const Shape* result_layout; - }; - - // Compiles a list of computations for ahead-of-time execution. This is - // intended for use in static compilation. The |options| parameter describes - // the target for which the compiler should emit code. - // - // TODO(b/31222190): This doesn't really belong in LocalClient. Move it to its - // own library. - StatusOr>> - CompileAheadOfTime( - const tensorflow::gtl::ArraySlice - computations, - const AotCompilationOptions& options); - - // Returns the size of a pointer in bytes for a given triple. - static int64 PointerSizeForTriple(tensorflow::StringPiece triple); - // Returns the platform that the underlying service targets. perftools::gputools::Platform* platform() const; diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc index 92aca3cae9e442453d3726972179d126959dca2f..76c0168f370ff1f0749759705b7ecff359a80341 100644 --- a/tensorflow/compiler/xla/index_util.cc +++ b/tensorflow/compiler/xla/index_util.cc @@ -131,4 +131,23 @@ namespace xla { return false; } +/* static */ int64 IndexUtil::GetDimensionStride(const Shape& shape, + int64 dimension) { + const Layout& layout = shape.layout(); + int64 pdim_size = layout.padded_dimensions_size(); + int64 stride = 1; + DCHECK(pdim_size == 0 || pdim_size == shape.dimensions_size()); + for (auto dim : layout.minor_to_major()) { + if (dim == dimension) { + break; + } + if (pdim_size == 0) { + stride *= shape.dimensions(dim); + } else { + stride *= layout.padded_dimensions(dim); + } + } + return stride; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/index_util.h b/tensorflow/compiler/xla/index_util.h index e6a26d622016c89f5459a50bd0f733daef469fae..c9838966a5b67397eb5fc4afe3ab9d98e82eb2b1 100644 --- a/tensorflow/compiler/xla/index_util.h +++ b/tensorflow/compiler/xla/index_util.h @@ -61,6 +61,14 @@ class IndexUtil { static bool BumpIndices(const Shape& shape, tensorflow::gtl::MutableArraySlice indices); + // Calculates the stride size (in number of elements, not byte size) of a + // given logical shape dimension (from 0 to rank-1). If available, padded + // dimensions are used. + // Example: + // GetDimensionStride(F32[5,8,10,4]{3,2,1,0}, 1) == + // sizeof(dimension(3)) * sizeof(dimension(2)) == 4 * 10 + static int64 GetDimensionStride(const Shape& shape, int64 dimension); + private: TF_DISALLOW_COPY_AND_ASSIGN(IndexUtil); }; diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 7091c324d14552d8b7603c3872d0ffc59771d8f7..0f622f9153436f58b05a4b5f4ea1dc0576da3e23 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -16,12 +16,15 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include +#include +#include #include #include #include #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -33,6 +36,137 @@ limitations under the License. namespace xla { +LiteralUtil::StrideConfig::StrideConfig( + const Shape& source_shape, const Shape& dest_shape, + tensorflow::gtl::ArraySlice dimensions) + : dimensions(dimensions), + base(dimensions.size(), 0), + step(dimensions.size(), 1) { + if (!dimensions.empty()) { + // Selects the shape with the highest minor dimension as the one upon + // where to run the tight stride loop. + if (source_shape.layout().minor_to_major()[0] >= + dest_shape.layout().minor_to_major()[0]) { + minor_dimension = source_shape.layout().minor_to_major()[0]; + dest_stride = IndexUtil::GetDimensionStride(dest_shape, minor_dimension); + } else { + minor_dimension = dest_shape.layout().minor_to_major()[0]; + source_stride = + IndexUtil::GetDimensionStride(source_shape, minor_dimension); + } + minor_loop_size = dimensions[minor_dimension]; + step[minor_dimension] = minor_loop_size; + } +} + +/* static */ std::unique_ptr LiteralUtil::CreateFromShape( + const Shape& shape) { + auto literal = MakeUnique(); + *literal->mutable_shape() = shape; + Reserve(ShapeUtil::ElementsIn(literal->shape()), literal.get()); + return literal; +} + +/* static */ std::unique_ptr LiteralUtil::CreateFromDimensions( + PrimitiveType primitive_type, + tensorflow::gtl::ArraySlice dimensions) { + return CreateFromShape(ShapeUtil::MakeShape(primitive_type, dimensions)); +} + +template +/* static */ Status LiteralUtil::CopyRange( + const Literal& src_literal, tensorflow::gtl::ArraySlice src_base, + Literal* dest_literal, tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size) { + const Shape& src_shape = src_literal.shape(); + const Shape& dest_shape = dest_literal->shape(); + tensorflow::gtl::ArraySlice src_data = GetArraySlice(src_literal); + tensorflow::gtl::MutableArraySlice dest_data = + GetMutableArraySlice(dest_literal); + + TF_RET_CHECK(ShapeUtil::Rank(src_shape) == src_base.size()); + TF_RET_CHECK(ShapeUtil::Rank(dest_shape) == dest_base.size()); + if (ShapeUtil::Rank(src_shape) == 0 || ShapeUtil::Rank(dest_shape) == 0) { + // If any of the two shapes are scalars, we can just call the StridedCopy() + // directly, and we know we will be copying only one value. + TF_RET_CHECK(copy_size.empty()); + StridedCopy(dest_data, LinearIndex(*dest_literal, dest_base), 0, src_data, + LinearIndex(src_literal, src_base), 0, 1); + } else if (!ShapeUtil::HasZeroElements(dest_shape)) { + TF_RET_CHECK(!ShapeUtil::HasZeroElements(src_shape)); + TF_RET_CHECK(src_base.size() == dest_base.size()); + TF_RET_CHECK(src_base.size() == copy_size.size()); + + // Scan the source from minor, stepping in copy size blocks, then within + // the index enumaration functor, do a strided copy advancing source index + // by one (walking through the minor dimension), and destination index by + // proper stride size at the matching dimension. + DimensionVector src_indexes(src_base.size(), 0); + DimensionVector dest_indexes(dest_base.size(), 0); + StrideConfig stride_config(src_shape, dest_shape, copy_size); + + auto copy_proc = [&](const std::vector& indexes) { + // Map from multi-dimensional index, to source index. + std::transform(indexes.begin(), indexes.end(), src_base.begin(), + src_indexes.begin(), std::plus()); + // Map from multi-dimensional index, to destination index. + std::transform(indexes.begin(), indexes.end(), dest_base.begin(), + dest_indexes.begin(), std::plus()); + + int64 src_index = LinearIndex(src_literal, src_indexes); + int64 dest_index = LinearIndex(*dest_literal, dest_indexes); + + StridedCopy(dest_data, dest_index, stride_config.dest_stride, src_data, + src_index, stride_config.source_stride, + stride_config.minor_loop_size); + return true; + }; + + ShapeUtil::ForEachIndex(src_shape, stride_config.base, + stride_config.dimensions, stride_config.step, + copy_proc); + } + return Status::OK(); +} + +/* static */ Status LiteralUtil::Copy( + const Literal& src_literal, tensorflow::gtl::ArraySlice src_base, + Literal* dest_literal, tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size) { + TF_RET_CHECK( + ShapeUtil::SameElementType(src_literal.shape(), dest_literal->shape())); + switch (src_literal.shape().element_type()) { + case U32: + return CopyRange(src_literal, src_base, dest_literal, dest_base, + copy_size); + case U64: + return CopyRange(src_literal, src_base, dest_literal, dest_base, + copy_size); + case S32: + return CopyRange(src_literal, src_base, dest_literal, dest_base, + copy_size); + case S64: + return CopyRange(src_literal, src_base, dest_literal, dest_base, + copy_size); + case F16: + return CopyRange(src_literal, src_base, dest_literal, dest_base, + copy_size); + case F32: + return CopyRange(src_literal, src_base, dest_literal, dest_base, + copy_size); + case F64: + return CopyRange(src_literal, src_base, dest_literal, dest_base, + copy_size); + case PRED: + return CopyRange(src_literal, src_base, dest_literal, dest_base, + copy_size); + default: + break; + } + return Unimplemented("Unhandled primitive type %d", + src_literal.shape().element_type()); +} + /* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) { switch (primitive_type) { case U8: @@ -47,6 +181,8 @@ namespace xla { return *LiteralUtil::CreateR0(0); case S64: return *LiteralUtil::CreateR0(0); + case F16: + return *LiteralUtil::CreateR0(static_cast(0.0f)); case F32: return *LiteralUtil::CreateR0(0); case F64: @@ -56,8 +192,6 @@ namespace xla { case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; - case F16: - LOG(FATAL) << "f16 literals not yet implemented"; case TUPLE: LOG(FATAL) << "tuple element type cannot take on value of 0"; case OPAQUE: @@ -91,7 +225,7 @@ namespace xla { case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - LOG(FATAL) << "f16 literals not yet implemented"; + return *LiteralUtil::CreateR0(static_cast(1.0f)); case TUPLE: LOG(FATAL) << "tuple element type cannot take on value of 1"; case OPAQUE: @@ -127,7 +261,8 @@ namespace xla { case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - LOG(FATAL) << "f16 literals not yet implemented"; + return *LiteralUtil::CreateR0( + static_cast(-std::numeric_limits::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no minimum value"; case OPAQUE: @@ -163,7 +298,8 @@ namespace xla { case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - LOG(FATAL) << "f16 literals not yet implemented"; + return *LiteralUtil::CreateR0( + static_cast(std::numeric_limits::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no maximum value"; case OPAQUE: @@ -197,37 +333,16 @@ namespace xla { /* static */ std::unique_ptr LiteralUtil::Relayout( const Literal& original, const Layout& layout) { - // Note: if this were a performance bottleneck, we avoid cloning and just make - // an uninitialized array instead, since all values are clobbered below. std::unique_ptr result = CloneToUnique(original); *result->mutable_shape()->mutable_layout() = layout; - const PrimitiveType primitive_type = original.shape().element_type(); - switch (primitive_type) { - case F32: - LiteralUtil::EachCell( - original, - [&](tensorflow::gtl::ArraySlice indices, float value) { - LiteralUtil::Set(result.get(), indices, value); - }); - return result; - case S32: - LiteralUtil::EachCell( - original, - [&](tensorflow::gtl::ArraySlice indices, int32 value) { - LiteralUtil::Set(result.get(), indices, value); - }); - return result; - case U32: - LiteralUtil::EachCell( - original, - [&](tensorflow::gtl::ArraySlice indices, uint32 value) { - LiteralUtil::Set(result.get(), indices, value); - }); - return result; - default: - LOG(FATAL) << "not yet implemented: " - << PrimitiveType_Name(primitive_type); - } + + const Shape& shape = original.shape(); + DimensionVector base(ShapeUtil::Rank(shape), 0); + DimensionVector copy_size(shape.dimensions().begin(), + shape.dimensions().end()); + + TF_CHECK_OK(Copy(original, base, result.get(), base, copy_size)); + return result; } /* static */ StatusOr> LiteralUtil::Reshape( @@ -235,25 +350,19 @@ namespace xla { if (ShapeUtil::IsTuple(input.shape())) { return InvalidArgument("Reshape does not support tuples."); } - + std::unique_ptr output; if (!LayoutUtil::IsMonotonicWithDim0Major(input.shape().layout())) { - return Unimplemented( - "Input shape must have a monotonic layout where dimension 0 is major, " - "was: %s", - LayoutUtil::HumanString(input.shape().layout()).c_str()); + std::vector minor_to_major(ShapeUtil::Rank(input.shape())); + std::iota(minor_to_major.rbegin(), minor_to_major.rend(), + static_cast(0)); + output = Relayout(input, LayoutUtil::MakeLayout(minor_to_major)); + } else { + output = CloneToUnique(input); } - std::vector layout(dimensions.size()); - std::iota(layout.rbegin(), layout.rend(), 0); - // Because the layout is monotonic, we can simply reuse the same sequence of // values without changing their order. - std::unique_ptr output = CloneToUnique(input); - output->clear_shape(); - output->mutable_shape()->set_element_type(input.shape().element_type()); - for (int64 dimension : dimensions) { - output->mutable_shape()->add_dimensions(dimension); - } - *output->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout(layout); + *output->mutable_shape() = + ShapeUtil::MakeShape(input.shape().element_type(), dimensions); int64 elements_before = ShapeUtil::ElementsIn(input.shape()); int64 elements_after = ShapeUtil::ElementsIn(output->shape()); @@ -267,73 +376,42 @@ namespace xla { return std::move(output); } -namespace { - -template -void TransposeLiteralInternal(const Literal& original, - tensorflow::gtl::ArraySlice permutation, - Literal* result) { - std::vector new_indices(ShapeUtil::Rank(original.shape())); - LiteralUtil::EachCell( - original, [&](tensorflow::gtl::ArraySlice indices, T value) { - for (int64 i = 0; i < indices.size(); ++i) { - new_indices[i] = indices[permutation[i]]; - } - LiteralUtil::Set(result, new_indices, value); - }); -} -} // namespace - /* static */ std::unique_ptr LiteralUtil::Transpose( const Literal& original, tensorflow::gtl::ArraySlice permutation) { CHECK(!ShapeUtil::IsTuple(original.shape())) - << "tuple is not supported for transpose"; - std::vector dimension_numbers(ShapeUtil::Rank(original.shape())); - std::iota(dimension_numbers.begin(), dimension_numbers.end(), 0); - CHECK(std::is_permutation(permutation.begin(), permutation.end(), - dimension_numbers.begin())) - << "given permutation is not a permutation of dimension numbers"; - std::vector new_dimension_sizes; - for (const int64 dim : permutation) { - new_dimension_sizes.push_back(original.shape().dimensions(dim)); - } - const auto result_shape = ShapeUtil::MakeShape( - original.shape().element_type(), new_dimension_sizes); - std::unique_ptr result = CloneToUnique(original); - *result->mutable_shape() = result_shape; - const PrimitiveType primitive_type = original.shape().element_type(); - switch (primitive_type) { - case F32: - TransposeLiteralInternal(original, permutation, result.get()); - return result; - case F64: - TransposeLiteralInternal(original, permutation, result.get()); - return result; - case PRED: - TransposeLiteralInternal(original, permutation, result.get()); - return result; - case S8: - TransposeLiteralInternal(original, permutation, result.get()); - return result; - case U8: - TransposeLiteralInternal(original, permutation, result.get()); - return result; - case S32: - TransposeLiteralInternal(original, permutation, result.get()); - return result; - case U32: - TransposeLiteralInternal(original, permutation, result.get()); - return result; - case S64: - TransposeLiteralInternal(original, permutation, result.get()); - return result; - case U64: - TransposeLiteralInternal(original, permutation, result.get()); - return result; - default: - LOG(FATAL) << "not yet implemented: " - << PrimitiveType_Name(primitive_type); + << "Tuple is not supported for transpose"; + CHECK(IsPermutation(permutation, ShapeUtil::Rank(original.shape()))) + << "Given permutation is not a permutation of dimension numbers"; + // To transpose the array, we just permute the dimensions and layout, and + // do a straight memory copy of the raw data set. + // This is considerably faster than iterating over every array element using + // the EachCell<>() and Set<>() APIs. + std::vector inverse_permutation = InversePermutation(permutation); + Shape shape = + ShapeUtil::PermuteDimensions(inverse_permutation, original.shape()); + // Replace the layout with one affine to the original shape, such that a + // transpose operation can be performed by leaving the flat values + // representation intact. + // For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation. + // The shape with affine layout resulting from that operation will be + // F32[8,11]{0,1}, since it leave the original most minor (the 8 sized), the + // most minor. + // Essentially, given MinMaj(Di) the position of the Di dimension within the + // minor to major vector, and given T(Di) the index that the original Di + // dimension has within the transposed array, a layout is affine if + // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major + // vector of the affine layout. + Layout* layout = shape.mutable_layout(); + layout->clear_minor_to_major(); + for (auto index : original.shape().layout().minor_to_major()) { + layout->add_minor_to_major(inverse_permutation[index]); } + std::unique_ptr new_literal = CreateFromShape(shape); + DCHECK_GE(ShapeUtil::ByteSizeOf(new_literal->shape()), + ShapeUtil::ByteSizeOf(original.shape())); + std::memcpy(MutableInternalData(new_literal.get()), InternalData(original), + ShapeUtil::ByteSizeOf(original.shape())); + return new_literal; } /* static */ std::unique_ptr LiteralUtil::Slice( @@ -342,7 +420,7 @@ void TransposeLiteralInternal(const Literal& original, CHECK(!ShapeUtil::IsTuple(literal.shape())) << "tuple is not supported for reshape"; - std::vector result_dimensions; + DimensionVector result_dimensions; for (int64 dnum = 0; dnum < ShapeUtil::Rank(literal.shape()); ++dnum) { CHECK_GE(start_indices[dnum], 0); CHECK_LE(limit_indices[dnum], literal.shape().dimensions(dnum)); @@ -358,7 +436,7 @@ void TransposeLiteralInternal(const Literal& original, *result_literal->mutable_shape() = result_shape; Reserve(ShapeUtil::ElementsIn(result_shape), result_literal.get()); - std::vector new_indices(ShapeUtil::Rank(result_shape)); + DimensionVector new_indices(ShapeUtil::Rank(result_shape)); switch (result_shape.element_type()) { case F32: LiteralUtil::EachCell( @@ -425,6 +503,8 @@ void TransposeLiteralInternal(const Literal& original, return tensorflow::strings::StrCat(Get(literal, multi_index)); case F64: return tensorflow::strings::StrCat(Get(literal, multi_index)); + case F16: + return tensorflow::strings::StrCat(Get(literal, multi_index)); default: return tensorflow::strings::StrCat( "[", PrimitiveType_Name(literal.shape().element_type()), "]"); @@ -579,6 +659,8 @@ void TransposeLiteralInternal(const Literal& original, return reinterpret_cast(literal.f32s().data()); case F64: return reinterpret_cast(literal.f64s().data()); + case F16: + return reinterpret_cast(literal.f16s().data()); default: LOG(FATAL) << "primitive type not supported in literals: " << PrimitiveType_Name(literal.shape().element_type()); @@ -593,38 +675,33 @@ void TransposeLiteralInternal(const Literal& original, CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); switch (literal->shape().element_type()) { case PRED: - GetMutableRepeatedField(literal)->Resize(num_elements, false); + Resize(num_elements, false, literal); + break; + case S8: + Resize(num_elements, 0, literal); break; case U8: - // u8s is an optional "bytes", rather than a repeated field. Therefore its - // access methods are somewhat different from the others. - literal->mutable_u8s()->resize(num_elements, 0); + Resize(num_elements, 0, literal); break; case S32: - GetMutableRepeatedField(literal)->Resize(num_elements, - /*value=*/0); + Resize(num_elements, 0, literal); break; case S64: - GetMutableRepeatedField(literal)->Resize( - num_elements, - /*value=*/0); + Resize(num_elements, 0, literal); break; case U32: - GetMutableRepeatedField(literal)->Resize(num_elements, - /*value=*/0); + Resize(num_elements, 0, literal); break; case U64: - GetMutableRepeatedField(literal)->Resize( - num_elements, - /*value=*/0); + Resize(num_elements, 0, literal); break; case F32: - GetMutableRepeatedField(literal)->Resize(num_elements, - /*value=*/0.0f); + Resize(num_elements, 0, literal); break; case F64: - GetMutableRepeatedField(literal)->Resize(num_elements, - /*value=*/0.0); + Resize(num_elements, 0, literal); + case F16: + Resize(num_elements, static_cast(0.0f), literal); break; default: LOG(FATAL) << "primitive type not supported in literals: " @@ -662,6 +739,9 @@ void TransposeLiteralInternal(const Literal& original, case F64: actual = literal.f64s_size(); break; + case F16: + actual = literal.f16s().size() / sizeof(half); + break; default: return tensorflow::errors::Unimplemented( "unhandled element type for literal validation: " + @@ -680,50 +760,16 @@ void TransposeLiteralInternal(const Literal& original, /* static */ void LiteralUtil::EachCellAsString( const Literal& literal, - std::function indices, - const string& value)> - per_cell) { - if (ShapeUtil::Rank(literal.shape()) == 1) { - for (int64 i0 = 0; i0 < literal.shape().dimensions(0); ++i0) { - per_cell({i0}, GetAsString(literal, {i0})); - } - return; - } - - if (ShapeUtil::Rank(literal.shape()) == 2) { - for (int64 i0 = 0; i0 < literal.shape().dimensions(0); ++i0) { - for (int64 i1 = 0; i1 < literal.shape().dimensions(1); ++i1) { - per_cell({i0, i1}, GetAsString(literal, {i0, i1})); - } - } + const std::function indices, + const string& value)>& per_cell) { + if (ShapeUtil::HasZeroElements(literal.shape())) { return; } - - if (ShapeUtil::Rank(literal.shape()) == 3) { - for (int64 i0 = 0; i0 < literal.shape().dimensions(0); ++i0) { - for (int64 i1 = 0; i1 < literal.shape().dimensions(1); ++i1) { - for (int64 i2 = 0; i2 < literal.shape().dimensions(2); ++i2) { - per_cell({i0, i1, i2}, GetAsString(literal, {i0, i1, i2})); - } - } - } - return; - } - - if (ShapeUtil::Rank(literal.shape()) == 4) { - for (int64 i0 = 0; i0 < literal.shape().dimensions(0); ++i0) { - for (int64 i1 = 0; i1 < literal.shape().dimensions(1); ++i1) { - for (int64 i2 = 0; i2 < literal.shape().dimensions(2); ++i2) { - for (int64 i3 = 0; i3 < literal.shape().dimensions(3); ++i3) { - per_cell({i0, i1, i2, i3}, GetAsString(literal, {i0, i1, i2, i3})); - } - } - } - } - return; - } - - LOG(FATAL) << "unhandled rank: " << ShapeUtil::Rank(literal.shape()); + std::vector indices = IndexUtil::LinearIndexToMultidimensionalIndex( + literal.shape(), /*linear_index=*/0); + do { + per_cell(indices, GetAsString(literal, indices)); + } while (IndexUtil::BumpIndices(literal.shape(), &indices)); } namespace { @@ -786,6 +832,8 @@ bool EqualElements(const Literal& literal1, const Literal& literal2, return EqualElements(literal1, literal2, 0, &multi_index); case F64: return EqualElements(literal1, literal2, 0, &multi_index); + case F16: + return EqualElements(literal1, literal2, 0, &multi_index); default: LOG(FATAL) << "Unimplemented: LiteralUtil::Equal for type " << PrimitiveType_Name(literal1.shape().element_type()); @@ -794,96 +842,176 @@ bool EqualElements(const Literal& literal1, const Literal& literal2, } template <> -/* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice( - const Literal& literal) { - CHECK(literal.shape().element_type() == PRED); - return literal.preds(); +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal) { + auto values = literal->mutable_preds(); + return tensorflow::gtl::MutableArraySlice(values->mutable_data(), + values->size()); } template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField(Literal* literal) { - CHECK(literal->shape().element_type() == PRED); - return literal->mutable_preds(); +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal) { + // C++11 standard, basic_string 21.4.1.5, values should be stored + // contiguously. From C++17 a mutable data() member will be provided. + auto values = literal->mutable_u8s(); + return tensorflow::gtl::MutableArraySlice( + reinterpret_cast(&(*values)[0]), values->size()); } template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal) { - CHECK(literal.shape().element_type() == U32); - return literal.u32s(); +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal) { + // C++11 standard, basic_string 21.4.1.5, values should be stored + // contiguously. From C++17 a mutable data() member will be provided. + auto values = literal->mutable_u8s(); + return tensorflow::gtl::MutableArraySlice( + reinterpret_cast(&(*values)[0]), values->size()); } template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField(Literal* literal) { - CHECK(literal->shape().element_type() == U32); - return literal->mutable_u32s(); +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal) { + auto values = literal->mutable_s32s(); + return tensorflow::gtl::MutableArraySlice(values->mutable_data(), + values->size()); } template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal) { - CHECK(literal.shape().element_type() == U64); - return AsUInt64Slice(literal.u64s()); +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal) { + auto values = literal->mutable_u32s(); + return tensorflow::gtl::MutableArraySlice(values->mutable_data(), + values->size()); } template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField( - Literal* literal) { - CHECK(literal->shape().element_type() == U64); - return literal->mutable_u64s(); +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal) { + static_assert(sizeof(int64) == sizeof(tensorflow::protobuf_int64) && + alignof(int64) == alignof(tensorflow::protobuf_int64), + "The int64 and tensorflow::protobuf_int64 types are not " + "compatible"); + auto values = literal->mutable_s64s(); + // Because of the fact that tensorflow::protobuf_int64 is defined as int64_t + // while tensorflow::int64 is defined as long long, a reinterpret_cast<> is + // necessary from the raw data pointer returned by the mutable_data() API. + return tensorflow::gtl::MutableArraySlice( + reinterpret_cast(values->mutable_data()), values->size()); } template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal) { - CHECK(literal.shape().element_type() == S32); - return literal.s32s(); +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal) { + static_assert(sizeof(uint64) == sizeof(tensorflow::protobuf_uint64) && + alignof(uint64) == alignof(tensorflow::protobuf_uint64), + "The uint64 and tensorflow::protobuf_uint64 types are not " + "compatible"); + auto values = literal->mutable_u64s(); + // Because of the fact that tensorflow::protobuf_uint64 is defined as uint64_t + // while tensorflow::uint64 is defined as unsigned long long, a + // reinterpret_cast<> is necessary from the raw data pointer returned by the + // mutable_data() API. + return tensorflow::gtl::MutableArraySlice( + reinterpret_cast(values->mutable_data()), values->size()); } template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField(Literal* literal) { - CHECK(literal->shape().element_type() == S32); - return literal->mutable_s32s(); +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal) { + auto values = literal->mutable_f32s(); + return tensorflow::gtl::MutableArraySlice(values->mutable_data(), + values->size()); } template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal) { - CHECK(literal.shape().element_type() == S64); - return AsInt64Slice(literal.s64s()); +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal) { + auto values = literal->mutable_f64s(); + return tensorflow::gtl::MutableArraySlice(values->mutable_data(), + values->size()); +} + +template <> +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal) { + // C++11 standard, basic_string 21.4.1.5, values should be stored + // contiguously. From C++17 a mutable data() member will be provided. + // TODO - there is an endianess problem here. fix it, or wait for uint16 + // support in protobuf + auto values = literal->mutable_f16s(); + return tensorflow::gtl::MutableArraySlice( + reinterpret_cast(&(*values)[0]), + values->size() / sizeof(half)); +} + +template <> +/* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice( + const Literal& literal) { + CHECK_EQ(literal.shape().element_type(), PRED); + return literal.preds(); +} + +template <> +/* static */ tensorflow::gtl::ArraySlice +LiteralUtil::GetArraySlice(const Literal& literal) { + CHECK_EQ(literal.shape().element_type(), U8); + return tensorflow::gtl::ArraySlice( + reinterpret_cast(literal.u8s().data()), + literal.u8s().size()); } template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField( - Literal* literal) { - CHECK(literal->shape().element_type() == S64); - return literal->mutable_s64s(); +/* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice( + const Literal& literal) { + CHECK_EQ(literal.shape().element_type(), S8); + return tensorflow::gtl::ArraySlice( + reinterpret_cast(literal.u8s().data()), + literal.u8s().size()); } template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField(Literal* literal) { - CHECK(literal->shape().element_type() == F32); - return literal->mutable_f32s(); +/* static */ tensorflow::gtl::ArraySlice +LiteralUtil::GetArraySlice(const Literal& literal) { + CHECK_EQ(literal.shape().element_type(), U32); + return literal.u32s(); +} + +template <> +/* static */ tensorflow::gtl::ArraySlice +LiteralUtil::GetArraySlice(const Literal& literal) { + CHECK_EQ(literal.shape().element_type(), U64); + return AsUInt64Slice(literal.u64s()); +} + +template <> +/* static */ tensorflow::gtl::ArraySlice +LiteralUtil::GetArraySlice(const Literal& literal) { + CHECK_EQ(literal.shape().element_type(), S32); + return literal.s32s(); +} + +template <> +/* static */ tensorflow::gtl::ArraySlice +LiteralUtil::GetArraySlice(const Literal& literal) { + CHECK_EQ(literal.shape().element_type(), S64); + return AsInt64Slice(literal.s64s()); } template <> /* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice(const Literal& literal) { - CHECK(literal.shape().element_type() == F64); + CHECK_EQ(literal.shape().element_type(), F64); return literal.f64s(); } template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField(Literal* literal) { - CHECK(literal->shape().element_type() == F64); - return literal->mutable_f64s(); +/* static */ tensorflow::gtl::ArraySlice +LiteralUtil::GetArraySlice(const Literal& literal) { + CHECK_EQ(literal.shape().element_type(), F16); + return tensorflow::gtl::ArraySlice( + reinterpret_cast(literal.f16s().data()), + literal.f16s().size() / sizeof(half)); } template @@ -925,6 +1053,8 @@ static bool AllElementsEqualValue(const Literal& literal, NativeT value) { return AllElementsEqualValue(literal, value); case F64: return AllElementsEqualValue(literal, value); + case F16: + return AllElementsEqualValue(literal, static_cast(value)); case PRED: if (value == 0) { return AllElementsEqualValue(literal, false); @@ -944,6 +1074,8 @@ static bool AllElementsEqualValue(const Literal& literal, NativeT value) { return AllElementsEqualValue(literal, value); case F64: return AllElementsEqualValue(literal, value); + case F16: + return AllElementsEqualValue(literal, static_cast(value)); default: return false; } @@ -968,6 +1100,8 @@ static bool AllElementsEqualValue(const Literal& literal, NativeT value) { return Get(literal, indices) == 0.0f; case F64: return Get(literal, indices) == 0.0; + case F16: + return Get(literal, indices) == static_cast(0.0f); case PRED: return Get(literal, indices) == false; default: @@ -976,51 +1110,77 @@ static bool AllElementsEqualValue(const Literal& literal, NativeT value) { } template <> -/* static */ void LiteralUtil::PopulateWithValue( - int64 value, tensorflow::gtl::ArraySlice dimensions, - Literal* literal) { - *literal->mutable_shape() = ShapeUtil::MakeShape( - primitive_util::NativeToPrimitiveType(), dimensions); - tensorflow::protobuf::RepeatedField* - repeated_field = - GetMutableRepeatedField(literal); - for (int64 i = 0; i < ShapeUtil::ElementsIn(literal->shape()); ++i) { - repeated_field->Add(value); - } +/* static */ void LiteralUtil::Resize(int64 num_elements, bool value, + Literal* literal) { + CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); + literal->mutable_preds()->Resize(num_elements, value); } template <> -/* static */ void LiteralUtil::PopulateWithValue( - uint64 value, tensorflow::gtl::ArraySlice dimensions, - Literal* literal) { - *literal->mutable_shape() = ShapeUtil::MakeShape( - primitive_util::NativeToPrimitiveType(), dimensions); - tensorflow::protobuf::RepeatedField* - repeated_field = - GetMutableRepeatedField(literal); - for (int64 i = 0; i < ShapeUtil::ElementsIn(literal->shape()); ++i) { - repeated_field->Add(value); - } +/* static */ void LiteralUtil::Resize(int64 num_elements, int8 value, + Literal* literal) { + CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); + literal->mutable_u8s()->resize(num_elements, value); +} + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, uint8 value, + Literal* literal) { + CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); + literal->mutable_u8s()->resize(num_elements, value); } template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, int64 value, - Literal* literal) { +/* static */ void LiteralUtil::Resize(int64 num_elements, int32 value, + Literal* literal) { CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); - tensorflow::protobuf::RepeatedField* - repeated_field = - GetMutableRepeatedField(literal); - repeated_field->Resize(num_elements, value); + literal->mutable_s32s()->Resize(num_elements, value); } template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, uint64 value, - Literal* literal) { +/* static */ void LiteralUtil::Resize(int64 num_elements, uint32 value, + Literal* literal) { CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); - tensorflow::protobuf::RepeatedField* - repeated_field = - GetMutableRepeatedField(literal); - repeated_field->Resize(num_elements, value); + literal->mutable_u32s()->Resize(num_elements, value); +} + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, int64 value, + Literal* literal) { + CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); + literal->mutable_s64s()->Resize(num_elements, value); +} + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, uint64 value, + Literal* literal) { + CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); + literal->mutable_u64s()->Resize(num_elements, value); +} + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, float value, + Literal* literal) { + CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); + literal->mutable_f32s()->Resize(num_elements, value); +} + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, double value, + Literal* literal) { + CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); + literal->mutable_f64s()->Resize(num_elements, value); +} + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, half value, + Literal* literal) { + CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); + literal->mutable_f16s()->resize(num_elements * sizeof(half)); + auto data = GetMutableArraySlice(literal); + for (int i = 0; i < num_elements; i++) { + data[i] = value; + } } } // namespace xla diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 21bb2e46cf2ebcd72bcce393a1e5526f41757544..2da010d56e38c18aed1362a4d2cff1708740ffe9 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -100,6 +101,31 @@ class LiteralUtil { values, const Layout& layout); + // Create a new Literal object with the shape specified as parameter. + // The content of the literal values is the default value of the primitive + // type of literal itself (0 for numeric types, and false for predicates). + static std::unique_ptr CreateFromShape(const Shape& shape); + + // Create a new Literal object with its values havings the primitive_type + // type, and with dimensions defined by the dimensions parameter. + // The content of the literal values is the default value of the primitive + // type of literal itself (0 for numeric types, and false for predicates). + static std::unique_ptr CreateFromDimensions( + PrimitiveType primitive_type, + tensorflow::gtl::ArraySlice dimensions); + + // Copies the values from src_literal, starting at src_base shape indexes, + // to dest_literal, starting at dest_base, where the copy size in each + // dimension is specified by copy_size. + // The src_literal and dest_literal must have the same primitive type, + // src_base+copy_size must fit the source literal dimensions, as well as + // dest_base+copy_size must fit the destination literal dimensions. + static Status Copy(const Literal& src_literal, + tensorflow::gtl::ArraySlice src_base, + Literal* dest_literal, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size); + // Creates a new value that has the equivalent value as literal, but conforms // to new_layout; e.g. a literal matrix that was in {0, 1} minor-to-major // dimension layout can be re-layed-out as {1, 0} minor-to-major dimension @@ -213,6 +239,11 @@ class LiteralUtil { // Clones literal into an owned unique_ptr version. static std::unique_ptr CloneToUnique(const Literal& literal); + // Returns the linear index of the given index within the literal's + // element_type repeated field. + static int64 LinearIndex(const Literal& literal, + tensorflow::gtl::ArraySlice multi_index); + // Gets or sets an element in the literal at the given index. The index is // CHECKed against the dimension sizes. template @@ -223,6 +254,12 @@ class LiteralUtil { tensorflow::gtl::ArraySlice multi_index, NativeT value); + // Retrieves the mutable array slice interface which can be used to manipulate + // pre-allocated literal values. + template + static tensorflow::gtl::MutableArraySlice GetMutableArraySlice( + Literal* literal); + // Returns the element value at index (0, ..., 0), however many zeroes are // required for that index. template @@ -257,9 +294,8 @@ class LiteralUtil { // like representation in a protobuf). static void EachCellAsString( const Literal& literal, - std::function indices, - const string& value)> - per_cell); + const std::function indices, + const string& value)>& per_cell); template static void EachCell( const Literal& literal, @@ -315,6 +351,14 @@ class LiteralUtil { const Layout& layout, Literal* literal); + // Populates literal values by calling the generator function for every cell + // in the literal object. + template + static Status Populate( + Literal* literal, + const std::function indexes)>& + generator); + // Creates a Literal of the given dimensions with all elements set to the // given value. template @@ -383,70 +427,73 @@ class LiteralUtil { static_assert(!std::is_same::value, "Cannot map native type to primitive type."); } - template - static tensorflow::protobuf::RepeatedField* GetMutableRepeatedField( - Literal* literal) { - // Make the expression depend on the template parameter NativeT so - // that this compile-time error only apperas if this function is - // instantiated with some concrete type that is not specialized - // below. - static_assert(!std::is_same::value, - "Cannot map native type to primitive type."); - } - // Returns the linear index of the given index within the literal's - // element_type repeated field. - static int64 LinearIndex(const Literal& literal, - tensorflow::gtl::ArraySlice multi_index); + // Internal template helper for the Copy() API, matching its arguments one by + // one. + template + static Status CopyRange(const Literal& src_literal, + tensorflow::gtl::ArraySlice src_base, + Literal* dest_literal, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size); + + // Utility structure which is used to create the optimal configuration for + // a ShapeUtil::ForEachIndex() scan across two literals. + struct StrideConfig { + StrideConfig(const Shape& source_shape, const Shape& dest_shape, + tensorflow::gtl::ArraySlice dimensions); + + // The dimensions of the stride operation. Essentially every dimension + // will be iterated from base[i] to base[i]+dimensions[i], in step[i] + // steps. + tensorflow::gtl::ArraySlice dimensions; + DimensionVector base; + DimensionVector step; + int64 minor_dimension = 0; + // The size of the strides for source and destination. One of the two + // (the one looping through its most minor dimension) will be 1, while + // the other will be the stride size at the dimension matching the other + // shape most minor dimension being scanned. + int64 dest_stride = 1; + int64 source_stride = 1; + // The size of the inner loop on the most minor dimension. + int64 minor_loop_size = 1; + }; TF_DISALLOW_COPY_AND_ASSIGN(LiteralUtil); }; // Declarations of template specializations for GetArraySlice and -// GetMutableRepeatedField. The specializations map native type to XLA primitive +// GetMutableArraySlice. The specializations map native type to XLA primitive // type. template <> /* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice( const Literal& literal); template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField(Literal* literal); +/* static */ tensorflow::gtl::ArraySlice +LiteralUtil::GetArraySlice(const Literal& literal); template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal); +/* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice( + const Literal& literal); template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField(Literal* literal); +/* static */ tensorflow::gtl::ArraySlice +LiteralUtil::GetArraySlice(const Literal& literal); template <> /* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice(const Literal& literal); -template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField( - Literal* literal); - template <> /* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice(const Literal& literal); -template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField(Literal* literal); - template <> /* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice(const Literal& literal); -template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField( - Literal* literal); - template <> /* static */ inline tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice(const Literal& literal) { @@ -454,22 +501,98 @@ LiteralUtil::GetArraySlice(const Literal& literal) { return literal.f32s(); } -template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField(Literal* literal); - template <> /* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice(const Literal& literal); template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField(Literal* literal); +/* static */ tensorflow::gtl::ArraySlice +LiteralUtil::GetArraySlice(const Literal& literal); + +template <> +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal); + +template <> +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal); + +template <> +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal); + +template <> +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal); + +template <> +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal); + +template <> +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal); + +template <> +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal); + +template <> +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal); + +template <> +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal); + +template <> +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, bool value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, int8 value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, uint8 value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, int32 value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, uint32 value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, int64 value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, uint64 value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, float value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, double value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, half value, + Literal* literal); template /* static */ std::unique_ptr LiteralUtil::CreateR0(NativeT value) { auto literal = MakeUnique(); - PopulateR0(value, literal.get()); + PopulateR0(value, literal.get()); return literal; } @@ -695,12 +818,20 @@ template <> return literal.u8s()[linear_index]; } +template <> +/* static */ inline half LiteralUtil::Get( + const Literal& literal, tensorflow::gtl::ArraySlice multi_index) { + CHECK(literal.shape().element_type() == F16); + int64 linear_index = LinearIndex(literal, multi_index); + return GetArraySlice(literal)[linear_index]; +} + template /* static */ void LiteralUtil::Set( Literal* literal, tensorflow::gtl::ArraySlice multi_index, NativeT value) { int64 linear_index = LinearIndex(*literal, multi_index); - GetMutableRepeatedField(literal)->Set(linear_index, value); + GetMutableArraySlice(literal).at(linear_index) = value; } template <> @@ -760,44 +891,12 @@ template } template -/* static */ void LiteralUtil::PopulateR0(NativeT value, Literal* literal) { - *literal->mutable_shape() = ShapeUtil::MakeShape( - primitive_util::NativeToPrimitiveType(), {}); - tensorflow::protobuf::RepeatedField* repeated_field = - GetMutableRepeatedField(literal); - repeated_field->Add(value); -} - -template <> -/* static */ inline void LiteralUtil::PopulateR0(uint8 value, - Literal* literal) { +/* static */ inline void LiteralUtil::PopulateR0(NativeT value, + Literal* literal) { *literal->mutable_shape() = - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {}); - literal->mutable_u8s()->push_back(value); -} - -template <> -/* static */ inline void LiteralUtil::PopulateR0(int8 value, - Literal* literal) { - *literal->mutable_shape() = - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {}); - literal->mutable_u8s()->push_back(value); -} - -template <> -/* static */ inline void LiteralUtil::PopulateR0(uint64 value, - Literal* literal) { - *literal->mutable_shape() = - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {}); - literal->mutable_u64s()->Add(value); -} - -template <> -/* static */ inline void LiteralUtil::PopulateR0(int64 value, - Literal* literal) { - *literal->mutable_shape() = - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {}); - literal->mutable_s64s()->Add(value); + ShapeUtil::MakeShape( + primitive_util::NativeToPrimitiveType(), {}); + Resize(1, value, literal); } template @@ -944,65 +1043,72 @@ template literal); } +template +/* static */ Status LiteralUtil::Populate( + Literal* literal, + const std::function indexes)>& + generator) { + const Shape& shape = literal->shape(); + int64 rank = ShapeUtil::Rank(shape); + TF_RET_CHECK(shape.element_type() == + primitive_util::NativeToPrimitiveType()); + tensorflow::gtl::MutableArraySlice data = + GetMutableArraySlice(literal); + if (rank > 0) { + StrideConfig stride_config(shape, shape, AsInt64Slice(shape.dimensions())); + DimensionVector minor_scan_indexes(rank, 0); + int64 minor_dimension_size = + ShapeUtil::GetDimension(shape, stride_config.minor_dimension); + + auto init_function = [&](const std::vector& indexes) { + int64 index = LinearIndex(*literal, indexes); + std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin()); + for (int64 i = 0; i < minor_dimension_size; ++i) { + minor_scan_indexes[stride_config.minor_dimension] = i; + data.at(index + i) = generator(minor_scan_indexes); + } + return true; + }; + ShapeUtil::ForEachIndex(shape, stride_config.base, stride_config.dimensions, + stride_config.step, init_function); + } else { + data.at(0) = generator({}); + } + return Status::OK(); +} + template /* static */ void LiteralUtil::PopulateWithValue( NativeT value, tensorflow::gtl::ArraySlice dimensions, Literal* literal) { *literal->mutable_shape() = ShapeUtil::MakeShape( primitive_util::NativeToPrimitiveType(), dimensions); - tensorflow::protobuf::RepeatedField* repeated_field = - GetMutableRepeatedField(literal); - for (int64 i = 0; i < ShapeUtil::ElementsIn(literal->shape()); ++i) { - repeated_field->Add(value); - } + Resize(ShapeUtil::ElementsIn(literal->shape()), value, literal); } -template <> -/* static */ void LiteralUtil::PopulateWithValue( - int64 value, tensorflow::gtl::ArraySlice dimensions, - Literal* literal); - -template <> -/* static */ void LiteralUtil::PopulateWithValue( - uint64 value, tensorflow::gtl::ArraySlice dimensions, - Literal* literal); - template /* static */ std::unique_ptr LiteralUtil::Convert( const Literal& literal) { + const Shape& shape = literal.shape(); auto result_literal = MakeUnique(); - Shape result_shape = literal.shape(); - result_shape.set_element_type( + Shape* result_shape = result_literal->mutable_shape(); + *result_shape = shape; + result_shape->set_element_type( primitive_util::NativeToPrimitiveType()); - *result_literal->mutable_shape() = result_shape; - LiteralUtil::Reserve(ShapeUtil::ElementsIn(result_shape), + LiteralUtil::Reserve(ShapeUtil::ElementsIn(*result_shape), result_literal.get()); - LiteralUtil::EachCell( - literal, - [&](tensorflow::gtl::ArraySlice indices, NativeSrcT value) { - LiteralUtil::Set(result_literal.get(), indices, - static_cast(value)); - }); + tensorflow::gtl::ArraySlice src_data = + GetArraySlice(literal); + tensorflow::gtl::MutableArraySlice dest_data = + GetMutableArraySlice(result_literal.get()); + int64 num_elements = ShapeUtil::ElementsIn(shape); + + for (int64 i = 0; i < num_elements; ++i) { + dest_data[i] = static_cast(src_data[i]); + } return result_literal; } -template -/* static */ void LiteralUtil::Resize(int64 num_elements, NativeT value, - Literal* literal) { - CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); - tensorflow::protobuf::RepeatedField* repeated_field = - GetMutableRepeatedField(literal); - repeated_field->Resize(num_elements, value); -} - -template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, int64 value, - Literal* literal); - -template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, uint64 value, - Literal* literal); - template /* static */ std::unique_ptr LiteralUtil::CreateFullWithMonotonicDim0MajorLayout( @@ -1022,10 +1128,7 @@ LiteralUtil::CreateFullWithMonotonicDim0MajorLayout( template /* static */ std::unique_ptr LiteralUtil::Replicate( const Literal& input, int64 times) { - // Ranks greater than 8 are very rare, so use InlinedVector to store - // the bounds and indices. - static constexpr int kInlineRank = 8; - tensorflow::gtl::InlinedVector bounds = {times}; + DimensionVector bounds = {times}; bounds.reserve(input.shape().dimensions_size() + 1); for (int64 bound : input.shape().dimensions()) { bounds.push_back(bound); @@ -1039,8 +1142,7 @@ template } Reserve(elements, literal.get()); - tensorflow::gtl::InlinedVector output_indices( - bounds.size(), 0); + DimensionVector output_indices(bounds.size(), 0); tensorflow::gtl::ArraySlice input_indices = output_indices; input_indices.remove_prefix(1); diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index 91971c3e24c326148322202ffb684285d980d4c7..9a09822174d9c93c8195af193f34017268bbc503 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -23,6 +23,8 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -103,6 +105,9 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { auto f32_lit = LiteralUtil::CreateR0(3.14f); ASSERT_EQ("3.14", LiteralUtil::ToString(*f32_lit)); + + auto f16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); + ASSERT_EQ("0.5", LiteralUtil::ToString(*f16_lit)); } TEST_F(LiteralUtilTest, LiteralVectorToString) { @@ -371,6 +376,15 @@ TEST_F(LiteralUtilTest, IsAll) { EXPECT_FALSE( LiteralUtil::IsAll(*LiteralUtil::CreateR2({{9, 8}, {8, 8}}), 8)); + half h8(8.0f); + half h9(9.0f); + EXPECT_TRUE( + LiteralUtil::IsAll(*LiteralUtil::CreateR2({{h8}, {h8}}), 8)); + EXPECT_FALSE( + LiteralUtil::IsAll(*LiteralUtil::CreateR2({{h8}, {h9}}), 8)); + EXPECT_FALSE( + LiteralUtil::IsAll(*LiteralUtil::CreateR2({{h9}, {h8}}), 8)); + auto uint64_max = std::numeric_limits::max(); EXPECT_FALSE(LiteralUtil::IsAll( *LiteralUtil::CreateR2( @@ -467,6 +481,26 @@ TEST_F(LiteralUtilTest, ReshapeR4) { EXPECT_TRUE(LiteralUtil::Equal(*expected, *reshape)); } +TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) { + // clang-format off + // F32[1x3x2x4] + auto original = LiteralUtil::CreateR4WithLayout({{ + {{10, 11, 12, 13}, {14, 15, 16, 17}}, + {{18, 19, 20, 21}, {22, 23, 24, 25}}, + {{26, 27, 28, 29}, {30, 31, 32, 33}}, + }}, layout_r4_dim0minor_); + // F32[1x3x4x2] + auto expected = LiteralUtil::CreateR3WithLayout({ + {{10, 11}, {12, 13}, {14, 15}, {16, 17}}, + {{18, 19}, {20, 21}, {22, 23}, {24, 25}}, + {{26, 27}, {28, 29}, {30, 31}, {32, 33}}, + }, layout_r3_dim0major_); + // clang-format on + auto reshape = LiteralUtil::Reshape(*original, {3, 4, 2}).ConsumeValueOrDie(); + + EXPECT_TRUE(LiteralUtil::Equal(*expected, *reshape)); +} + TEST_F(LiteralUtilTest, TransposeR0) { auto original = LiteralUtil::CreateR0(1.7f); auto reshape = LiteralUtil::Transpose(*original, /*permutation=*/{}); @@ -637,6 +671,30 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2U64) { EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); } +TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { + Literal output; + half h(0.25f); + LiteralUtil::PopulateWithValue(h, {}, &output); + auto expected = LiteralUtil::CreateR0(h); + EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); +} + +TEST_F(LiteralUtilTest, PopulateWithValueR1F16) { + Literal output; + half h(0.5f); + LiteralUtil::PopulateWithValue(h, {3}, &output); + auto expected = LiteralUtil::CreateR1({h, h, h}); + EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); +} + +TEST_F(LiteralUtilTest, PopulateWithValueR2F16) { + Literal output; + half h(2.0f); + LiteralUtil::PopulateWithValue(h, {2, 2}, &output); + auto expected = LiteralUtil::CreateR2({{h, h}, {h, h}}); + EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); +} + TEST_F(LiteralUtilTest, ReplicateR2U32) { auto input = LiteralUtil::CreateR2( {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); @@ -648,5 +706,156 @@ TEST_F(LiteralUtilTest, ReplicateR2U32) { EXPECT_TRUE(LiteralUtil::Equal(*output, *expected)); } +TEST_F(LiteralUtilTest, Copy) { + const int64 dimensions[] = {17, 15, 34, 21}; + const int64 layouts[][4] = { + {3, 2, 1, 0}, {0, 2, 1, 3}, {0, 1, 2, 3}, {2, 0, 3, 1}, {1, 3, 0, 2}}; + for (const auto& layout : layouts) { + Shape shape = ShapeUtil::MakeShapeWithLayout( + primitive_util::NativeToPrimitiveType(), dimensions, layout); + auto blank = LiteralUtil::CreateFromShape(shape); + auto source = LiteralUtil::CreateFromShape(shape); + const int64 zero_base[] = {0, 0, 0, 0}; + const int64 step[] = {1, 1, 1, 1}; + uint32 seqnr = 0; + auto init_proc = [&](const std::vector& indexes) { + LiteralUtil::Set(source.get(), indexes, ++seqnr); + return true; + }; + + ShapeUtil::ForEachIndex(source->shape(), zero_base, dimensions, step, + init_proc); + + const int64 src_base[] = {3, 1, 5, 7}; + const int64 dest_base[] = {6, 4, 12, 2}; + const int64 copy_size[] = {7, 8, 11, 9}; + + TF_EXPECT_OK(LiteralUtil::Copy(*source, src_base, blank.get(), dest_base, + copy_size)); + std::vector source_indexes(TF_ARRAYSIZE(dimensions), 0); + std::vector blank_indexes(TF_ARRAYSIZE(dimensions), 0); + bool matched = true; + auto check_proc = [&](const std::vector& indexes) { + std::copy(indexes.begin(), indexes.end(), source_indexes.begin()); + std::transform(source_indexes.begin(), source_indexes.end(), src_base, + source_indexes.begin(), std::plus()); + std::copy(indexes.begin(), indexes.end(), blank_indexes.begin()); + std::transform(blank_indexes.begin(), blank_indexes.end(), dest_base, + blank_indexes.begin(), std::plus()); + auto bval = LiteralUtil::Get(*blank, blank_indexes); + matched = (bval != 0 && + bval == LiteralUtil::Get(*source, source_indexes)); + return matched; + }; + ShapeUtil::ForEachIndex(source->shape(), zero_base, copy_size, step, + check_proc); + EXPECT_TRUE(matched); + } +} + +TEST_F(LiteralUtilTest, CopyScalars) { + auto zero = LiteralUtil::CreateR0(0); + auto nine = LiteralUtil::CreateR0(9); + TF_EXPECT_OK(LiteralUtil::Copy(*nine, {}, zero.get(), {}, {})); + EXPECT_TRUE(LiteralUtil::Equal(*zero, *nine)); + + auto vect = LiteralUtil::CreateR1({3, 4, 9, 12, 5, 17, 21}); + TF_EXPECT_OK(LiteralUtil::Copy(*vect, {5}, zero.get(), {}, {})); + EXPECT_EQ(LiteralUtil::Get(*zero, {}), 17); + TF_EXPECT_OK(LiteralUtil::Copy(*zero, {}, vect.get(), {4}, {})); + EXPECT_EQ(LiteralUtil::Get(*vect, {4}), 17); +} + +TEST_F(LiteralUtilTest, F16) { + // Verify that the internal data views are consistent and that they + // are in little endian format + // TODO - modify if we make the data format machine endianess dependent + auto m1 = LiteralUtil::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2})); + Literal* l1 = m1.get(); + const char* d1 = (const char*)LiteralUtil::InternalData(*l1); + EXPECT_EQ(d1[0], 0); + EXPECT_EQ(d1[1], 0); + EXPECT_EQ(d1[2], 0); + EXPECT_EQ(d1[3], 0); + EXPECT_EQ(d1[4], 0); + EXPECT_EQ(d1[5], 0); + EXPECT_EQ(d1[6], 0); + EXPECT_EQ(d1[7], 0); + EXPECT_EQ(LiteralUtil::InternalData(*l1), + LiteralUtil::MutableInternalData(l1)); + + half h1(1.0f); + half h2(2.0f); + auto m2 = LiteralUtil::CreateR2({{h1, h2}, {h2, h1}}); + Literal* l2 = m2.get(); + const char* d2 = (const char*)LiteralUtil::InternalData(*l2); + EXPECT_EQ(d2[0], 0); + EXPECT_EQ(d2[1], 0x3C); + EXPECT_EQ(d2[2], 0); + EXPECT_EQ(d2[3], 0x40); + EXPECT_EQ(d2[4], 0); + EXPECT_EQ(d2[5], 0x40); + EXPECT_EQ(d2[6], 0); + EXPECT_EQ(d2[7], 0x3C); + EXPECT_EQ(LiteralUtil::InternalData(*l2), + LiteralUtil::MutableInternalData(l2)); +} + +TEST_F(LiteralUtilTest, Populate) { + struct PopulateData { + std::vector dimensions; + std::vector layout; + } populate_data[] = { + {{}, {}}, + {{16}, {0}}, + {{4, 16}, {1, 0}}, + {{21, 12}, {0, 1}}, + {{6, 11, 17}, {2, 0, 1}}, + {{6, 11, 5, 17}, {3, 2, 0, 1}}, + }; + for (const auto& data : populate_data) { + Shape shape = ShapeUtil::MakeShapeWithLayout( + primitive_util::NativeToPrimitiveType(), data.dimensions, + data.layout); + auto literal = LiteralUtil::CreateFromShape(shape); + auto generator = [&](tensorflow::gtl::ArraySlice indexes) -> uint32 { + // Offsets from linear index just to avoid R0 literals to be initialized + // with zero. + return LiteralUtil::LinearIndex(*literal, indexes) + 17; + }; + TF_EXPECT_OK(LiteralUtil::Populate(literal.get(), generator)); + + std::vector zero_base(data.dimensions.size(), 0); + std::vector step(data.dimensions.size(), 1); + bool matched = true; + auto check_function = [&](const std::vector& indexes) { + auto value = LiteralUtil::Get(*literal, indexes); + matched = matched && (value == generator(indexes)); + return matched; + }; + ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step, + check_function); + EXPECT_TRUE(matched); + } +} + +TEST_F(LiteralUtilTest, ConvertR4) { + // clang-format off + auto original = LiteralUtil::CreateR4WithLayout({{ + {{10, 11, 12, 13}, {14, 15, 16, 17}}, + {{18, 19, 20, 21}, {22, 23, 24, 25}}, + {{26, 27, 28, 29}, {30, 31, 32, 33}}, + }}, layout_r4_dim0major_); + auto expected = LiteralUtil::CreateR4WithLayout({{ + {{10, 11, 12, 13}, {14, 15, 16, 17}}, + {{18, 19, 20, 21}, {22, 23, 24, 25}}, + {{26, 27, 28, 29}, {30, 31, 32, 33}}, + }}, layout_r4_dim0major_); + // clang-format on + auto converted = LiteralUtil::Convert(*original); + + EXPECT_TRUE(LiteralUtil::Equal(*expected, *converted)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/metric_table_report.cc b/tensorflow/compiler/xla/metric_table_report.cc index cd7c42f6e17e15b5e1c6ebfa1f24a40a9003a63e..0d4ddc239243b79d47b6a1672b65abe9b23e7b52 100644 --- a/tensorflow/compiler/xla/metric_table_report.cc +++ b/tensorflow/compiler/xla/metric_table_report.cc @@ -38,7 +38,8 @@ void MetricTableReport::SetEntryName(string entry_name) { void MetricTableReport::SetShowAllEntries() { max_entries_to_show_ = std::numeric_limits::max(); - max_metric_proportion_to_show = 1.1; // more than 100% + max_entries_per_category_to_show_ = std::numeric_limits::max(); + max_metric_proportion_to_show_ = 1.1; // more than 100% } void MetricTableReport::SetShowCategoryTable() { show_category_table_ = true; } @@ -141,7 +142,7 @@ void MetricTableReport::AppendCategoryTable() { int64 categories_shown = 0; for (const auto& category : categories) { if (categories_shown >= max_entries_to_show_ || - metric_sum / expected_metric_sum_ > max_metric_proportion_to_show) { + metric_sum / expected_metric_sum_ > max_metric_proportion_to_show_) { break; } ++categories_shown; @@ -156,15 +157,14 @@ void MetricTableReport::AppendCategoryTable() { entry_name_, ")"); AppendTableRow(text, category.metric_sum, metric_sum); - // Show the top few entries in the category. - const int64 kMaxToShow = 5; + // Show the top entries in the category. const char* const kIndentPrefix = " * "; - int64 entries_to_show = - std::min(kMaxToShow, category.entries.size()); - if (category.entries.size() == kMaxToShow + 1) { + int64 entries_to_show = std::min(max_entries_per_category_to_show_, + category.entries.size()); + if (category.entries.size() == entries_to_show + 1) { // May as well show the last entry on the line that would otherwise say // that there is a single entry not shown. - entries_to_show = category.entries.size(); + ++entries_to_show; } for (int64 i = 0; i < entries_to_show; ++i) { AppendLine(kIndentPrefix, MetricPercent(category.entries[i]->metric), " ", @@ -193,7 +193,7 @@ void MetricTableReport::AppendEntryTable() { int64 entries_shown = 0; for (const auto& entry : entries_) { if (entries_shown >= max_entries_to_show_ || - metric_sum / expected_metric_sum_ > max_metric_proportion_to_show) { + metric_sum / expected_metric_sum_ > max_metric_proportion_to_show_) { break; } ++entries_shown; diff --git a/tensorflow/compiler/xla/metric_table_report.h b/tensorflow/compiler/xla/metric_table_report.h index e967627bff4446a695bfae514faac4b1acca4968..818fb1d3fe0b8bbe1a8eba363ff6445e2f3df9d2 100644 --- a/tensorflow/compiler/xla/metric_table_report.h +++ b/tensorflow/compiler/xla/metric_table_report.h @@ -103,6 +103,7 @@ class MetricTableReport { private: static constexpr double kDefaultMaxMetricProportionToShow = 0.99; static constexpr int64 kDefaultMaxEntriesToShow = 100; + static constexpr int64 kDefaultMaxEntriesPerCategoryToShow = 5; // Append all parameters to the report. template @@ -162,7 +163,8 @@ class MetricTableReport { // These members control how many categories and entries to show in tables. int64 max_entries_to_show_ = kDefaultMaxEntriesToShow; - double max_metric_proportion_to_show = kDefaultMaxMetricProportionToShow; + int64 max_entries_per_category_to_show_ = kDefaultMaxEntriesPerCategoryToShow; + double max_metric_proportion_to_show_ = kDefaultMaxMetricProportionToShow; // The report that is being created. string report_; diff --git a/tensorflow/compiler/xla/port/BUILD b/tensorflow/compiler/xla/port/BUILD deleted file mode 100644 index 6fc5f1185c9d56075f18928e4b2c8e3819cf9ddd..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/port/BUILD +++ /dev/null @@ -1,33 +0,0 @@ -licenses(["notice"]) # Apache 2.0 - -# Filegroup used to collect source files for dependency checking. -filegroup( - name = "c_srcs", - data = glob([ - "**/*.cc", - "**/*.h", - ]), - visibility = ["//tensorflow/compiler/xla:internal"], -) - -cc_library( - name = "initialize", - hdrs = ["initialize.h"], - visibility = [ - "//tensorflow/compiler/xla:__subpackages__", - ], -) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc index e3909ae8e9736351d3ee91332572b5db62727289..e4e37177a2d74e6da20300f1439942a146ad8d49 100644 --- a/tensorflow/compiler/xla/primitive_util.cc +++ b/tensorflow/compiler/xla/primitive_util.cc @@ -78,6 +78,11 @@ PrimitiveType NativeToPrimitiveType() { return F64; } +template <> +PrimitiveType NativeToPrimitiveType() { + return F16; +} + bool IsFloatingPointType(PrimitiveType type) { return type == F16 || type == F32 || type == F64; } diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index 78f0ee6f592d9b9ec2ed85f23297634c5e2e4d41..162a11c7d2966346979b98c804917203f82c806c 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -75,6 +75,8 @@ template <> PrimitiveType NativeToPrimitiveType(); template <> PrimitiveType NativeToPrimitiveType(); +template <> +PrimitiveType NativeToPrimitiveType(); bool IsFloatingPointType(PrimitiveType type); @@ -150,6 +152,10 @@ template <> struct PrimitiveTypeToNative { using type = double; }; +template <> +struct PrimitiveTypeToNative { + using type = half; +}; } // namespace primitive_util } // namespace xla diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index 5630033ac89b3aefbb8503f8e04fe268f9ab4da6..4194d5fc6be0ad552e9fe6dd14b51fa0a67f2eca 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -180,14 +180,28 @@ ReferenceUtil::ReduceWindow4DGeneric( const tensorflow::gtl::ArraySlice& stride, Padding padding) { std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3(), operand.n4()}; - auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); + return ReduceWindow4DGeneric( + operand, init, reduce_func, window, stride, + xla::MakePadding(dim_lengths, window, stride, padding)); +} + +/* static */ std::unique_ptr> +ReferenceUtil::ReduceWindow4DGeneric( + const Array4D& operand, float init, + const std::function& reduce_func, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, + const tensorflow::gtl::ArraySlice>& padding) { + std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3(), + operand.n4()}; std::vector window_counts(window.size(), 0); std::vector pad_low(window.size(), 0); for (int64 i = 0; i < window.size(); ++i) { + int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second; window_counts[i] = - WindowCount(dim_lengths[i], window[i], stride[i], padding); - pad_low[i] = padding_both[i].first; + window_util::StridedBound(padded_width, window[i], stride[i]); + pad_low[i] = padding[i].first; } auto result = MakeUnique>(window_counts[0], window_counts[1], window_counts[2], window_counts[3]); diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index eb1eea7fc4c68a3a29cdf8b7eef9773b990b1bbc..f58f0bdc9f51dff62c10dda4aba7aac03e689ce7 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -162,6 +162,12 @@ class ReferenceUtil { const std::function& reduce_func, const tensorflow::gtl::ArraySlice& window, const tensorflow::gtl::ArraySlice& stride, Padding padding); + static std::unique_ptr> ReduceWindow4DGeneric( + const Array4D& operand, float init, + const std::function& reduce_func, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, + const tensorflow::gtl::ArraySlice>& padding); // Performs select and scatter with Greater Than or equal as the select, plus // as the scatter, and Same Padding. @@ -400,7 +406,46 @@ class ReferenceUtil { const PaddingConfig& padding, const float pad); + // ApplyElementwise2D(f, x, y, ...) returns the Array2D formed by running + // f(x[i], y[i], ...) for each array element in the Array2Ds x, y, .... + // + // The given arrays must have the same size and element type, and the return + // type of f must be implicitly convertible to the arrays' element type. + // + // Example usage: + // + // Array2D x, y, z = ...; + // std::unique_ptr result = ReferenceUtil::ApplyElementwise2D( + // [](float a, float b, float c) { return a * b + c; }, x, y, z); + // + template + static std::unique_ptr> ApplyElementwise2D( + F&& f, const Array2D& array1, const Array2D&... arrays) { + AssertSameSize2D(array1, arrays...); + auto result = MakeUnique>(array1.n1(), array1.n2()); + for (int64 i = 0; i < array1.n1(); ++i) { + for (int64 j = 0; j < array1.n2(); ++j) { + (*result)(i, j) = f(array1(i, j), arrays(i, j)...); + } + } + return result; + } + private: + template + static void AssertSameSize2D(const Array2D& array1, + const Array2D& array2, + const Array2D&... arrays) { + static_assert(std::is_same::value, "Args must be same type."); + CHECK_EQ(array1.n1(), array2.n1()); + CHECK_EQ(array1.n2(), array2.n2()); + AssertSameSize2D(array2, arrays...); + } + + // Recursive base case for AssertSameSize2D. + template + static void AssertSameSize2D(const Array1& array1) {} + TF_DISALLOW_COPY_AND_ASSIGN(ReferenceUtil); }; diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc index b0aa55840283c011099ef0f4263307a4ef101382..f839ac019df07c5c5e07eed856ea55463bb3efae 100644 --- a/tensorflow/compiler/xla/reference_util_test.cc +++ b/tensorflow/compiler/xla/reference_util_test.cc @@ -52,9 +52,9 @@ class ReferenceUtilTest : public ::testing::Test { TEST_F(ReferenceUtilTest, TransposeArray2D) { auto result = ReferenceUtil::TransposeArray2D(*matrix_); - auto result_literal = LiteralUtil::CreateR2FromArray2D(*result); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 4.f}, {2.f, 5.f}, {3.f, 6.f}}, - *result_literal, ErrorSpec(0.0001)); + *actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, MatmulArray2D) { @@ -62,32 +62,32 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) { {7.f, 8.f}, {9.f, 10.f}, {11.f, 12.f}, }); auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs); - auto result_literal = LiteralUtil::CreateR2FromArray2D(*result); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{58.f, 64.f}, {139.f, 154.f}}, - *result_literal, ErrorSpec(0.0001)); + *actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, ReduceToColArray2D) { auto add = [](float lhs, float rhs) { return lhs + rhs; }; auto result = ReferenceUtil::ReduceToColArray2D(*matrix_, 0.0f, add); - auto result_literal = LiteralUtil::CreateR1(*result); - LiteralTestUtil::ExpectR1Near({6.f, 15.f}, *result_literal, + auto actual_literal = LiteralUtil::CreateR1(*result); + LiteralTestUtil::ExpectR1Near({6.f, 15.f}, *actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, ReduceToRowArray2D) { auto add = [](float lhs, float rhs) { return lhs + rhs; }; auto result = ReferenceUtil::ReduceToRowArray2D(*matrix_, 0.0f, add); - auto result_literal = LiteralUtil::CreateR1(*result); - LiteralTestUtil::ExpectR1Near({5.f, 7.f, 9.f}, *result_literal, + auto actual_literal = LiteralUtil::CreateR1(*result); + LiteralTestUtil::ExpectR1Near({5.f, 7.f, 9.f}, *actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, MapArray2D) { auto identity = [](float value) { return log(exp(value)); }; auto result = ReferenceUtil::MapArray2D(*matrix_, identity); - auto result_literal = LiteralUtil::CreateR2FromArray2D(*result); - LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *result_literal, + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); + LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *actual_literal, ErrorSpec(0.0001)); } @@ -96,9 +96,9 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray2D) { return value + row + col; }; auto result = ReferenceUtil::MapWithIndexArray2D(*matrix_, add_index); - auto result_literal = LiteralUtil::CreateR2FromArray2D(*result); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 3.f, 5.f}, {5.f, 7.f, 9.f}}, - *result_literal, ErrorSpec(0.0001)); + *actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, MapArray4D) { @@ -107,11 +107,11 @@ TEST_F(ReferenceUtilTest, MapArray4D) { input->FillWithMultiples(1.0f); auto multiply_by_two = [](float value) { return 2 * value; }; auto result = ReferenceUtil::MapArray4D(*input, multiply_by_two); - auto result_literal = LiteralUtil::CreateR4FromArray4D(*result); + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result); Array4D expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5); expected.FillWithMultiples(2.0f); - LiteralTestUtil::ExpectR4NearArray4D(expected, *result_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, ErrorSpec(0.0001)); } @@ -124,11 +124,11 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) { return value - (3 * 4 * 5 * plane + 4 * 5 * depth + 5 * height + width); }; auto result = ReferenceUtil::MapWithIndexArray4D(*input, subtract_index); - auto result_literal = LiteralUtil::CreateR4FromArray4D(*result); + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result); Array4D expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5); expected.Fill(0.0f); - LiteralTestUtil::ExpectR4NearArray4D(expected, *result_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, ErrorSpec(0.0001)); } @@ -302,5 +302,17 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) { ErrorSpec(0.0001)); } +TEST_F(ReferenceUtilTest, ApplyElementwise2D) { + Array2D a({{1, 2}, {3, 4}}); + Array2D b({{10, 20}, {30, 40}}); + Array2D c({{100, 200}, {300, 400}}); + + auto actual = ReferenceUtil::ApplyElementwise2D( + [](float x, float y, float z) { return 100 * x + 10 * y + z; }, a, b, c); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*actual); + LiteralTestUtil::ExpectR2Near({{300.f, 600.f}, {900.f, 1200.f}}, + *actual_literal, ErrorSpec(0.0001)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index e5a921674f675168b0c30198ce25146d7bc91302..75a0f6f0f3be116343343b6ef45afc3913e35c61 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -195,7 +195,6 @@ cc_library( deps = [ ":hlo", "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", ], @@ -407,6 +406,27 @@ cc_library( ], ) +cc_library( + name = "compile_only_service", + srcs = ["compile_only_service.cc"], + hdrs = ["compile_only_service.h"], + deps = [ + ":backend", + ":compiler", + ":computation_layout", + ":computation_tracker", + ":platform_util", + ":service", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + cc_library( name = "cpu_plugin", deps = [ @@ -624,6 +644,7 @@ cc_library( "buffer_liveness.h", ], deps = [ + ":call_graph", ":hlo", ":hlo_ordering", ":liveness_util", @@ -664,8 +685,8 @@ cc_library( ], deps = [ ":buffer_liveness", - ":heap_simulator", ":hlo", + ":hlo_ordering", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -705,50 +726,38 @@ cc_test( ], ) -cc_library( - name = "heap_simulator", - srcs = [ - "heap_simulator.cc", - ], - hdrs = [ - "heap_simulator.h", - ], - deps = [ - ":hlo", - ":liveness_util", - ":logical_buffer", - ":tuple_points_to_analysis", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib", - ], -) - cc_test( name = "heap_simulator_test", srcs = ["heap_simulator_test.cc"], deps = [ - ":heap_simulator", ":hlo", + ":hlo_ordering", ":logical_buffer", ":tuple_points_to_analysis", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:lib", "//tensorflow/core:test_main", ], ) +# The hlo_ordering library contains both hlo_ordering and heap_simulator because +# they are mutually dependent. cc_library( name = "hlo_ordering", srcs = [ + "heap_simulator.cc", "hlo_ordering.cc", ], hdrs = [ + "heap_simulator.h", "hlo_ordering.h", ], deps = [ - ":heap_simulator", + ":call_graph", ":hlo", + ":liveness_util", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -858,7 +867,9 @@ cc_library( ":hlo_pass", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", ], ) @@ -1190,6 +1201,7 @@ cc_library( ":buffer_liveness", ":hlo", ":hlo_pass", + ":liveness_util", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:status_macros", @@ -1254,6 +1266,7 @@ cc_library( ":hlo_cost_analysis", ":hlo_dce", ":hlo_ordering", + ":liveness_util", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -1271,6 +1284,7 @@ cc_test( deps = [ ":cpu_plugin", ":hlo", + ":hlo_matchers", ":hlo_ordering", ":hlo_rematerialization", "//tensorflow/compiler/xla:shape_util", @@ -1384,6 +1398,7 @@ cc_test( ":cpu_plugin", ":hlo", ":hlo_cse", + ":hlo_matchers", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", @@ -1412,6 +1427,28 @@ cc_library( ], ) +cc_test( + name = "hlo_constant_folding_test", + srcs = ["hlo_constant_folding_test.cc"], + deps = [ + ":cpu_plugin", + ":hlo", + ":hlo_constant_folding", + ":hlo_matchers", + ":hlo_pass", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "device_memory_allocator", srcs = ["device_memory_allocator.cc"], diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 6acb9bdcbac2e79538d14d94003e15d11058f1a9..3f888b4c2e378bd88fcafa02171fef52ccd758f9 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1,3 +1,4 @@ + /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -51,6 +52,16 @@ bool IsLiteralWithValue(const HloInstruction* operand, int8 value) { LiteralUtil::IsAll(operand->literal(), value); } +bool IsAll(const HloInstruction* op, int8 value) { + if (IsLiteralWithValue(op, value)) { + return true; + } + if (op->opcode() == HloOpcode::kBroadcast && IsAll(op->operand(0), value)) { + return true; + } + return false; +} + // Returns whether the given transpose produces a result which is bit-wise // identical to its operand and thus may be replaced with a bitcast. bool TransposeIsBitcast(const HloInstruction* transpose) { @@ -150,9 +161,17 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { tensorflow::gtl::ArraySlice dimensions, HloComputation* function) override; + Status HandleReduceWindow(HloInstruction* reduce_window, + HloInstruction* operand, const Window& window, + HloComputation* function) override; + Status HandleReverse(HloInstruction* reverse, HloInstruction* operand) override; Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override; + Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice, + HloInstruction* operand, + HloInstruction* update, + HloInstruction* start_indices) override; Status HandleTranspose(HloInstruction* transpose) override; @@ -214,6 +233,29 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { StatusOr TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand( HloInstruction* reshape_or_broadcast); + // Replaces the existing HLO instruction old_instruction, with + // new_instruction, and marks the optimizer status as changed. + // Returns the Status representing the result of the replace operation. + Status ReplaceWithNewInstruction( + HloInstruction* old_instruction, + std::unique_ptr new_instruction) { + TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( + old_instruction, std::move(new_instruction))); + changed_ = true; + return Status::OK(); + } + + // Replaces the existing HLO instruction old_instruction, with + // new_instruction, and marks the optimizer status as changed. + // Returns the Status representing the result of the replace operation. + Status ReplaceInstruction(HloInstruction* old_instruction, + HloInstruction* new_instruction) { + TF_RETURN_IF_ERROR( + computation_->ReplaceInstruction(old_instruction, new_instruction)); + changed_ = true; + return Status::OK(); + } + // Current HloComputation instance the AlgebraicSimplifierVisitor is // traversing. HloComputation* computation_; @@ -262,8 +304,7 @@ void AlgebraicSimplifierVisitor::ReplaceWithBitcast( auto bitcast = computation_->AddInstruction( HloInstruction::CreateUnary(instruction->shape(), HloOpcode::kBitcast, instruction->mutable_operand(0))); - TF_CHECK_OK(computation_->ReplaceInstruction(instruction, bitcast)); - changed_ = true; + TF_CHECK_OK(ReplaceInstruction(instruction, bitcast)); } bool AlgebraicSimplifierVisitor::ReplaceInstructionIfSameShape( @@ -271,9 +312,7 @@ bool AlgebraicSimplifierVisitor::ReplaceInstructionIfSameShape( if (!SameShape(old_instruction, new_instruction)) { return false; } - TF_CHECK_OK( - computation_->ReplaceInstruction(old_instruction, new_instruction)); - changed_ = true; + TF_CHECK_OK(ReplaceInstruction(old_instruction, new_instruction)); return true; } @@ -282,12 +321,12 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add, HloInstruction* rhs) { // A + 0 => A VLOG(10) << "trying transform [A + 0 => A]: " << add->ToString(); - if (IsLiteralWithValue(rhs, 0) && ReplaceInstructionIfSameShape(add, lhs)) { + if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(add, lhs)) { return Status::OK(); } // 0 + A => A VLOG(10) << "trying transform [0 + A => A]: " << add->ToString(); - if (IsLiteralWithValue(lhs, 0) && ReplaceInstructionIfSameShape(add, rhs)) { + if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(add, rhs)) { return Status::OK(); } @@ -304,9 +343,32 @@ Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy, Status AlgebraicSimplifierVisitor::HandleConcatenate( HloInstruction* concatenate, tensorflow::gtl::ArraySlice operands) { - // Unary concatenates are useless. if (operands.size() == 1) { + // Unary concatenates are useless. ReplaceInstructionIfSameShape(concatenate, operands[0]); + return Status::OK(); + } + // Filter out and remove empty operands. + std::vector nonempty_operands; + for (HloInstruction* operand : operands) { + if (!ShapeUtil::HasZeroElements(operand->shape())) { + nonempty_operands.push_back(operand); + } + } + if (nonempty_operands.size() < operands.size()) { + HloInstruction* replacement; + if (nonempty_operands.empty()) { + replacement = operands[0]; + } else if (nonempty_operands.size() == 1) { + replacement = nonempty_operands[0]; + } else { + replacement = + computation_->AddInstruction(concatenate->CloneWithNewOperands( + concatenate->shape(), nonempty_operands)); + } + VLOG(10) << "trying to replace " << concatenate->ToString() << " with " + << replacement->ToString(); + ReplaceInstructionIfSameShape(concatenate, replacement); } return Status::OK(); } @@ -316,7 +378,7 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub, HloInstruction* rhs) { // A - 0 => A VLOG(10) << "trying transform [A - 0 => A]: " << sub->ToString(); - if (IsLiteralWithValue(rhs, 0) && ReplaceInstructionIfSameShape(sub, lhs)) { + if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(sub, lhs)) { return Status::OK(); } @@ -328,8 +390,7 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide, HloInstruction* rhs) { // A/1 => A VLOG(10) << "trying transform [A/1 => A]: " << divide->ToString(); - if (IsLiteralWithValue(rhs, 1) && - ReplaceInstructionIfSameShape(divide, lhs)) { + if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(divide, lhs)) { return Status::OK(); } @@ -340,8 +401,7 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide, computation_->AddInstruction(HloInstruction::CreateBinary( divide->shape(), HloOpcode::kSubtract, lhs->mutable_operand(0), rhs->mutable_operand(0))); - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( divide, HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp, subtract)); } @@ -368,8 +428,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot, ShapeUtil::HasZeroElements(rhs->shape())) { auto zero = computation_->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {})); } @@ -378,8 +437,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot, auto new_dot = computation_->AddInstruction(HloInstruction::CreateBinary( ShapeUtil::PermuteDimensions({1, 0}, dot->shape()), HloOpcode::kDot, rhs->mutable_operand(0), lhs->mutable_operand(0))); - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0})); } @@ -387,8 +445,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot, // // A dot(a[M, 1], b[1, N]) = multiply(a [M,1], b [1, N]) if (ShapeUtil::Rank(rhs->shape()) == 2 && rhs->shape().dimensions(0) == 1) { - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( dot, HloInstruction::CreateBinary(dot->shape(), HloOpcode::kMultiply, lhs, rhs)); } @@ -412,8 +469,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot, auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce( ShapeUtil::MakeShape(dot->shape().element_type(), {}), multiply, zero, {0}, add_reduce_computation)); - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( dot, HloInstruction::CreateReshape(dot->shape(), reduce)); } @@ -452,8 +508,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot, {rhs->shape().dimensions(1)}), multiply, zero, {0}, add_reduce_computation)); } - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( dot, HloInstruction::CreateReshape(dot->shape(), reduce)); } @@ -479,8 +534,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot, ShapeUtil::MakeShape(dot->shape().element_type(), {lhs->shape().dimensions(0)}), multiply, zero, {1}, add_reduce_computation)); - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( dot, HloInstruction::CreateReshape(dot->shape(), reduce)); } return Status::OK(); @@ -491,14 +545,12 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply, HloInstruction* rhs) { // A*1 => A VLOG(10) << "trying transform [A*1 => A]: " << multiply->ToString(); - if (IsLiteralWithValue(rhs, 1) && - ReplaceInstructionIfSameShape(multiply, lhs)) { + if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(multiply, lhs)) { return Status::OK(); } // 1*A => A VLOG(10) << "trying transform [1*A => A]: " << multiply->ToString(); - if (IsLiteralWithValue(lhs, 1) && - ReplaceInstructionIfSameShape(multiply, rhs)) { + if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(multiply, rhs)) { return Status::OK(); } return Status::OK(); @@ -619,8 +671,7 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { ShapeUtil::ElementsIn(operand->shape())) { VLOG(10) << "transform broadcast(X) -> reshape(X) where " "n(broadcast(X)) == n(X)"; - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( broadcast, HloInstruction::CreateReshape(broadcast->shape(), operand)); } @@ -632,8 +683,7 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { ShapeUtil::ElementsIn(operand->shape())) { VLOG(10) << "transform broadcast(X) -> transpose(X) where " "n(broadcast(X)) == n(X)"; - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( broadcast, HloInstruction::CreateTranspose(broadcast->shape(), operand, broadcast->dimensions())); } @@ -653,8 +703,7 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { for (auto inserted_index : inserted_indices) { dims.erase(dims.begin() + inserted_index); } - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( broadcast, HloInstruction::CreateBroadcast(broadcast->shape(), operand->mutable_operand(0), dims)); @@ -697,65 +746,6 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { return Status::OK(); } -template -static std::unique_ptr ConvertIfTypesMatch( - const Literal& src_literal) { - CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); - - return HloInstruction::CreateConstant( - LiteralUtil::Convert::type, - typename primitive_util::PrimitiveTypeToNative< - primitive_dest_type>::type>(src_literal)); -} - -template -static std::unique_ptr ConvertIfDestTypeMatches( - const Literal& src_literal, PrimitiveType primitive_dest_type) { - switch (primitive_dest_type) { -#define CONVERT_IF_TYPES_MATCH(type) \ - case (type): \ - return ConvertIfTypesMatch(src_literal); - CONVERT_IF_TYPES_MATCH(PRED) - CONVERT_IF_TYPES_MATCH(S8) - CONVERT_IF_TYPES_MATCH(S32) - CONVERT_IF_TYPES_MATCH(S64) - CONVERT_IF_TYPES_MATCH(U8) - CONVERT_IF_TYPES_MATCH(U32) - CONVERT_IF_TYPES_MATCH(U64) - CONVERT_IF_TYPES_MATCH(F32) - CONVERT_IF_TYPES_MATCH(F64) -#undef CONVERT_IF_TYPES_MATCH - // Other types are not yet supported. - default: - LOG(FATAL) << "Unimplemented: ConvertIfDestTypeMatches for type " - << PrimitiveType_Name(src_literal.shape().element_type()); - } -} - -static std::unique_ptr ConvertIfSrcTypeMatches( - const Literal& src_literal, PrimitiveType primitive_dest_type) { - switch (src_literal.shape().element_type()) { -#define CONVERT_IF_DEST_TYPE_MATCHES(type) \ - case (type): \ - return ConvertIfDestTypeMatches<(type)>(src_literal, primitive_dest_type); - CONVERT_IF_DEST_TYPE_MATCHES(PRED) - CONVERT_IF_DEST_TYPE_MATCHES(S8) - CONVERT_IF_DEST_TYPE_MATCHES(S32) - CONVERT_IF_DEST_TYPE_MATCHES(S64) - CONVERT_IF_DEST_TYPE_MATCHES(U8) - CONVERT_IF_DEST_TYPE_MATCHES(U32) - CONVERT_IF_DEST_TYPE_MATCHES(U64) - CONVERT_IF_DEST_TYPE_MATCHES(F32) - CONVERT_IF_DEST_TYPE_MATCHES(F64) -#undef CONVERT_IF_DEST_TYPE_MATCHES - // Other types are not yet supported. - default: - LOG(FATAL) << "Unimplemented: ConvertIfSrcTypeMatches for type " - << PrimitiveType_Name(src_literal.shape().element_type()); - } -} - // A conversion to the same element type as the operand is a nop and can be // removed. A conversion of a constant can be simplified by making a new // constant. @@ -764,16 +754,7 @@ Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert, PrimitiveType src_type = operand->shape().element_type(); PrimitiveType dest_type = convert->shape().element_type(); if (src_type == dest_type) { - changed_ = true; - return computation_->ReplaceInstruction(convert, operand); - } - if (operand->opcode() == HloOpcode::kConstant) { - const Literal& src_literal = operand->literal(); - std::unique_ptr new_constant = - ConvertIfSrcTypeMatches(src_literal, dest_type); - changed_ = true; - return computation_->ReplaceWithNewInstruction(convert, - std::move(new_constant)); + return ReplaceInstruction(convert, operand); } return Status::OK(); } @@ -859,8 +840,7 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { std::unique_ptr slice = HloInstruction::CreateSlice( pad->shape(), nonzero_pad, start_indices, end_indices); - changed_ = true; - return computation_->ReplaceWithNewInstruction(pad, std::move(slice)); + return ReplaceWithNewInstruction(pad, std::move(slice)); } return Status::OK(); @@ -870,7 +850,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power, HloInstruction* lhs, HloInstruction* rhs) { VLOG(10) << "trying transform [pow(A, 0) => 1]: " << power->ToString(); - if (IsLiteralWithValue(rhs, 0)) { + if (IsAll(rhs, 0)) { auto one = HloInstruction::CreateConstant(LiteralUtil::CloneToUnique( LiteralUtil::One(power->shape().element_type()))); std::unique_ptr ones; @@ -880,30 +860,27 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power, ones = HloInstruction::CreateBroadcast( power->shape(), computation_->AddInstruction(std::move(one)), {}); } - changed_ = true; - return computation_->ReplaceWithNewInstruction(power, std::move(ones)); + return ReplaceWithNewInstruction(power, std::move(ones)); } VLOG(10) << "trying transform [pow(A, 1) => A]: " << power->ToString(); - if (IsLiteralWithValue(rhs, 1) && ReplaceInstructionIfSameShape(power, lhs)) { + if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(power, lhs)) { return Status::OK(); } VLOG(10) << "trying transform [pow(A, 2) => A*A]: " << power->ToString(); - if (IsLiteralWithValue(rhs, 2)) { - changed_ = true; - return computation_->ReplaceWithNewInstruction( + if (IsAll(rhs, 2)) { + return ReplaceWithNewInstruction( power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kMultiply, lhs, lhs)); } VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString(); - if (IsLiteralWithValue(rhs, -1)) { + if (IsAll(rhs, -1)) { auto* one = computation_->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CloneToUnique( LiteralUtil::One(rhs->shape().element_type())))); - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kDivide, one, lhs)); } @@ -984,14 +961,12 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { // Delete no-op reshapes, i.e. where shape = operand shape. if (SameShape(reshape, operand)) { VLOG(10) << "deleting no-op reshape"; - changed_ = true; - return computation_->ReplaceInstruction(reshape, operand); + return ReplaceInstruction(reshape, operand); } // Merge reshapes. if (HloOpcode::kReshape == operand->opcode()) { - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( reshape, HloInstruction::CreateReshape(reshape->shape(), operand->mutable_operand(0))); } @@ -1000,8 +975,7 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { auto opt_dims = ReshapeLeavesDimensionsUnmodified( reshape, reshape->operand(0)->dimensions()); if (opt_dims.first) { - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( reshape, HloInstruction::CreateBroadcast( reshape->shape(), reshape->mutable_operand(0)->mutable_operand(0), @@ -1037,8 +1011,7 @@ Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse, }; if (std::all_of(reverse->dimensions().begin(), reverse->dimensions().end(), dim_is_one)) { - changed_ = true; - return computation_->ReplaceInstruction(reverse, operand); + return ReplaceInstruction(reverse, operand); } return Status::OK(); } @@ -1052,12 +1025,22 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice, return Status::OK(); } +Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( + HloInstruction* dynamic_update_slice, HloInstruction* operand, + HloInstruction* update, HloInstruction* start_indices) { + // DynamicUpdateSlice on a scalar just passes through the update argument. + if (ShapeUtil::IsScalar(dynamic_update_slice->shape())) { + return ReplaceInstruction(dynamic_update_slice, update); + } + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleReduce( HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, tensorflow::gtl::ArraySlice dimensions, HloComputation* function) { if (ShapeUtil::HasZeroElements(arg->shape()) || ShapeUtil::HasZeroElements(reduce->shape())) { - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( reduce, HloInstruction::CreateBroadcast(reduce->shape(), init_value, {})); return Status::OK(); @@ -1070,7 +1053,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce( for (auto dim : dimensions) { new_reduce_dimensions.push_back(transpose_dimensions[dim]); } - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( reduce, HloInstruction::CreateReduce( reduce->shape(), arg->mutable_operand(0), init_value, new_reduce_dimensions, function)); @@ -1114,7 +1097,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce( new_reduce_dimensions.push_back(i); } } - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( reduce, HloInstruction::CreateReduce( reduce->shape(), arg->mutable_operand(0), init_value, new_reduce_dimensions, function)); @@ -1125,27 +1108,84 @@ Status AlgebraicSimplifierVisitor::HandleReduce( ShapeUtil::HasZeroElements(arg->shape())) { auto reshape = computation_->AddInstruction( HloInstruction::CreateReshape(reduce->shape(), arg)); - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( reduce, HloInstruction::CreateMap(reduce->shape(), {reshape, init_value}, function)); } return Status::OK(); } +Status AlgebraicSimplifierVisitor::HandleReduceWindow( + HloInstruction* reduce_window, HloInstruction* operand, + const Window& window, HloComputation* function) { + VLOG(10) << "Considering folding Pad: " << operand->ToString() + << "\ninto reduce-window: " << reduce_window->ToString(); + + // This optimization folds a pad op into reduce_window. + if (operand->opcode() != HloOpcode::kPad) { + VLOG(10) << "Not folding pad into reduce-window as there is no pad."; + return Status::OK(); + } + + // Do not fold interior padding into ReduceWindow since the backends do not + // support it. + const PaddingConfig& pad_config = operand->padding_config(); + if (HasInteriorPadding(pad_config)) { + VLOG(10) << "Not folding pad into reduce-window due to interior padding."; + return Status::OK(); + } + + // If reduce_window already has padding, the pad value of the pad op and the + // init value of reduce_window must match to allow folding the pad. + const HloInstruction* pad_value = operand->operand(1); + const HloInstruction* reduce_init_value = reduce_window->operand(1); + if (pad_value != reduce_init_value) { + // The pad value is usually a constant, so we handle that case and do not + // try to get more fancy about proving equivalence in cases beyond that. + if (pad_value->opcode() != HloOpcode::kConstant || + reduce_init_value->opcode() != HloOpcode::kConstant || + !LiteralUtil::Equal(pad_value->literal(), + reduce_init_value->literal())) { + VLOG(10) + << "Not folding pad into reduce-window due to different pad values."; + return Status::OK(); + } + } + + // Carry out the folding of the pad into reduce_window. + VLOG(10) << "Folding pad into reduce-window."; + Window new_window = window; + const int64 rank = ShapeUtil::Rank(reduce_window->shape()); + TF_RET_CHECK(pad_config.dimensions_size() == rank); + TF_RET_CHECK(window.dimensions_size() == rank); + for (int64 i = 0; i < rank; ++i) { + const auto& pad_dim = pad_config.dimensions(i); + auto& window_dim = *new_window.mutable_dimensions(i); + window_dim.set_padding_low(window_dim.padding_low() + + pad_dim.edge_padding_low()); + window_dim.set_padding_high(window_dim.padding_high() + + pad_dim.edge_padding_high()); + } + return ReplaceWithNewInstruction( + reduce_window, HloInstruction::CreateReduceWindow( + /*shape=*/reduce_window->shape(), + /*operand=*/operand->mutable_operand(0), + /*init_value=*/reduce_window->mutable_operand(1), + /*window=*/new_window, + /*reduce_computation=*/function)); +} + Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { auto operand = transpose->mutable_operand(0); if (std::is_sorted(transpose->dimensions().begin(), transpose->dimensions().end())) { VLOG(10) << "deleting no-op transpose"; - changed_ = true; - return computation_->ReplaceInstruction(transpose, operand); + return ReplaceInstruction(transpose, operand); } if (HloOpcode::kTranspose == operand->opcode()) { - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( transpose, HloInstruction::CreateTranspose( transpose->shape(), operand->mutable_operand(0), ComposePermutations(operand->dimensions(), @@ -1272,9 +1312,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( auto new_rhs = add_bitcast(new_filter_shape, rhs); auto dot = computation_->AddInstruction(HloInstruction::CreateBinary( dot_output_shape, HloOpcode::kDot, new_lhs, new_rhs)); - changed_ = true; - return computation_->ReplaceInstruction(convolution, - add_bitcast(convolution_shape, dot)); + return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)); } bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape( @@ -1288,8 +1326,7 @@ bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape( auto clamp = HloInstruction::CreateTernary(root->shape(), HloOpcode::kClamp, max_operand, operand, min_operand); - TF_CHECK_OK(computation_->ReplaceWithNewInstruction(root, std::move(clamp))); - changed_ = true; + TF_CHECK_OK(ReplaceWithNewInstruction(root, std::move(clamp))); return true; } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 3123ee4f8728a8d76b16bc4b3162962757d3b778..87d8a7165ccfad587474a0c89e9387597e341d8f 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -41,6 +41,7 @@ namespace { AlgebraicSimplifier::ValidBitcastCallback bitcasting_callback() { return [](const Shape&, const Shape&) { return true; }; } + AlgebraicSimplifier::ValidBitcastCallback non_bitcasting_callback() { return [](const Shape&, const Shape&) { return false; }; } @@ -69,6 +70,52 @@ TEST_F(AlgebraicSimplifierTest, AddZero) { EXPECT_EQ(root, param0); } +TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { + Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r2f32, "param0")); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction* bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(r2f32, zero, {0, 1})); + builder.AddInstruction( + HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAdd); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + +TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { + Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r2f32, "param0")); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 0, 0}))); + HloInstruction* bcast = + builder.AddInstruction(HloInstruction::CreateBroadcast(r2f32, zero, {1})); + builder.AddInstruction( + HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAdd); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + // Test that A - 0 is simplified to A TEST_F(AlgebraicSimplifierTest, SubZero) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); @@ -420,115 +467,108 @@ TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) { EXPECT_THAT(computation->root_instruction(), input); } -TEST_F(AlgebraicSimplifierTest, ConvertF32ToS64) { - HloComputation::Builder builder(TestName()); - HloInstruction* input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); - builder.AddInstruction( - HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {}), input)); - - auto module = MakeUnique(TestName()); - auto computation = module->AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), op::Convert(input)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - - EXPECT_THAT(computation->root_instruction(), op::Constant()); - EXPECT_EQ(LiteralUtil::GetFirstElement( - computation->root_instruction()->literal()), - 42); -} - -TEST_F(AlgebraicSimplifierTest, ConvertS64ToF32) { +// Test that copies are removed. +TEST_F(AlgebraicSimplifierTest, RemoveCopy) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); - HloInstruction* input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); builder.AddInstruction( - HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); + HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Convert(input)); + EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Constant()); - EXPECT_EQ(LiteralUtil::GetFirstElement( - computation->root_instruction()->literal()), - 42.0f); + EXPECT_THAT(computation->root_instruction(), param0); } -TEST_F(AlgebraicSimplifierTest, ConvertF32ArrayToS64Array) { +// Test that unary concatenates are removed. +TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) { + Shape r1f32 = ShapeUtil::MakeShape(F32, {100}); HloComputation::Builder builder(TestName()); - HloInstruction* input = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({42.0f, 19.0f}))); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "param0")); builder.AddInstruction( - HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {2}), input)); + HloInstruction::CreateConcatenate(param0->shape(), {param0}, 0)); auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Convert(input)); + EXPECT_THAT(computation->root_instruction(), op::Concatenate(param0)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Constant()); - EXPECT_EQ( - LiteralUtil::Get(computation->root_instruction()->literal(), {0}), - 42); - EXPECT_EQ( - LiteralUtil::Get(computation->root_instruction()->literal(), {1}), - 19); + EXPECT_THAT(computation->root_instruction(), param0); } -// Test that copies are removed. -TEST_F(AlgebraicSimplifierTest, RemoveCopy) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); +// Test that empty operands of concatenates are removed. +TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { + const int kParamLength = 100; + Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32, "param0")); - builder.AddInstruction( - HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); + HloInstruction::CreateParameter(0, r1f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r1f32, "param1")); + HloInstruction* empty_literal = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({}))); + HloInstruction* empty_slice = + builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {0}), param1, {42}, {42})); + Shape result_shape = ShapeUtil::MakeShape(F32, {3 * kParamLength}); + builder.AddInstruction(HloInstruction::CreateConcatenate( + result_shape, {empty_literal, param0, param0, empty_slice, param1}, 0)); auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); + EXPECT_THAT( + computation->root_instruction(), + op::Concatenate(empty_literal, param0, param0, empty_slice, param1)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), param0); + EXPECT_THAT(computation->root_instruction(), + op::Concatenate(param0, param0, param1)); } -// Test that unary concatenates are removed. -TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) { - Shape r1f32 = ShapeUtil::MakeShape(F32, {100}); +// Test a concatenate with only empty operands is removed. +TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) { + const int kParamLength = 100; + Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r1f32, "param0")); - builder.AddInstruction( - HloInstruction::CreateConcatenate(param0->shape(), {param0}, 0)); + HloInstruction* empty_literal = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({}))); + HloInstruction* empty_slice = + builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {0}), param0, {42}, {42})); + Shape result_shape = ShapeUtil::MakeShape(F32, {0}); + builder.AddInstruction(HloInstruction::CreateConcatenate( + result_shape, {empty_literal, empty_slice}, 0)); auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Concatenate(param0)); + EXPECT_THAT(computation->root_instruction(), + op::Concatenate(empty_literal, empty_slice)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), param0); + EXPECT_EQ(computation->root_instruction(), empty_literal); } // Test that a simplification which changes layouts is not performed if layout @@ -1508,6 +1548,86 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reshape_shape)); } +// Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x). +TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { + HloModule module(TestName()); + HloComputation::Builder builder(TestName()); + + // Create operand to the pad. + HloInstruction* operand = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 2, 3, 4}), "p0")); + + // Create the pad. + PaddingConfig padding = MakeNoPaddingConfig(4); + padding.mutable_dimensions(1)->set_edge_padding_low(1); + padding.mutable_dimensions(3)->set_edge_padding_high(2); + + HloInstruction* pad_value = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f))); + HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( + ShapeUtil::MakeShape(F32, {1, 3, 3, 5}), operand, pad_value, padding)); + + // Create add computation. + HloComputation* add_computation = nullptr; + { + HloComputation::Builder builder(TestName() + ".add"); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); + add_computation = module.AddEmbeddedComputation(builder.Build()); + } + + // Create the reduce-window. + Window window; + for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) { + auto* dim = window.add_dimensions(); + dim->set_size(1); + dim->set_padding_low(10); + dim->set_padding_high(100); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + } + const Shape reduce_window_shape = + ShapeUtil::MakeShape(F32, {111, 113, 113, 115}); + HloInstruction* reduce_init_value = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f))); + HloInstruction* reduce_window = + builder.AddInstruction(HloInstruction::CreateReduceWindow( + reduce_window_shape, pad, reduce_init_value, window, + add_computation)); + + // Build the computation and run the simplifier. + auto computation = module.AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root, reduce_window); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + + // Running simplification again should not result in any further changes. + ASSERT_FALSE(simplifier.Run(&module).ValueOrDie()); + + // Verify the result + root = computation->root_instruction(); + EXPECT_THAT(root, op::ReduceWindow(operand, op::Constant())); + EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reduce_window_shape)) + << ShapeUtil::HumanString(root->shape()) << " vs " + << ShapeUtil::HumanString(reduce_window_shape); + EXPECT_EQ(root->window().dimensions(0).padding_low(), 10); + EXPECT_EQ(root->window().dimensions(1).padding_low(), 11); + EXPECT_EQ(root->window().dimensions(2).padding_low(), 10); + EXPECT_EQ(root->window().dimensions(3).padding_low(), 10); + EXPECT_EQ(root->window().dimensions(0).padding_high(), 100); + EXPECT_EQ(root->window().dimensions(1).padding_high(), 100); + EXPECT_EQ(root->window().dimensions(2).padding_high(), 100); + EXPECT_EQ(root->window().dimensions(3).padding_high(), 102); +} + TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) { HloComputation::Builder builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {448, 2048, 1, 1}); diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index 49a621810ef26f76494bab08d087bb4a07472000..83759a7a0c62222b81b82b8a0f8e0396a8f17eff 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -64,8 +64,9 @@ GlobalDataHandle AllocationTracker::RegisterInternal( auto& allocation = FindOrDie(handle_to_allocation_, handle); int ref_count = allocation->ref_count(); CHECK_GT(ref_count, 0); - VLOG(2) << "ref_count: " << ref_count << " -> " << ref_count + 1; - allocation->increment_ref_count(); + VLOG(2) << "ref_count: " << ref_count << " -> " << + (ref_count + initial_ref_count); + allocation->increment_ref_count(initial_ref_count); } else { handle = next_handle_++; VLOG(2) << "ref_count: " << initial_ref_count; diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h index e00768001620275d702c2f96a89d981526ea81a7..ebbf35b6fe87bc7322ccb99cfe8f8eed56de06b3 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.h +++ b/tensorflow/compiler/xla/service/allocation_tracker.h @@ -63,10 +63,10 @@ class Allocation { CHECK_GE(ref_count_, 0); return ref_count_; } - void increment_ref_count() { + void increment_ref_count(int inc) { CHECK_GT(ref_count_, 0); - CHECK_LT(ref_count_, INT_MAX); - ++ref_count_; + CHECK_LE(ref_count_, INT_MAX - inc); + ref_count_ += inc; } void decrement_ref_count() { CHECK_GT(ref_count_, 0); diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index 5c05417c6dcb887b5352d1270c24a4eae62149e3..1913617fecf757a529bbdc803b4227a560c6e1cf 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -41,13 +41,39 @@ namespace se = ::perftools::gputools; namespace xla { +BackendOptions& BackendOptions::set_platform( + perftools::gputools::Platform* platform) { + platform_ = platform; + return *this; +} + +perftools::gputools::Platform* BackendOptions::platform() const { + return platform_; +} + +BackendOptions& BackendOptions::set_number_of_replicas(int number_of_replicas) { + number_of_replicas_ = number_of_replicas; + return *this; +} + +int BackendOptions::number_of_replicas() const { return number_of_replicas_; } + +BackendOptions& BackendOptions::set_intra_op_parallelism_threads( + int num_threads) { + intra_op_parallelism_threads_ = num_threads; + return *this; +} + +int BackendOptions::intra_op_parallelism_threads() const { + return intra_op_parallelism_threads_; +} + // Define this in .cc file to avoid having to include eigen or forward declare // these types in the header. struct Backend::EigenThreadPoolWrapper { - explicit EigenThreadPoolWrapper() - : pool(new tensorflow::thread::ThreadPool( - tensorflow::Env::Default(), "XLAEigen", - tensorflow::port::NumSchedulableCPUs())), + explicit EigenThreadPoolWrapper(const int num_threads) + : pool(new tensorflow::thread::ThreadPool(tensorflow::Env::Default(), + "XLAEigen", num_threads)), wrapper(new tensorflow::EigenThreadPoolWrapper(pool.get())), device(new Eigen::ThreadPoolDevice(wrapper.get(), wrapper->NumThreads())) {} @@ -58,18 +84,21 @@ struct Backend::EigenThreadPoolWrapper { }; /* static */ StatusOr> Backend::CreateBackend( - perftools::gputools::Platform* platform, int64 replica_count) { + const BackendOptions& options) { + int64 replica_count = options.number_of_replicas(); if (replica_count == -1) { legacy_flags::BackendFlags* flags = legacy_flags::GetBackendFlags(); replica_count = flags->xla_replicas; } + perftools::gputools::Platform* platform = options.platform(); TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform)); TF_ASSIGN_OR_RETURN(auto stream_executors, PlatformUtil::GetStreamExecutors(platform)); TF_ASSIGN_OR_RETURN(auto transfer_manager, TransferManager::GetForPlatform(platform)); - std::unique_ptr backend(new Backend( - replica_count, platform, compiler, stream_executors, transfer_manager)); + std::unique_ptr backend( + new Backend(replica_count, platform, compiler, stream_executors, + transfer_manager, options.intra_op_parallelism_threads())); TF_RETURN_IF_ERROR(backend->PoolStreams(kInitialStreamsToPool, backend->default_stream_executor())); return std::move(backend); @@ -79,7 +108,9 @@ struct Backend::EigenThreadPoolWrapper { Backend::CreateDefaultBackend() { TF_ASSIGN_OR_RETURN(se::Platform * platform, PlatformUtil::GetDefaultPlatform()); - return CreateBackend(platform); + BackendOptions backend_options; + backend_options.set_platform(platform); + return CreateBackend(backend_options); } tensorflow::Status Backend::PoolStreams(int n, se::StreamExecutor* executor) { @@ -114,7 +145,7 @@ Backend::Backend( int64 replica_count, perftools::gputools::Platform* platform, Compiler* compiler, tensorflow::gtl::ArraySlice stream_executors, - TransferManager* transfer_manager) + TransferManager* transfer_manager, int intra_op_parallelism_threads) : platform_(platform), compiler_(compiler), transfer_manager_(transfer_manager), @@ -144,7 +175,11 @@ Backend::Backend( inter_op_thread_pool_.reset(new tensorflow::thread::ThreadPool( tensorflow::Env::Default(), "xla_inter_op", tensorflow::port::NumSchedulableCPUs())); - intra_op_thread_pool_wrapper_.reset(new EigenThreadPoolWrapper()); + const int num_threads = intra_op_parallelism_threads > 0 + ? intra_op_parallelism_threads + : tensorflow::port::NumSchedulableCPUs(); + intra_op_thread_pool_wrapper_.reset( + new EigenThreadPoolWrapper(num_threads)); } } @@ -190,10 +225,17 @@ tensorflow::thread::ThreadPool* Backend::inter_op_thread_pool() const { const Eigen::ThreadPoolDevice* Backend::eigen_intra_op_thread_pool_device() const { - if (intra_op_thread_pool_wrapper_ == nullptr) return nullptr; + if (intra_op_thread_pool_wrapper_ == nullptr) { + return nullptr; + } return intra_op_thread_pool_wrapper_->device.get(); } +tensorflow::thread::ThreadPool* Backend::eigen_intra_op_thread_pool() const { + if (intra_op_thread_pool_wrapper_ == nullptr) return nullptr; + return intra_op_thread_pool_wrapper_->pool.get(); +} + StatusOr Backend::stream_executor( int device_ordinal) const { if (device_ordinal < 0 || diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index 9f6829b7d937cec6a67d4016a40506de5df8572d..1068bac2779e9a3dc6c23c0b9fbcc5403fcc2815 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -39,6 +39,31 @@ struct ThreadPoolDevice; namespace xla { +// Options to configure the backend when it is created. +class BackendOptions { + public: + // Set the platform backing the backend, or nullptr for the default platform. + BackendOptions& set_platform(perftools::gputools::Platform* platform); + perftools::gputools::Platform* platform() const; + + // Set the number of replicas to use when compiling replicated + // programs. The default is -1 meaning that the value is read from + // the xla_replicas flag. + BackendOptions& set_number_of_replicas(int number_of_replicas); + int number_of_replicas() const; + + // Sets the thread pool size for parallel execution of an individual operator. + // The default value of -1 will result in initializing the thread pool with + // the number of threads equal to the number of cores in the system. + BackendOptions& set_intra_op_parallelism_threads(int num_threads); + int intra_op_parallelism_threads() const; + + private: + perftools::gputools::Platform* platform_ = nullptr; + int number_of_replicas_ = -1; + int intra_op_parallelism_threads_ = -1; +}; + // Class which encapsulates an XLA backend. It includes everything necessary // to compile and execute computations on a particular platform. // @@ -53,9 +78,9 @@ class Backend { static constexpr int kInitialStreamsToPool = 8; // Creates a new backend for the given platform with the given number of - // replicas. A value of -1 means to use the flag value. + // replicas. static StatusOr> CreateBackend( - perftools::gputools::Platform* platform, int64 replica_count = -1); + const BackendOptions& options); // Creates a backend for the default platform. The default platform is defined // in PlatformUtil. @@ -150,6 +175,7 @@ class Backend { // For the host platform, returns the configured eigen threadpool device to be // used for scheduling work. For other platforms, returns NULL. const Eigen::ThreadPoolDevice* eigen_intra_op_thread_pool_device() const; + tensorflow::thread::ThreadPool* eigen_intra_op_thread_pool() const; // Resets the devices associated with this backend. Status ResetDevices(); @@ -160,7 +186,7 @@ class Backend { Compiler* compiler, tensorflow::gtl::ArraySlice stream_executors, - TransferManager* transfer_manager); + TransferManager* transfer_manager, int intra_op_parallelism_threads); Backend(const Backend&) = delete; Backend& operator=(const Backend&) = delete; diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 6efa73b1211da9d41c502818a0bc570fa7773fc6..47560fefea855fa7f70ef6268252a9b6d9964f76 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -488,11 +488,9 @@ Status GatherComputationsByAllocationType( /* static */ StatusOr> BufferAssigner::Run( const HloModule* module, std::unique_ptr hlo_ordering, - LogicalBuffer::SizeFunction buffer_size, int64 alignment, - const std::vector* hlos_to_allocate) { + LogicalBuffer::SizeFunction buffer_size, int64 alignment) { BufferAssigner assigner(std::move(buffer_size), alignment); - return assigner.CreateAssignment(module, std::move(hlo_ordering), - hlos_to_allocate); + return assigner.CreateAssignment(module, std::move(hlo_ordering)); } bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, @@ -545,24 +543,22 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, Status BufferAssigner::AssignBuffersForComputation( const HloComputation* computation, bool is_thread_local, - const FlatSet* hlos_to_allocate, const FlatSet& colocated_buffers, const FlatSet& colocated_allocations, + FlatMap>* + buffers_to_assign_sequentially, BufferAssignment* assignment) { // Buffers are sorted and assigned to BufferAllocations in decreasing order of // size. std::vector sorted_buffers; for (auto& instruction : computation->instructions()) { - if (hlos_to_allocate == nullptr || - hlos_to_allocate->count(instruction.get()) > 0) { - // Add all buffers which this instruction defines. Instruction which don't - // define buffers (eg, bitcast which just forwards a pointer) don't need - // any allocations. - for (const LogicalBuffer* buffer : - assignment->points_to_analysis().GetBuffersDefinedByInstruction( - instruction.get())) { - sorted_buffers.push_back(buffer); - } + // Add all buffers which this instruction defines. Instruction which don't + // define buffers (eg, bitcast which just forwards a pointer) don't need + // any allocations. + for (const LogicalBuffer* buffer : + assignment->points_to_analysis().GetBuffersDefinedByInstruction( + instruction.get())) { + sorted_buffers.push_back(buffer); } } @@ -578,9 +574,16 @@ Status BufferAssigner::AssignBuffersForComputation( // If there is a sequential instruction ordering, we'll delay assignment of // temp buffers until after the main assignment loop. const BufferLiveness& liveness = assignment->liveness(); - const std::vector* sequential_order = - liveness.hlo_ordering().SequentialOrder(*computation); - FlatSet unassigned_temp_buffers; + const bool has_sequential_order = + liveness.hlo_ordering().SequentialOrder(*computation) != nullptr; + if (has_sequential_order && buffers_to_assign_sequentially != nullptr) { + // Every sequential computation must get an entry in the + // buffers_to_assign_sequentially map, even if we end up with an empty set + // of buffers. This ensures we can correctly determine whether to run + // whole-module heap simulation. + buffers_to_assign_sequentially->emplace(computation, + FlatSet()); + } // Sort the LogicalBuffers first by size. We assign the larger LogicalBuffers // first for simplicity. This means any previously created BufferAllocation is @@ -599,7 +602,7 @@ Status BufferAssigner::AssignBuffersForComputation( // important reuse case where an elementwise instruction reuses one of its // operand's buffer. This improves locality. std::sort(sorted_buffers.begin(), sorted_buffers.end(), - [this, sequential_order, &liveness, &post_order_position]( + [this, has_sequential_order, &liveness, &post_order_position]( const LogicalBuffer* a, const LogicalBuffer* b) { // Primary sort is by decreasing buffer size. const int64 a_size = buffer_size_(*a); @@ -609,7 +612,7 @@ Status BufferAssigner::AssignBuffersForComputation( } // Otherwise live out buffers come before others, if the // instructions are sequentially ordered. - if (sequential_order != nullptr) { + if (has_sequential_order) { const bool a_live_out = liveness.MaybeLiveOut(*a); const bool b_live_out = liveness.MaybeLiveOut(*b); if (a_live_out != b_live_out) { @@ -746,7 +749,7 @@ Status BufferAssigner::AssignBuffersForComputation( } } - if (!assignment->HasAllocation(*buffer) && sequential_order != nullptr && + if (!assignment->HasAllocation(*buffer) && has_sequential_order && !liveness.MaybeLiveOut(*buffer)) { // There is a sequential instruction ordering, so we delay assignment of // temp buffers until after the loop. We do this right before we decide to @@ -758,7 +761,7 @@ Status BufferAssigner::AssignBuffersForComputation( // for the definition of temp buffers. CHECK(!is_entry_parameter) << *buffer; CHECK(!is_thread_local) << *buffer; - unassigned_temp_buffers.insert(buffer); + (*buffers_to_assign_sequentially)[computation].insert(buffer); VLOG(3) << "Delaying assignment of temp buffer: " << *buffer; continue; } @@ -772,27 +775,68 @@ Status BufferAssigner::AssignBuffersForComputation( } } - if (!unassigned_temp_buffers.empty()) { - TF_RETURN_IF_ERROR(AssignBuffersWithSequentialOrdering( - *sequential_order, unassigned_temp_buffers, *computation, assignment)); - } return Status::OK(); } Status BufferAssigner::AssignBuffersWithSequentialOrdering( - const std::vector& sequence, - const FlatSet& buffers_to_assign, - const HloComputation& computation, BufferAssignment* assignment) { + const FlatMap>& + buffers_to_assign_sequentially, + bool run_whole_module_heap_simulation, BufferAssignment* assignment) { // Run the sequence of instructions through the heap simulator. The heuristic // that seems to give the best results is lazy-best-fit, with all runs of // alloc / free calls sorted in decreasing size order. - TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique( - MakeUnique(alignment_)), - sequence, computation, - assignment->points_to_analysis(), buffer_size_, - &buffers_to_assign)); + const HloOrdering& hlo_ordering = assignment->liveness().hlo_ordering(); + if (run_whole_module_heap_simulation) { + // Run the heap simulation over the whole module. This reduces memory usage, + // since buffers for kCall and kWhile sub-computations are only live for the + // duration of their calling instructions. + VLOG(1) << "Running whole-module heap simulation"; + SequentialHloOrdering::HloModuleSequence module_sequence; + FlatSet all_buffers_to_assign; + for (const auto& pair : buffers_to_assign_sequentially) { + const HloComputation* computation = pair.first; + const FlatSet& buffers_to_assign = pair.second; + const std::vector* instruction_sequence = + hlo_ordering.SequentialOrder(*computation); + CHECK(instruction_sequence != nullptr) << computation->name(); + module_sequence[computation] = *instruction_sequence; + all_buffers_to_assign.insert(buffers_to_assign.begin(), + buffers_to_assign.end()); + } + TF_ASSIGN_OR_RETURN( + const HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique( + MakeUnique(alignment_)), + assignment->module(), module_sequence, + assignment->points_to_analysis(), buffer_size_, + &all_buffers_to_assign)); + AssignBuffersFromHeapSimulator(result, assignment); + } else { + // Run the heap-simulation on a per-computation basis. Buffers for + // sub-computations are assigned disjoint BufferAllocations, assuming the + // worst-case that they may all be live concurrently. + VLOG(1) << "Running per-computation heap simulation"; + for (const auto& pair : buffers_to_assign_sequentially) { + const HloComputation* computation = pair.first; + const FlatSet& buffers_to_assign = pair.second; + const std::vector* instruction_sequence = + hlo_ordering.SequentialOrder(*computation); + CHECK(instruction_sequence != nullptr) << computation->name(); + TF_ASSIGN_OR_RETURN( + const HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique( + MakeUnique(alignment_)), + *computation, *instruction_sequence, + assignment->points_to_analysis(), buffer_size_, + &buffers_to_assign)); + AssignBuffersFromHeapSimulator(result, assignment); + } + } + return Status::OK(); +} + +void BufferAssigner::AssignBuffersFromHeapSimulator( + const HeapSimulator::Result& result, BufferAssignment* assignment) { if (assignment->stats_.preallocated_temp_fragmentation_bytes == -1) { assignment->stats_.preallocated_temp_fragmentation_bytes = result.fragmentation_size; @@ -801,8 +845,6 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( result.fragmentation_size; } - // Use the results of the heap simulator to create one allocation per - // computation, with LogicalBuffers packed to specific offsets. BufferAllocation* allocation = assignment->NewEmptyAllocation( result.heap_size, /*is_thread_local=*/false, /*is_reusable=*/true); for (const auto& buffer_chunk : result.chunk_map) { @@ -810,7 +852,6 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( const HeapSimulator::Chunk& chunk = buffer_chunk.second; assignment->AddAssignment(allocation, buffer, chunk.offset, chunk.size); } - return Status::OK(); } // Adds the 'colocated_set' of buffers to 'colocated_buffer_sets', maintaining @@ -1103,35 +1144,15 @@ void BufferAssigner::AssignColocatedBufferSets( } StatusOr> BufferAssigner::CreateAssignment( - const HloModule* module, std::unique_ptr hlo_ordering, - const std::vector* hlos_to_allocate) { + const HloModule* module, std::unique_ptr hlo_ordering) { TF_ASSIGN_OR_RETURN(std::unique_ptr liveness, BufferLiveness::Run(module, std::move(hlo_ordering))); - std::vector thread_local_computations; - std::vector global_computations; VLOG(1) << "Assigning buffers to module " << module->name(); - if (hlos_to_allocate != nullptr) { - VLOG(3) << "LogicalBuffer assignment restricted to hlos: "; - for (auto hlo : *hlos_to_allocate) { - VLOG(3) << " " << hlo->parent()->name() << "::" << hlo->name(); - } - } - XLA_VLOG_LINES(3, module->ToString()); + XLA_VLOG_LINES(2, module->ToString()); XLA_VLOG_LINES(3, liveness->ToString()); XLA_VLOG_LINES(3, liveness->points_to_analysis().ToString()); - TF_RETURN_IF_ERROR(GatherComputationsByAllocationType( - module, &thread_local_computations, &global_computations)); - - // Set of HLO's to allocate if hlos_to_allocate is given. Passed as a set to - // AssignBuffersForComputation for fast membership testing. - std::unique_ptr> hlo_set; - if (hlos_to_allocate != nullptr) { - hlo_set = MakeUnique>( - hlos_to_allocate->begin(), hlos_to_allocate->end()); - } - // Can't use MakeUnique because BufferAssignment constructor is private. std::unique_ptr assignment( new BufferAssignment(module, std::move(liveness), alignment_)); @@ -1148,16 +1169,38 @@ StatusOr> BufferAssigner::CreateAssignment( AssignColocatedBufferSets(colocated_buffer_sets, assignment.get(), &colocated_buffers, &colocated_allocations); + std::vector thread_local_computations; + std::vector global_computations; + TF_RETURN_IF_ERROR(GatherComputationsByAllocationType( + module, &thread_local_computations, &global_computations)); + + // First assign buffers for global computatations. Temporary buffers for + // sequential computations are collected in 'buffers_to_assign_sequentially'. + FlatMap> + buffers_to_assign_sequentially; for (auto* computation : global_computations) { TF_RETURN_IF_ERROR(AssignBuffersForComputation( - computation, /*is_thread_local=*/false, hlo_set.get(), - colocated_buffers, colocated_allocations, assignment.get())); + computation, /*is_thread_local=*/false, colocated_buffers, + colocated_allocations, &buffers_to_assign_sequentially, + assignment.get())); } + // Assign buffers with sequential ordering, if any. If all global computations + // are sequential, we can run heap simuation on the whole module, which + // reduces memory usage. + const bool run_whole_module_heap_simulation = + buffers_to_assign_sequentially.size() == global_computations.size(); + TF_RETURN_IF_ERROR(AssignBuffersWithSequentialOrdering( + buffers_to_assign_sequentially, run_whole_module_heap_simulation, + assignment.get())); + + // Now assign buffers for thread-local computations. All LogicalBuffers get + // their own BufferAllocation. for (auto* computation : thread_local_computations) { TF_RET_CHECK(computation != module->entry_computation()); TF_RETURN_IF_ERROR(AssignBuffersForComputation( - computation, /*is_thread_local=*/true, hlo_set.get(), colocated_buffers, - colocated_allocations, assignment.get())); + computation, /*is_thread_local=*/true, colocated_buffers, + colocated_allocations, /*buffers_to_assign_sequentially=*/nullptr, + assignment.get())); } // Mark all buffers which may be live out of the entry computation as diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 34667c435d5448ab2a518733516e4a5140fb3dc4..9774a3174acfc7dcf219532a3c0eae22ad5f743c 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -23,6 +23,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/buffer_liveness.h" +#include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -354,6 +355,9 @@ class BufferAssignment { void AddAssignment(BufferAllocation* allocation, const LogicalBuffer& buffer, int64 offset, int64 size); + // Returns the HloModule used to construct this assignment. + const HloModule& module() { return *module_; } + // Returns the BufferLiveness object used to construct this assignment. const BufferLiveness& liveness() { return *liveness_; } @@ -396,13 +400,10 @@ class BufferAssigner { // Build and return a BufferAssignment for the given module. The given // HloOrdering is used to determine buffer liveness. buffer_size is a function // which returns the size of a LogicalBuffer. Alignment is the the minimum - // alignment of any buffer. If hlos_to_allocate is not null then only - // instructions in this vector are considered for buffer assignment. If - // hlos_to_allocate is null then all instructions are considered. + // alignment of any buffer. static StatusOr> Run( const HloModule* module, std::unique_ptr hlo_ordering, - LogicalBuffer::SizeFunction buffer_size, int64 alignment, - const std::vector* hlos_to_allocate = nullptr); + LogicalBuffer::SizeFunction buffer_size, int64 alignment); private: explicit BufferAssigner(LogicalBuffer::SizeFunction buffer_size, @@ -412,29 +413,38 @@ class BufferAssigner { // Create a buffer assignment. StatusOr> CreateAssignment( - const HloModule* module, std::unique_ptr hlo_ordering, - const std::vector* hlos_to_allocate = nullptr); + const HloModule* module, std::unique_ptr hlo_ordering); // Assigns buffers to the instructions in the given computation. "assignment" // is modified to reflect the new buffer assignments. If is_thread_local is // true, then all assigned buffers have the is_thread_local flag set to - // true. If hlos_to_allocate is not null it indicates which HLOs to include in - // buffer assignment. If null, all instructions in the computation are - // included. + // true. Status AssignBuffersForComputation( const HloComputation* computation, bool is_thread_local, - const tensorflow::gtl::FlatSet* hlos_to_allocate, const tensorflow::gtl::FlatSet& colocated_buffers, const tensorflow::gtl::FlatSet& colocated_allocations, + tensorflow::gtl::FlatMap>* + buffers_to_assign_sequentially, BufferAssignment* assignment); - // Assigns 'buffers_to_assign' assuming the HLO instructions will be executed - // in the given 'sequential_order'. + // Assigns 'buffers_to_assign_sequentially' using heap simulation, assuming + // the HLO instructions will be executed in the sequential order given by + // assignment->liveness().hlo_ordering().SequentialOrder. If + // 'run_whole_module_heap_simulation' is true, the heap simulation will be run + // assuming all global computations are sequentially ordered. Status AssignBuffersWithSequentialOrdering( - const std::vector& sequential_order, - const tensorflow::gtl::FlatSet& buffers_to_assign, - const HloComputation& computation, BufferAssignment* assignment); + const tensorflow::gtl::FlatMap< + const HloComputation*, + tensorflow::gtl::FlatSet>& + buffers_to_assign_sequentially, + bool run_whole_module_heap_simulation, BufferAssignment* assignment); + + // Uses the results of the heap simulator to create a single allocation, with + // LogicalBuffers packed to specific offsets. + void AssignBuffersFromHeapSimulator(const HeapSimulator::Result& result, + BufferAssignment* assignment); // Tries to assign the given instruction to the given buffer. Returns if the // assignment was successful. @@ -477,8 +487,6 @@ class BufferAssigner { const HloComputation& computation, const BufferLiveness& buffer_liveness, std::vector* colocated_buffer_sets); - const HloModule* module_; - // Function which returns the buffer size for a given logical buffer (shape). LogicalBuffer::SizeFunction buffer_size_; diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 0d6e89c5c6a4fbe2c7ed1acabcd743939faedc3a..ac1d769010c55ee4430554abe3205391bee5ebf1 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -856,8 +856,7 @@ TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) { EXPECT_FALSE(map_root_alloc.maybe_live_out()); EXPECT_TRUE(map_root_alloc.is_thread_local()); - // Allocations for the call computation should not be thread-local and not - // live-out. + // Allocations for the call computation should not be thread-local. auto& call_param_alloc = GetTopLevelAllocation(*assignment, call_param); EXPECT_FALSE(call_param_alloc.is_entry_computation_parameter()); EXPECT_FALSE(call_param_alloc.maybe_live_out()); @@ -865,7 +864,6 @@ TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) { auto& call_root_alloc = GetTopLevelAllocation(*assignment, call_root); EXPECT_FALSE(call_root_alloc.is_entry_computation_parameter()); - EXPECT_FALSE(call_root_alloc.maybe_live_out()); EXPECT_FALSE(call_root_alloc.is_thread_local()); // Entry computation allocations can be marked liveout and @@ -1445,8 +1443,7 @@ TEST_F(BufferAssignmentTest, TwoCalls) { FlattenCallGraph flatten; TF_ASSIGN_OR_ASSERT_OK(bool result, flatten.Run(module.get())); EXPECT_TRUE(result); - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr call_graph, - CallGraph::Build(module.get())); + std::unique_ptr call_graph = CallGraph::Build(module.get()); } RunCopyInsertion(module.get()); diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc index 38c2c8155186877355920042c63b52bf7192c1f6..3be4810490561808df2b34e341cfbd04928f8585 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -45,9 +45,7 @@ StatusOr> BufferLiveness::Run( } tensorflow::Status BufferLiveness::Analyze() { - TF_ASSIGN_OR_RETURN(points_to_analysis_, - TuplePointsToAnalysis::Run( - module_, /*include_loop_fusion_instructions=*/true)); + TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module_)); for (auto& computation : module_->computations()) { // Gather all instructions whose buffers might alias other instructions into // the set aliased_buffers_. This includes those contained as a tuple @@ -117,11 +115,7 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, // If 'b' is a user of 'a' then the buffers interfere unless 'a.instruction' // and 'b.instruction' emit the same shape/layout, and 'b.instruction' meets - // one of following qualifications: - // *) Is element-wise. - // *) Is a loop fusion instruction (with DynamicUpdateSlice fused root) where - // the singleton use of 'a' at 'a.index' is the fused root at operand 0. - // *) Use of 'operand' is DynamicUpdateSlice at operand index 0. + // the qualifications specified in CanShareOperandBufferWithUser. for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) { if (b.instruction()->IsUserOf(alias.instruction()) && !CanShareOperandBufferWithUser(alias.instruction(), alias.index(), diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index 57d69f5b71b336aba5bb9a8105b66ae5a5baa50a..fa7b2a309525dd80d655e10474c5d49f9da14ea8 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -98,12 +98,12 @@ void CallGraphNode::AddCallerCallSite(const CallSite& caller_callsite) { } } -Status CallGraphNode::AddCallSiteForInstruction(HloInstruction* instruction) { - TF_RET_CHECK(instruction->parent() == computation()); +void CallGraphNode::AddCallSiteForInstruction(HloInstruction* instruction) { + CHECK_EQ(instruction->parent(), computation()); const CallContext context = GetInstructionCallContext(instruction); if (!instruction->called_computations().empty()) { - TF_RET_CHECK(context == CallContext::kSequential || - context == CallContext::kParallel); + CHECK(context == CallContext::kSequential || + context == CallContext::kParallel); callsite_instructions_.insert({instruction, callsites_.size()}); callsites_.push_back( CallSite(instruction, instruction->called_computations(), context)); @@ -116,22 +116,21 @@ Status CallGraphNode::AddCallSiteForInstruction(HloInstruction* instruction) { } } } - return Status::OK(); } CallGraph::CallGraph(const HloModule* module) : module_(module) {} -StatusOr CallGraph::GetNode( +const CallGraphNode& CallGraph::GetNode( const HloComputation* computation) const { auto it = node_indices_.find(computation); - TF_RET_CHECK(it != node_indices_.end()); - return &nodes_[it->second]; + CHECK(it != node_indices_.end()); + return nodes_[it->second]; } -StatusOr CallGraph::GetNode(const HloComputation* computation) { +CallGraphNode& CallGraph::GetNode(const HloComputation* computation) { auto it = node_indices_.find(computation); - TF_RET_CHECK(it != node_indices_.end()); - return &nodes_[it->second]; + CHECK(it != node_indices_.end()); + return nodes_[it->second]; } namespace { @@ -154,17 +153,17 @@ CallContext UnionContexts(CallContext a, CallContext b) { } // namespace -Status CallGraph::SetCallContexts() { +void CallGraph::SetCallContexts() { std::queue worklist; // Initialize worklist with all roots of the call graph (computations without // callers). for (const std::unique_ptr& computation : module_->computations()) { - TF_ASSIGN_OR_RETURN(CallGraphNode * node, GetNode(computation.get())); - if (node->callers().empty()) { - node->set_context(CallContext::kSequential); - worklist.push(node); + CallGraphNode& node = GetNode(computation.get()); + if (node.callers().empty()) { + node.set_context(CallContext::kSequential); + worklist.push(&node); } } @@ -174,7 +173,7 @@ Status CallGraph::SetCallContexts() { for (const CallSite& callsite : node->callsites()) { for (const HloComputation* callee : callsite.called_computations()) { - TF_ASSIGN_OR_RETURN(CallGraphNode * callee_node, GetNode(callee)); + CallGraphNode& callee_node = GetNode(callee); // Update context of callee computation based on the callsite and its // current context. @@ -182,16 +181,16 @@ Status CallGraph::SetCallContexts() { if (callsite.context() == CallContext::kParallel) { context_to_add = CallContext::kParallel; } else { - TF_RET_CHECK(callsite.context() == CallContext::kSequential); + CHECK_EQ(callsite.context(), CallContext::kSequential); context_to_add = node->context(); } CallContext new_context = - UnionContexts(context_to_add, callee_node->context()); + UnionContexts(context_to_add, callee_node.context()); - if (new_context != callee_node->context()) { + if (new_context != callee_node.context()) { // Context of computation has been changed so add node to worklist. - callee_node->set_context(new_context); - worklist.push(callee_node); + callee_node.set_context(new_context); + worklist.push(&callee_node); } } } @@ -200,14 +199,12 @@ Status CallGraph::SetCallContexts() { // No node should have a kNone calling context. for (const std::unique_ptr& computation : module_->computations()) { - TF_ASSIGN_OR_RETURN(CallGraphNode * node, GetNode(computation.get())); - TF_RET_CHECK(node->context() != CallContext::kNone); + CHECK_NE(GetNode(computation.get()).context(), CallContext::kNone); } - return Status::OK(); } /* static */ -StatusOr> CallGraph::Build(const HloModule* module) { +std::unique_ptr CallGraph::Build(const HloModule* module) { // Constructor for CallGraph is private so MakeUnique can't be used. auto call_graph = WrapUnique(new CallGraph(module)); @@ -221,54 +218,49 @@ StatusOr> CallGraph::Build(const HloModule* module) { {computation.get(), call_graph->nodes_.size()}); // All computations should be unique, so the computation should not already // exist in the map. - TF_RET_CHECK(it_added.second); + CHECK(it_added.second); call_graph->nodes_.emplace_back(computation.get()); // Add all callsites in this computation. for (const std::unique_ptr& instruction : computation->instructions()) { - TF_RETURN_IF_ERROR(call_graph->nodes_.back().AddCallSiteForInstruction( - instruction.get())); + call_graph->nodes_.back().AddCallSiteForInstruction(instruction.get()); } } // Add caller callsites to each node. for (const std::unique_ptr& computation : module->computations()) { - TF_ASSIGN_OR_RETURN(CallGraphNode * caller_node, - call_graph->GetNode(computation.get())); - for (const CallSite& callsite : caller_node->callsites()) { + for (const CallSite& callsite : + call_graph->GetNode(computation.get()).callsites()) { for (auto* callee : callsite.called_computations()) { // Add caller callsites. - TF_ASSIGN_OR_RETURN(CallGraphNode * callee_node, - call_graph->GetNode(callee)); - callee_node->AddCallerCallSite(callsite); + call_graph->GetNode(callee).AddCallerCallSite(callsite); } } } - TF_RETURN_IF_ERROR(call_graph->SetCallContexts()); - + call_graph->SetCallContexts(); XLA_VLOG_LINES(1, call_graph->ToString()); - return std::move(call_graph); + return call_graph; } Status CallGraph::VisitNodesInternal( - const VisitorFunction& visitor_func, const CallGraphNode* node, + const VisitorFunction& visitor_func, const CallGraphNode& node, tensorflow::gtl::FlatSet* visited) const { - auto pair = visited->insert(node); + auto pair = visited->insert(&node); if (!pair.second) { // Node was not inserted. Node has already been visited. return Status::OK(); } - for (const HloComputation* computation : node->callees()) { - TF_ASSIGN_OR_RETURN(const CallGraphNode* callee_node, GetNode(computation)); - TF_RETURN_IF_ERROR(VisitNodesInternal(visitor_func, callee_node, visited)); + for (const HloComputation* computation : node.callees()) { + TF_RETURN_IF_ERROR( + VisitNodesInternal(visitor_func, GetNode(computation), visited)); } - return visitor_func(*node); + return visitor_func(node); } Status CallGraph::VisitNodes(const VisitorFunction& visitor_func, @@ -278,14 +270,13 @@ Status CallGraph::VisitNodes(const VisitorFunction& visitor_func, // Traverse from all roots in the call graph. for (const CallGraphNode& node : nodes()) { if (node.callers().empty()) { - TF_RETURN_IF_ERROR(VisitNodesInternal(visitor_func, &node, &visited)); + TF_RETURN_IF_ERROR(VisitNodesInternal(visitor_func, node, &visited)); } } } else { // Traverse only from the entry computation. - TF_ASSIGN_OR_RETURN(const CallGraphNode* entry_node, - GetNode(module_->entry_computation())); - TF_RETURN_IF_ERROR(VisitNodesInternal(visitor_func, entry_node, &visited)); + TF_RETURN_IF_ERROR(VisitNodesInternal( + visitor_func, GetNode(module_->entry_computation()), &visited)); } return Status::OK(); diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h index 62d12f8f91b099b452143c98427fdd1e6867ac7d..7f9990f06d4fee4c52fa516fc2f6031f5dab2bb9 100644 --- a/tensorflow/compiler/xla/service/call_graph.h +++ b/tensorflow/compiler/xla/service/call_graph.h @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" @@ -138,7 +137,7 @@ class CallGraphNode { // If instruction calls any computations adds a call site for this instruction // to the call graph node. If the instruction calls no computations then no // call site is added. - Status AddCallSiteForInstruction(HloInstruction* instruction); + void AddCallSiteForInstruction(HloInstruction* instruction); // Computation represented by this call graph node. HloComputation* computation_; @@ -174,12 +173,11 @@ class CallGraph { using VisitorFunction = std::function; // Builds and returns a call graph for the given HLO module. - static StatusOr> Build(const HloModule* module); + static std::unique_ptr Build(const HloModule* module); // Returns the node associated with the given computation. - StatusOr GetNode( - const HloComputation* computation) const; - StatusOr GetNode(const HloComputation* computation); + const CallGraphNode& GetNode(const HloComputation* computation) const; + CallGraphNode& GetNode(const HloComputation* computation); // Returns the vector of all nodes in the call graph. const std::vector& nodes() const { return nodes_; } @@ -197,14 +195,14 @@ class CallGraph { CallGraph(const HloModule* module); // Sets the call contexts for every node in the graph. - Status SetCallContexts(); + void SetCallContexts(); // Helper method for VisitNodes(). Traverses the call graph from 'node' in DFS // post order (callee before caller) calling visitor_func on each node. Adds // nodes to 'visited' as each node is visited. Skips nodes already in // 'visited'. Status VisitNodesInternal( - const VisitorFunction& visitor_func, const CallGraphNode* node, + const VisitorFunction& visitor_func, const CallGraphNode& node, tensorflow::gtl::FlatSet* visited) const; // The HLO module represented by this call graph. diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc index f71a5d01afa20e6c4e86ad8ef7a3a68c5e23e210..ab0ea47d024d871be88bfcab957810deb1ecac99 100644 --- a/tensorflow/compiler/xla/service/call_graph_test.cc +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -95,17 +95,15 @@ TEST_F(CallGraphTest, SingletonComputation) { HloModule module(TestName()); HloComputation* computation = module.AddEntryComputation(MakeScalarComputation()); - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr call_graph, - CallGraph::Build(&module)); + std::unique_ptr call_graph = CallGraph::Build(&module); EXPECT_EQ(1, call_graph->nodes().size()); - TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* node, - call_graph->GetNode(computation)); - EXPECT_EQ(computation, node->computation()); - EXPECT_TRUE(node->callsites().empty()); - EXPECT_TRUE(node->callees().empty()); - EXPECT_TRUE(node->caller_callsites().empty()); - EXPECT_TRUE(node->callers().empty()); - EXPECT_EQ(CallContext::kSequential, node->context()); + const CallGraphNode& node = call_graph->GetNode(computation); + EXPECT_EQ(computation, node.computation()); + EXPECT_TRUE(node.callsites().empty()); + EXPECT_TRUE(node.callees().empty()); + EXPECT_TRUE(node.caller_callsites().empty()); + EXPECT_TRUE(node.callers().empty()); + EXPECT_EQ(CallContext::kSequential, node.context()); } TEST_F(CallGraphTest, UnreachableComputation) { @@ -117,19 +115,17 @@ TEST_F(CallGraphTest, UnreachableComputation) { HloComputation* unreachable_computation = module.AddEmbeddedComputation(MakeScalarComputation()); - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr call_graph, - CallGraph::Build(&module)); + std::unique_ptr call_graph = CallGraph::Build(&module); EXPECT_EQ(2, call_graph->nodes().size()); - TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* entry_node, - call_graph->GetNode(entry_computation)); - EXPECT_EQ(entry_computation, entry_node->computation()); - EXPECT_EQ(CallContext::kSequential, entry_node->context()); + const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); + EXPECT_EQ(entry_computation, entry_node.computation()); + EXPECT_EQ(CallContext::kSequential, entry_node.context()); - TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* unreachable_node, - call_graph->GetNode(unreachable_computation)); - EXPECT_EQ(unreachable_computation, unreachable_node->computation()); - EXPECT_EQ(CallContext::kSequential, unreachable_node->context()); + const CallGraphNode& unreachable_node = + call_graph->GetNode(unreachable_computation); + EXPECT_EQ(unreachable_computation, unreachable_node.computation()); + EXPECT_EQ(CallContext::kSequential, unreachable_node.context()); } TEST_F(CallGraphTest, ParallelComputation) { @@ -141,27 +137,24 @@ TEST_F(CallGraphTest, ParallelComputation) { HloComputation* entry_computation = module.AddEntryComputation( MakeMappingComputation(map_computation, /*callsites=*/5)); - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr call_graph, - CallGraph::Build(&module)); + std::unique_ptr call_graph = CallGraph::Build(&module); EXPECT_EQ(2, call_graph->nodes().size()); - TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* entry_node, - call_graph->GetNode(entry_computation)); - EXPECT_EQ(entry_computation, entry_node->computation()); - EXPECT_EQ(CallContext::kSequential, entry_node->context()); - EXPECT_EQ(5, entry_node->callsites().size()); - EXPECT_EQ(1, entry_node->callees().size()); - EXPECT_TRUE(entry_node->caller_callsites().empty()); - EXPECT_TRUE(entry_node->callers().empty()); - - TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* map_node, - call_graph->GetNode(map_computation)); - EXPECT_EQ(map_computation, map_node->computation()); - EXPECT_EQ(CallContext::kParallel, map_node->context()); - EXPECT_TRUE(map_node->callsites().empty()); - EXPECT_TRUE(map_node->callees().empty()); - EXPECT_EQ(5, map_node->caller_callsites().size()); - EXPECT_EQ(1, map_node->callers().size()); + const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); + EXPECT_EQ(entry_computation, entry_node.computation()); + EXPECT_EQ(CallContext::kSequential, entry_node.context()); + EXPECT_EQ(5, entry_node.callsites().size()); + EXPECT_EQ(1, entry_node.callees().size()); + EXPECT_TRUE(entry_node.caller_callsites().empty()); + EXPECT_TRUE(entry_node.callers().empty()); + + const CallGraphNode& map_node = call_graph->GetNode(map_computation); + EXPECT_EQ(map_computation, map_node.computation()); + EXPECT_EQ(CallContext::kParallel, map_node.context()); + EXPECT_TRUE(map_node.callsites().empty()); + EXPECT_TRUE(map_node.callees().empty()); + EXPECT_EQ(5, map_node.caller_callsites().size()); + EXPECT_EQ(1, map_node.callers().size()); } TEST_F(CallGraphTest, SequentialComputations) { @@ -173,27 +166,24 @@ TEST_F(CallGraphTest, SequentialComputations) { HloComputation* entry_computation = module.AddEntryComputation( MakeCallingComputation(called_computation, /*callsites=*/3)); - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr call_graph, - CallGraph::Build(&module)); + std::unique_ptr call_graph = CallGraph::Build(&module); EXPECT_EQ(2, call_graph->nodes().size()); - TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* entry_node, - call_graph->GetNode(entry_computation)); - EXPECT_EQ(entry_computation, entry_node->computation()); - EXPECT_EQ(CallContext::kSequential, entry_node->context()); - EXPECT_EQ(3, entry_node->callsites().size()); - EXPECT_EQ(1, entry_node->callees().size()); - EXPECT_TRUE(entry_node->caller_callsites().empty()); - EXPECT_TRUE(entry_node->callers().empty()); - - TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* called_node, - call_graph->GetNode(called_computation)); - EXPECT_EQ(called_computation, called_node->computation()); - EXPECT_EQ(CallContext::kSequential, called_node->context()); - EXPECT_TRUE(called_node->callsites().empty()); - EXPECT_TRUE(called_node->callees().empty()); - EXPECT_EQ(3, called_node->caller_callsites().size()); - EXPECT_EQ(1, called_node->callers().size()); + const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); + EXPECT_EQ(entry_computation, entry_node.computation()); + EXPECT_EQ(CallContext::kSequential, entry_node.context()); + EXPECT_EQ(3, entry_node.callsites().size()); + EXPECT_EQ(1, entry_node.callees().size()); + EXPECT_TRUE(entry_node.caller_callsites().empty()); + EXPECT_TRUE(entry_node.callers().empty()); + + const CallGraphNode& called_node = call_graph->GetNode(called_computation); + EXPECT_EQ(called_computation, called_node.computation()); + EXPECT_EQ(CallContext::kSequential, called_node.context()); + EXPECT_TRUE(called_node.callsites().empty()); + EXPECT_TRUE(called_node.callees().empty()); + EXPECT_EQ(3, called_node.caller_callsites().size()); + EXPECT_EQ(1, called_node.callers().size()); } TEST_F(CallGraphTest, ContextBothComputations) { @@ -213,32 +203,29 @@ TEST_F(CallGraphTest, ContextBothComputations) { HloComputation* entry_computation = module.AddEntryComputation(builder.Build()); - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr call_graph, - CallGraph::Build(&module)); + std::unique_ptr call_graph = CallGraph::Build(&module); EXPECT_EQ(2, call_graph->nodes().size()); - TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* entry_node, - call_graph->GetNode(entry_computation)); - EXPECT_EQ(entry_computation, entry_node->computation()); - EXPECT_EQ(2, entry_node->callsites().size()); + const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); + EXPECT_EQ(entry_computation, entry_node.computation()); + EXPECT_EQ(2, entry_node.callsites().size()); - const CallSite& call_callsite = entry_node->callsites()[0]; + const CallSite& call_callsite = entry_node.callsites()[0]; EXPECT_EQ(call, call_callsite.instruction()); EXPECT_THAT(call_callsite.called_computations(), UnorderedElementsAre(subcomputation)); EXPECT_EQ(CallContext::kSequential, call_callsite.context()); - EXPECT_EQ(entry_node->GetCallSite(call), &call_callsite); + EXPECT_EQ(entry_node.GetCallSite(call), &call_callsite); - const CallSite& map_callsite = entry_node->callsites()[1]; + const CallSite& map_callsite = entry_node.callsites()[1]; EXPECT_EQ(map, map_callsite.instruction()); EXPECT_THAT(map_callsite.called_computations(), UnorderedElementsAre(subcomputation)); EXPECT_EQ(CallContext::kParallel, map_callsite.context()); - EXPECT_EQ(entry_node->GetCallSite(map), &map_callsite); + EXPECT_EQ(entry_node.GetCallSite(map), &map_callsite); - TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* sub_node, - call_graph->GetNode(subcomputation)); - EXPECT_EQ(CallContext::kBoth, sub_node->context()); + const CallGraphNode& sub_node = call_graph->GetNode(subcomputation); + EXPECT_EQ(CallContext::kBoth, sub_node.context()); } TEST_F(CallGraphTest, ComplexGraph) { @@ -284,27 +271,24 @@ TEST_F(CallGraphTest, ComplexGraph) { entry_computation = module.AddEntryComputation(builder.Build()); } - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr call_graph, - CallGraph::Build(&module)); + std::unique_ptr call_graph = CallGraph::Build(&module); EXPECT_EQ(5, call_graph->nodes().size()); // Entry computation has one while instruction calling two computations // (cond_computation and a_computation). - TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* entry_node, - call_graph->GetNode(entry_computation)); - ASSERT_EQ(1, entry_node->callsites().size()); + const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); + ASSERT_EQ(1, entry_node.callsites().size()); const std::vector& called_computations = - entry_node->callsites()[0].called_computations(); + entry_node.callsites()[0].called_computations(); EXPECT_THAT(called_computations, UnorderedElementsAre(cond_computation, a_computation)); - EXPECT_EQ(CallContext::kSequential, entry_node->context()); + EXPECT_EQ(CallContext::kSequential, entry_node.context()); - TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* c_node, - call_graph->GetNode(c_computation)); - EXPECT_TRUE(c_node->callsites().empty()); - EXPECT_THAT(c_node->callers(), + const CallGraphNode& c_node = call_graph->GetNode(c_computation); + EXPECT_TRUE(c_node.callsites().empty()); + EXPECT_THAT(c_node.callers(), UnorderedElementsAre(a_computation, b_computation)); - EXPECT_EQ(CallContext::kBoth, c_node->context()); + EXPECT_EQ(CallContext::kBoth, c_node.context()); // Visit the graph and verify nodes were visited in callee-before-caller // order. @@ -337,8 +321,7 @@ TEST_F(CallGraphTest, VisitSingletonComputation) { HloModule module(TestName()); HloComputation* computation = module.AddEntryComputation(MakeScalarComputation()); - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr call_graph, - CallGraph::Build(&module)); + std::unique_ptr call_graph = CallGraph::Build(&module); std::vector visited; TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) { @@ -355,8 +338,7 @@ TEST_F(CallGraphTest, VisitUnreachableComputation) { module.AddEntryComputation(MakeScalarComputation()); HloComputation* unreachable_computation = module.AddEmbeddedComputation(MakeScalarComputation()); - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr call_graph, - CallGraph::Build(&module)); + std::unique_ptr call_graph = CallGraph::Build(&module); // Test visitation of only reachable nodes. { @@ -390,8 +372,7 @@ TEST_F(CallGraphTest, VisitWithError) { // Test that the call graph visitor properly propagates errors. HloModule module(TestName()); module.AddEntryComputation(MakeScalarComputation()); - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr call_graph, - CallGraph::Build(&module)); + std::unique_ptr call_graph = CallGraph::Build(&module); Status status = call_graph->VisitNodes( [](const CallGraphNode&) { return InternalError("Visitation failed"); }); diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc new file mode 100644 index 0000000000000000000000000000000000000000..ac1906c88c47a1efff305f2a45de66b84048af37 --- /dev/null +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -0,0 +1,131 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/compile_only_service.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/computation_tracker.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace se = ::perftools::gputools; + +namespace xla { + +/* static */ StatusOr> +CompileOnlyService::NewService(perftools::gputools::Platform* platform) { + ServiceOptions default_options; + default_options.set_platform(platform); + return NewService(default_options); +} + +/* static */ StatusOr> +CompileOnlyService::NewService(const ServiceOptions& options) { + perftools::gputools::Platform* platform = options.platform(); + if (platform == nullptr) { + TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); + } + + TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform)); + + TF_ASSIGN_OR_RETURN(std::unique_ptr compute_constant_backend, + CreateComputeConstantBackend()); + std::unique_ptr service( + new CompileOnlyService(compiler, std::move(compute_constant_backend))); + return std::move(service); +} + +CompileOnlyService::CompileOnlyService( + Compiler* compiler, std::unique_ptr compute_constant_backend) + : Service(/*backend=*/nullptr, std::move(compute_constant_backend)), + compiler_(compiler) { + runs_in_client_process_ = true; +} + +StatusOr>> +CompileOnlyService::CompileAheadOfTime( + const tensorflow::gtl::ArraySlice computations, + const AotCompilationOptions& options) { + std::vector> hlo_modules; + std::vector> module_configs; + for (const AotComputationInstance& instance : computations) { + TF_ASSIGN_OR_RETURN(UserComputation * user_computation, + computation_tracker_.Resolve(instance.computation)); + VersionedComputationHandle versioned_handle = + user_computation->GetVersionedHandle(); + + // Dump computation proto state if flag is set. + legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); + const string& directory_path = flags->xla_dump_computations_to; + if (!directory_path.empty()) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr session_module, + computation_tracker_.SnapshotComputation(versioned_handle.handle)); + string filename = tensorflow::strings::StrCat( + "computation_", versioned_handle.handle.handle(), "__", + session_module->entry().name(), "__version_", + versioned_handle.version); + TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename, + *session_module)); + } + + TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, + computation_tracker_.BuildHloModule( + versioned_handle, + /*include_unreachable_instructions=*/true)); + hlo_modules.push_back(std::move(hlo_module)); + + TF_ASSIGN_OR_RETURN( + std::shared_ptr program_shape, + user_computation->ComputeProgramShape(versioned_handle.version)); + + module_configs.push_back(MakeUnique(*program_shape)); + HloModuleConfig* module_config = module_configs.back().get(); + auto* computation_layout = + module_config->mutable_entry_computation_layout(); + if (flags->xla_hlo_profile) { + module_config->enable_hlo_profiling(true); + } + for (int i = 0; i < instance.argument_layouts.size(); ++i) { + const Shape& argument_layout = *instance.argument_layouts[i]; + if (ShapeUtil::IsTuple(argument_layout)) { + return Unimplemented("tuple arguments not supported yet"); + } + TF_RETURN_IF_ERROR( + computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( + argument_layout)); + } + TF_RETURN_IF_ERROR( + computation_layout->mutable_result_layout()->CopyLayoutFromShape( + *instance.result_layout)); + } + + return compiler_->CompileAheadOfTime(std::move(hlo_modules), + std::move(module_configs), + MakeHloDumper(), options); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h new file mode 100644 index 0000000000000000000000000000000000000000..6dae49e3e1acf144847d44af4507880d8bf2efc4 --- /dev/null +++ b/tensorflow/compiler/xla/service/compile_only_service.h @@ -0,0 +1,125 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILE_ONLY_SERVICE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILE_ONLY_SERVICE_H_ + +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/service.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { + +// An XLA Service specialization for ahead-of-time compilation. This only +// instantiates a Compiler object for the relevant platform; it does not +// instantiate or require an execution backend. +class CompileOnlyService : public Service { + public: + // Factory for creating a CompileOnlyService. The parameter platform is the + // platform that the service should target. If platform is null then the + // default platform is used. + static StatusOr> NewService( + perftools::gputools::Platform* platform); + static StatusOr> NewService( + const ServiceOptions& options); + + // A description of a computation to compile using CompileAheadOfTime. + struct AotComputationInstance { + ComputationHandle computation; + std::vector argument_layouts; + const Shape* result_layout = nullptr; + }; + + // Compiles a list of computations for ahead-of-time execution. This is + // intended for use in static compilation. See + // |CompileOnlyClient::CompileAheadOfTime| for additional details. + StatusOr>> + CompileAheadOfTime( + const tensorflow::gtl::ArraySlice computations, + const AotCompilationOptions& Options); + + // Override Service methods that require or imply the existence of an + // execute backend. Note that this does not include TransferToClient and + // TransferToClientInProcess, as computing contants produces global data + // that we may wish to transfer. + tensorflow::Status Execute(const ExecuteRequest* arg, + ExecuteResponse* result) override { + return Unimplemented("CompileOnlyService does not support execution."); + } + tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg, + ExecuteParallelResponse* result) override { + return Unimplemented("CompileOnlyService does not support execution."); + } + tensorflow::Status GetDeviceHandles( + const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) override { + return Unimplemented("CompileOnlyService does not support devices."); + } + tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg, + ExecuteAsyncResponse* result) override { + return Unimplemented("CompileOnlyService does not support execution."); + } + tensorflow::Status WaitForExecution( + const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) override { + return Unimplemented("CompileOnlyService does not support execution."); + } + tensorflow::Status TransferToServer( + const TransferToServerRequest* arg, + TransferToServerResponse* result) override { + return Unimplemented( + "CompileOnlyService does not support device data transfers."); + } + tensorflow::Status TransferToInfeed( + const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) override { + return Unimplemented( + "CompileOnlyService does not support device data transfers."); + } + tensorflow::Status TransferFromOutfeed( + const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) override { + return Unimplemented( + "CompileOnlyService does not support device data transfers."); + } + tensorflow::Status TransferToServerInProcess( + const TransferToServerInProcessRequest* arg, + TransferToServerInProcessResponse* result) override { + return Unimplemented( + "CompileOnlyService does not support device data transfers."); + } + tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) override { + return Unimplemented("CompileOnlyService does not support devices."); + } + + private: + explicit CompileOnlyService( + Compiler* compiler, std::unique_ptr compute_constant_backend); + CompileOnlyService(const CompileOnlyService&) = delete; + void operator=(const CompileOnlyService&) = delete; + + // The compiler for the target platform. This is included in place of + // the Service::execute_backend_'s compiler, since execute_backend_ is a + // nullptr in CompileOnlyService. + Compiler* compiler_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILE_ONLY_SERVICE_H_ diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 7db28aed3cd2045d6c1e94a390ce632bf3bbe9de..907b0307d4b61018814b02737fba4837c2e1d668 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -319,6 +320,7 @@ Status InstructionCopier::RecordIndicesWhichInterfereWithOtherInstruction( if (liveness.MayInterfere(*instruction_buffer, *other_buffer)) { VLOG(2) << "Adding copy of buffer for instruction: " << instruction_->name() + << " instruction_buffer: " << instruction_buffer->ToString() << " at index: " << tensorflow::str_util::Join(index, ",") << " because of interference with buffer: " << other_buffer->ToString(); @@ -351,6 +353,11 @@ Status InstructionCopier::RecordControlPredecessors( for (const BufferAlias& alias : points_to_analysis.GetBufferAliases(*buffer)) { for (HloInstruction* user : alias.instruction()->users()) { + if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), + user, points_to_analysis)) { + continue; + } + if (user != instruction_) { control_predecessors_.mutable_element(index)->push_back(user); } diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 6e24506b383c61a3e346d3e3250511cd6a2d4940..affb5f99066d8278c583c469d97e78646d52f3c6 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -53,7 +53,6 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", - "//tensorflow/compiler/xla/port:initialize", "//tensorflow/compiler/xla/service:algebraic_simplifier", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", @@ -98,6 +97,7 @@ cc_library( name = "simple_orc_jit", srcs = ["simple_orc_jit.cc"], hdrs = ["simple_orc_jit.h"], + linkopts = ["-ldl"], deps = [ ":compiler_functor", ":cpu_runtime", diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index f717d57839f4cfc59121b9e8e39b5b9c63c9b60d..b42702dbe1abe3db838159bda2665743e416a2d5 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" @@ -28,6 +29,8 @@ limitations under the License. namespace xla { namespace cpu { +using ::testing::ElementsAre; + class ConvCanonicalizationTest : public HloTestBase { public: ConvCanonicalizationTest() { @@ -96,14 +99,14 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { // The input is in CNHW order. input_reshape should produce // NHWC for the convolution to hit the Eigen fast path. - EXPECT_TRUE(ContainersEqual(input_reshape->dimensions(), {1, 2, 3, 0})); + EXPECT_THAT(input_reshape->dimensions(), ElementsAre(1, 2, 3, 0)); // The kernel is in OIHW order. kernel_reshape should produce // HWIO for the convolution to hit the Eigen fast path. - EXPECT_TRUE(ContainersEqual(kernel_reshape->dimensions(), {2, 3, 1, 0})); + EXPECT_THAT(kernel_reshape->dimensions(), ElementsAre(2, 3, 1, 0)); // The output of the canonical convolution is in NHWC order (the same as // input_reshape's order). output_reshape should restore that order to the // order of the computation root (CNHW). - EXPECT_TRUE(ContainersEqual(output_reshape->dimensions(), {3, 0, 1, 2})); + EXPECT_THAT(output_reshape->dimensions(), ElementsAre(3, 0, 1, 2)); } TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 3e6be5a7a2374d4d274f95b7b8e2d814f8ace8b1..1ba45e59838c10ab5c050cb74e263eca70783fb0 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -39,7 +39,6 @@ limitations under the License. #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/port/initialize.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" @@ -683,8 +682,10 @@ int64 CpuCompiler::ShapeSizeBytes(const Shape& shape) const { } // namespace cpu } // namespace xla -REGISTER_MODULE_INITIALIZER(cpu_compiler, { +static bool InitModule() { xla::Compiler::RegisterCompilerFactory(se::host::kHostPlatformId, []() { return xla::MakeUnique(); }); -}); + return true; +} +static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index 240da35ef190eb7080947ab7d1da91d8d2dd8973..dc002846e9e6b07c767ddc8af939657c4c51bf23 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -24,6 +24,11 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, int64 operand_index) { HloInstruction* producer = consumer->mutable_operand(operand_index); + // Output fusion is not currently supported on CPUs. + if (producer->opcode() == HloOpcode::kFusion) { + return false; + } + // Condition for consumer: must be elementwise or a fusion op // (which necessarily only contains elementwise operations) if (!(consumer->opcode() == HloOpcode::kFusion || diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 8e06f0520edfb05c7ec606dcb8e85c5ef997c2c0..253de20f25127bf0ac23d5969e0f16c143396e47 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" -#include #include #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 1c704fd1ee77f3effad2b460e955efe53e441310..1e34de9e4bde992154ece2b8ff0783c9fb2b8a1a 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -201,7 +201,8 @@ void IrEmitter::InitializeIrFunction(const string& function_name, if (&argument == retval) { continue; } - compute_function_->setDoesNotAlias(argument.getArgNo() + 1); + compute_function_->addAttribute(argument.getArgNo() + 1, + llvm::Attribute::NoAlias); } ir_builder_.SetInsertPoint(llvm::BasicBlock::Create( diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc index 7a4723e8d75588d8ccb711892b4082024695e444..cadad10910132c716eefd4ecfba53f3d7e02df99 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc @@ -146,7 +146,7 @@ Status ParallelCpuExecutable::AllocateBuffers( } Status ParallelCpuExecutable::ExecuteComputeFunctions( - const ExecutableRunOptions* run_options, + const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, tensorflow::gtl::ArraySlice buffers, HloExecutionProfile* hlo_execution_profile) { @@ -160,7 +160,7 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( } Status ParallelCpuExecutable::ExecuteComputeFunctions( - const ExecutableRunOptions* run_options, + const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, tensorflow::gtl::ArraySlice buffers, HloExecutionProfile* hlo_execution_profile) { @@ -214,7 +214,7 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( void** temps_array = buffer_pointers.data(); uint64* profile_counters_array = profile_counters.data(); - auto* thread_pool = CHECK_NOTNULL(run_options->inter_op_thread_pool()); + auto* thread_pool = CHECK_NOTNULL(run_options->xla_intra_op_thread_pool()); tensorflow::mutex completion_queue_lock; tensorflow::condition_variable completion_queue_cv; std::deque completion_queue; @@ -251,11 +251,12 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( }); auto function = FindOrDie(functions, instruction); // The thread pool entry takes ownership of |operand_buffers|. + const auto* exec_run_options = &run_options->run_options(); thread_pool->Schedule([instruction, &completion_queue, &completion_queue_lock, &completion_queue_cv, - result_buffer, run_options, operand_buffers, + result_buffer, exec_run_options, operand_buffers, temps_array, profile_counters_array, function] { - function(result_buffer, run_options, operand_buffers, temps_array, + function(result_buffer, exec_run_options, operand_buffers, temps_array, profile_counters_array); delete[] operand_buffers; // Push the completed HLO instruction on the queue, the main thread @@ -345,9 +346,8 @@ ParallelCpuExecutable::ExecuteOnStream( const BufferAllocation::Index result_index = result_slice.index(); VLOG(3) << "result index: " << result_index; - TF_RETURN_IF_ERROR(ExecuteComputeFunctions(&run_options->run_options(), - arguments, device_allocations, - hlo_execution_profile)); + TF_RETURN_IF_ERROR(ExecuteComputeFunctions( + run_options, arguments, device_allocations, hlo_execution_profile)); // Mark the buffers that are actually live (used in the output) when the // computation finishes executing. @@ -400,8 +400,8 @@ StatusOr> ParallelCpuExecutable::ExecuteOnStream( TF_RETURN_IF_ERROR(AllocateBuffers( memory_allocator, stream->parent()->device_ordinal(), &buffers)); - TF_RETURN_IF_ERROR(ExecuteComputeFunctions( - &run_options->run_options(), arguments, buffers, hlo_execution_profile)); + TF_RETURN_IF_ERROR(ExecuteComputeFunctions(run_options, arguments, buffers, + hlo_execution_profile)); // Copy DeviceMemoryBase values which contain the array(s) of the result into // the respective location in ShapedBuffer which is returned to the caller. diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h index 7223de9f0798365138cdb26ca9dce07cd0e474e3..6e1239d590c1f5698066cd77b5637912e14264e7 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h @@ -96,14 +96,14 @@ class ParallelCpuExecutable : public Executable { // Calls the generated functions in 'function_names_', performing the // computation with the given arguments using the supplied buffers. Status ExecuteComputeFunctions( - const ExecutableRunOptions* run_options, + const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, tensorflow::gtl::ArraySlice buffers, HloExecutionProfile* hlo_execution_profile); Status ExecuteComputeFunctions( - const ExecutableRunOptions* run_options, + const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, tensorflow::gtl::ArraySlice buffers, diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc index 677080a8623224cdd65e35b3116ae57b7b3b3ca2..332f4216dc7b970cb985719ef82d5aa82bb86d3d 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc @@ -53,8 +53,8 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, typedef typename Eigen::Tensor::DimensionPair DimPair; int lhs_contract_dim = transpose_lhs ? 0 : 1; int rhs_contract_dim = transpose_rhs ? 1 : 0; - const Eigen::array dims( - DimPair(lhs_contract_dim, rhs_contract_dim)); + const Eigen::array dims({ + DimPair(lhs_contract_dim, rhs_contract_dim) }); // Matrix multiply is a special case of the "contract" operation where // the contraction is performed along dimension 1 of the lhs and dimension diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc index 384a978873de89526f43556296aaa51c46ac1d3f..e45329c4ef52090c4d8b50c1afc452d0dadceb35 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc @@ -47,8 +47,8 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, typedef typename Eigen::Tensor::DimensionPair DimPair; int lhs_contract_dim = transpose_lhs ? 0 : 1; int rhs_contract_dim = transpose_rhs ? 1 : 0; - const Eigen::array dims( - DimPair(lhs_contract_dim, rhs_contract_dim)); + const Eigen::array dims({ + DimPair(lhs_contract_dim, rhs_contract_dim)}); // Matrix multiply is a special case of the "contract" operation where // the contraction is performed along dimension 1 of the lhs and dimension diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index a04815dad94484a6f01ebd27d3ec73f547086722..bea1da4044669f5e910af09ba1b65416a69367b5 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -240,14 +240,18 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( return ir_builder_->CreateFDiv(lhs_value, rhs_value); case HloOpcode::kRemainder: return ir_builder_->CreateFRem(lhs_value, rhs_value); - - // The 'O' prefix on the LLVM ops means "ordered" compare where comparisons - // with NAN always return false. + // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered + // comparisons always return false when one of the operands is NaN, whereas + // unordered comparisons return true. + // + // We use ordered comparisons for everything except kNe, where we use an + // unordered comparison. This makes x != y equivalent to !(x == y), and + // matches C++'s semantics. case HloOpcode::kEq: return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value, rhs_value, ir_builder_); case HloOpcode::kNe: - return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_ONE, lhs_value, + return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value, rhs_value, ir_builder_); case HloOpcode::kLt: return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value, @@ -739,11 +743,11 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const HloInstruction* operand = hlo->operand(operand_idx); auto true_block = llvm_ir::CreateBasicBlock( exit_block, tensorflow::strings::StrCat( - "concat_index_from_operand", operand_idx), + "concat_index_from_operand", operand_idx), ir_builder_); auto false_block = llvm_ir::CreateBasicBlock( exit_block, tensorflow::strings::StrCat( - "concat_index_not_from_operand", operand_idx), + "concat_index_not_from_operand", operand_idx), ir_builder_); auto concat_dim_size = llvm::ConstantInt::get(source_index[concat_dim]->getType(), diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.cc b/tensorflow/compiler/xla/service/flatten_call_graph.cc index 3c41fe870f2109a16f4d47aee5195a5537380bcb..297a4f7599f9c127386b2f53f7ffb987befc456e 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph.cc @@ -102,8 +102,7 @@ Status FlattenNode(const CallGraphNode& node) { StatusOr FlattenCallGraph::Run(HloModule* module) { XLA_VLOG_LINES(3, "Before flatten call graph:\n" + module->ToString()); - TF_ASSIGN_OR_RETURN(std::unique_ptr call_graph, - CallGraph::Build(module)); + std::unique_ptr call_graph = CallGraph::Build(module); TF_RETURN_IF_ERROR(call_graph->VisitNodes(FlattenNode)); XLA_VLOG_LINES(3, "After flatten call graph:\n" + module->ToString()); diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc index 6c4a48bbe8e096c00b4ebc2e991a8ff38c06a07b..4e03a96fb3f03710cd3062a79aa4955311cf19c1 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc @@ -141,11 +141,9 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) { { TF_ASSIGN_OR_ASSERT_OK(bool result, RunFlattenCallGraph(&module)); EXPECT_TRUE(result); - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr flat_call_graph, - CallGraph::Build(&module)); - TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* c_node, - flat_call_graph->GetNode(c_computation)); - EXPECT_EQ(1, c_node->caller_callsites().size()); + std::unique_ptr flat_call_graph = CallGraph::Build(&module); + const CallGraphNode& c_node = flat_call_graph->GetNode(c_computation); + EXPECT_EQ(1, c_node.caller_callsites().size()); } } @@ -178,21 +176,17 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { } { - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr call_graph, - CallGraph::Build(&module)); - TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* cond_node, - call_graph->GetNode(cond_computation)); - EXPECT_EQ(2, cond_node->caller_callsites().size()); + std::unique_ptr call_graph = CallGraph::Build(&module); + const CallGraphNode& cond_node = call_graph->GetNode(cond_computation); + EXPECT_EQ(2, cond_node.caller_callsites().size()); } { TF_ASSIGN_OR_ASSERT_OK(bool result, RunFlattenCallGraph(&module)); EXPECT_TRUE(result); - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr call_graph, - CallGraph::Build(&module)); - TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* cond_node, - call_graph->GetNode(cond_computation)); - EXPECT_EQ(1, cond_node->caller_callsites().size()); + std::unique_ptr call_graph = CallGraph::Build(&module); + const CallGraphNode& cond_node = call_graph->GetNode(cond_computation); + EXPECT_EQ(1, cond_node.caller_callsites().size()); } } @@ -219,17 +213,14 @@ TEST_F(FlattenCallGraphTest, FlattenCalls) { TF_ASSIGN_OR_ASSERT_OK(bool result, RunFlattenCallGraph(&module)); EXPECT_TRUE(result); - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr call_graph, - CallGraph::Build(&module)); + std::unique_ptr call_graph = CallGraph::Build(&module); EXPECT_EQ(7, module.computations().size()); - TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* c_node, - call_graph->GetNode(c_computation)); - EXPECT_EQ(1, c_node->caller_callsites().size()); + const CallGraphNode& c_node = call_graph->GetNode(c_computation); + EXPECT_EQ(1, c_node.caller_callsites().size()); - TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* b_node, - call_graph->GetNode(b_computation)); - EXPECT_EQ(1, b_node->caller_callsites().size()); + const CallGraphNode& b_node = call_graph->GetNode(b_computation); + EXPECT_EQ(1, b_node.caller_callsites().size()); } } // namespace diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 1fdbcfe5641ab3c0cda63268082069df765bf4e6..d26f415fd4bdfec597c70b760942cc406a0d6cfa 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -264,6 +264,8 @@ cc_library( "//tensorflow/compiler/xla/service:tuple_points_to_analysis", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform/default/build_config:cublas_plugin", + "//tensorflow/core/platform/default/build_config:cudnn_plugin", "//tensorflow/core/platform/default/build_config:stream_executor_cuda", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 1667ab36792c91cbbf3c6396a673bedff2208045..e57eb0bdee64948290d5eaf15965afcdc8bea0ad 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -113,7 +113,7 @@ StatusOr GpuElementalIrEmitter::EmitMathCall( tensorflow::gtl::ArraySlice operands, tensorflow::gtl::ArraySlice input_types, PrimitiveType output_type) const { - // Binary math functions tranform are of type [T] -> T. + // Binary math functions transform are of type [T] -> T. for (PrimitiveType input_type : input_types) { if (output_type != input_type) { return Unimplemented("Input type ≠ output type: %s ≠ %s", diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 34a44ad40548272a0c2a87efadfa1ab2aca7b979..a36dcbbd2faf3258ec2790f51bb2aec3ce834a6c 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -46,6 +46,11 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, int64 operand_index) { HloInstruction* producer = consumer->mutable_operand(operand_index); + // Output fusion is not currently supported on GPUs. + if (producer->opcode() == HloOpcode::kFusion) { + return false; + } + // RNG operations are not currently parallel-friendly on GPU. if (producer->opcode() == HloOpcode::kRng) { return false; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index e8378a7f447cebf8d491e98595188d2391333c58..c6e8a2f78b5a398d9e9d5a684ac4d42520ec20c8 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -59,6 +59,11 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, } // namespace bool ImplementedAsGemm(const HloInstruction& hlo) { + // We can only do this if the HLO is unnested. + if (hlo.parent() != hlo.GetModule()->entry_computation()) { + return false; + } + // For certain types of Dot, we can call pre-canned BLAS gemm. if (hlo.opcode() == HloOpcode::kDot) { const Shape& lhs_shape = hlo.operand(0)->shape(); @@ -85,6 +90,11 @@ bool ImplementedAsGemm(const HloInstruction& hlo) { } bool ImplementedAsDnnConvolution(const HloInstruction& hlo) { + // We can only do this if the HLO is unnested. + if (hlo.parent() != hlo.GetModule()->entry_computation()) { + return false; + } + // Forward convolution. if (hlo.opcode() == HloOpcode::kConvolution) { const ConvolutionDimensionNumbers& dnums = diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index e60978df0a2a9c7911c71314e5325ee0fbfd67e0..36619a845413b19ec2d559252409dae1b96b76e4 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -399,7 +399,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot, llvm::Type* accum_type = target_array.GetElementLlvmType(); llvm::Value* accum_address = llvm_ir::EmitAllocaAtFunctionEntry( accum_type, // The pointee type of the alloca instruction. - "accum_address", // The name of the alloca instuction. + "accum_address", // The name of the alloca instruction. &ir_builder_); // Initialize the accumulator in the preheader to zero. diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 04babcca0c822700e4e47c66433e8d3ea6ac3d39..e52e55a1a8199019e2c149a777a4e948f830ce0e 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -196,7 +196,7 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( ir_emitter_context_->buffer_assignment().GetTempAllocation()) { kernel->addDereferenceableAttr(temp_buffer_arg_no + 1, allocation->size()); } - kernel->setDoesNotAlias(temp_buffer_arg_no + 1); + kernel->addAttribute(temp_buffer_arg_no + 1, llvm::Attribute::NoAlias); // Add the declaration of this kernel to llvm.nvvm.annotations so that NVPTX // treats it as a CUDA kernel. diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 485216837dc727bfe8565ff22678dd2fa470bc40..383729185df14404c4479993a7cdec771a63b26e 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -396,7 +396,7 @@ StatusOr CompileModuleToPtx(llvm::Module* module, // The LLVM IR verifier performs sanity checking on the IR. This helps // discover problems and report them in a meaningful manner, rather than let - // later passes report obscure assertions becasue of unfulfilled invariants. + // later passes report obscure assertions because of unfulfilled invariants. module_passes.add(llvm::createVerifierPass()); // Create the function-level pass manager. It needs data layout information diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc index 46a5d303b74af6b312b9e7d774dd484336322b4e..61bc6f6055740a3632ddd1cad94491de97309ae6 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer.cc @@ -168,7 +168,7 @@ class MatcherBase { virtual ~MatcherBase() {} // Attempts to match each ExprTree in 'expr_trees_'. - // Returns OK on the first succesful match, error status otherwise. + // Returns OK on the first successful match, error status otherwise. virtual tensorflow::Status Run() { Status status; for (const ExprTree& expr_tree : expr_trees_) { diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 9c4899a67debfebf72b93b412a07ad60993fd819..d7aa5664df40f24d17b48e846839c22cf7922f75 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -53,12 +53,44 @@ std::vector UniqueOperandSourceBuffers( /*static*/ StatusOr HeapSimulator::Run( - std::unique_ptr algorithm, + std::unique_ptr algorithm, const HloModule& module, + const SequentialHloOrdering::HloModuleSequence& module_sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_fn, + const FlatSet* buffers_to_assign) { + HeapSimulator heap(std::move(algorithm), size_fn, buffers_to_assign); + const HloComputation* entry_computation = module.entry_computation(); + const std::vector& instruction_sequence = + FindOrDie(module_sequence, entry_computation); + TF_RETURN_IF_ERROR(heap.RunComputation(*entry_computation, + instruction_sequence, + points_to_analysis, &module_sequence)); + return heap.Finish(); +} + +/*static*/ +StatusOr HeapSimulator::Run( + std::unique_ptr algorithm, const HloComputation& computation, const std::vector& instruction_sequence, - const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_fn, const FlatSet* buffers_to_assign) { + HeapSimulator heap(std::move(algorithm), size_fn, buffers_to_assign); + TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence, + points_to_analysis, + /*module_sequence=*/nullptr)); + return heap.Finish(); +} + +// Runs a heap simulation for the given 'computation', assuming the given +// 'instruction_sequence'. If 'module_sequence' is non-null, it is used to find +// kCall and kWhile sub-computations, and the heap simulation for those +// sub-computations will be run recursively. +Status HeapSimulator::RunComputation( + const HloComputation& computation, + const std::vector& instruction_sequence, + const TuplePointsToAnalysis& points_to_analysis, + const SequentialHloOrdering::HloModuleSequence* module_sequence) { // The goal here is to minimize memory usage, assuming the given sequential // ordering of instructions. The strategy is to walk through the instruction // sequence, calling Alloc and Free on the underlying heap algorithm. The @@ -67,7 +99,6 @@ StatusOr HeapSimulator::Run( // 'live_buffers' tracks the liveness of each buffer that we assign, by // associating it with a set of HloInstructions that need to be visited. When // the set becomes empty, the buffer is no longer used, and can be freed. - HeapSimulator heap(std::move(algorithm), size_fn, buffers_to_assign); FlatMap> live_buffers; const HloInstruction* root = computation.root_instruction(); @@ -90,7 +121,7 @@ StatusOr HeapSimulator::Run( // lifetime of buffers that aren't already connected by a data dependency. std::vector dead_buffers_to_free; for (const LogicalBuffer* buffer : buffers_defined_by_instruction) { - if (heap.IgnoreBuffer(buffer)) { + if (IgnoreBuffer(buffer)) { continue; } for (const BufferAlias& alias : @@ -127,7 +158,7 @@ StatusOr HeapSimulator::Run( std::vector operand_buffers_to_free; for (const LogicalBuffer* operand_buffer : UniqueOperandSourceBuffers(instruction, points_to_analysis)) { - if (heap.IgnoreBuffer(operand_buffer)) { + if (IgnoreBuffer(operand_buffer)) { continue; } live_buffers[operand_buffer].erase(instruction); @@ -142,10 +173,10 @@ StatusOr HeapSimulator::Run( // happen before dead or operand buffers are freed; the instruction reads // the operand buffers to produce its output. // - // INVARIANT: Either heap.Alloc or heap.ShareBuffer will be called for each - // buffer that we should assign. + // INVARIANT: Either Alloc or ShareBuffer will be called for each buffer + // that we should assign. for (const LogicalBuffer* buffer : buffers_defined_by_instruction) { - if (heap.IgnoreBuffer(buffer)) { + if (IgnoreBuffer(buffer)) { continue; } @@ -159,24 +190,50 @@ StatusOr HeapSimulator::Run( CanShareOperandBufferWithUser( operand_buffer->instruction(), operand_buffer->index(), buffer->instruction(), buffer->index(), points_to_analysis)) { - heap.ShareBuffer(buffer, operand_buffer); + ShareBuffer(buffer, operand_buffer); shared = true; break; } } if (!shared) { - heap.Alloc(buffer); + Alloc(buffer); } } + // If the whole module is sequential, we can save memory by running the + // heap-simulation for sub-computations inline. E.g. the buffers for the + // condition and body of a kWhile instruction are only live for the duration + // of the instruction itself. + // + // The order that the sub-computations are simulated does not affect + // correctness; since the whole module is sequential, we know that the + // sub-computations will never be run concurrently. + if (module_sequence != nullptr) { + if (instruction->opcode() == HloOpcode::kCall || + instruction->opcode() == HloOpcode::kWhile) { + for (const HloComputation* called_computation : + instruction->called_computations()) { + const std::vector& called_sequence = + FindOrDie(*module_sequence, called_computation); + TF_RETURN_IF_ERROR(RunComputation(*called_computation, + called_sequence, points_to_analysis, + module_sequence)); + } + } + + // Other sub-computations (e.g. Map, Reduce, ...) are skipped; they are + // assigned "thread-local" allocations, meaning their buffers are not + // allocated up-front at the beginning of the computation. + } + // Free buffers that are no longer live. This is the earliest point that we // can de-allocate; right after the last use of the buffer. for (const LogicalBuffer* buffer : dead_buffers_to_free) { - heap.Free(buffer); + Free(buffer); } for (const LogicalBuffer* buffer : operand_buffers_to_free) { - heap.Free(buffer); + Free(buffer); } } @@ -187,10 +244,10 @@ StatusOr HeapSimulator::Run( const FlatSet& pending = buffer_pending.second; CHECK_EQ(pending.size(), 1) << *buffer; CHECK(*pending.begin() == nullptr) << *buffer; - heap.Free(buffer); + Free(buffer); } - return heap.Finish(); + return Status::OK(); } HeapSimulator::HeapSimulator( @@ -309,6 +366,11 @@ HeapSimulator::Result HeapSimulator::Finish() { result.chunk_map.emplace(buffer, chunk); } } + // If we were told to assign specific buffers, make sure we've assigned + // exactly that many buffers. + if (buffers_to_assign_ != nullptr) { + CHECK_EQ(buffers_to_assign_->size(), result.chunk_map.size()); + } } // Fragmentation is the difference between the actual and ideal sizes. diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index 0ce2906767898bcace45e296d76f958c50a2b3a7..3d98046261902b41a17a8ab0f9a349634a1e4545 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" @@ -63,17 +64,32 @@ class HeapSimulator { }; // Run the heap simulation with the given algorithm, assuming the given - // sequential ordering of instructions. The 'instruction_sequence' must - // contain a topologically-consistent total ordering of all instructions in - // the computation. The result is invalid if instructions are not run in - // exactly this sequence. + // module_sequence, which must contain a topologically-consistent total + // ordering of all instructions within each computation. The result is invalid + // if instructions are not run in exactly this sequence. + // + // Running heap simulation on the whole module tends to save memory, compared + // to running on a per-computation basis, since we can re-use buffer space for + // called sub-computations. // // If 'buffers_to_assign' is provided, only those buffers are assigned // offsets, otherwise all buffers defined by the instructions are assigned. + static StatusOr Run( + std::unique_ptr algorithm, const HloModule& module, + const SequentialHloOrdering::HloModuleSequence& module_sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_fn, + const tensorflow::gtl::FlatSet* buffers_to_assign = + nullptr); + + // Same as above, but runs on a single computation. The 'instruction_sequence' + // must contain a topologically-consistent total ordering of all instructions + // in the computation. The result is invalid if instructions are not run in + // exactly this sequence. static StatusOr Run( std::unique_ptr algorithm, - const std::vector& instruction_sequence, const HloComputation& computation, + const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_fn, const tensorflow::gtl::FlatSet* buffers_to_assign = @@ -86,6 +102,12 @@ class HeapSimulator { const tensorflow::gtl::FlatSet* buffers_to_assign); ~HeapSimulator(); + Status RunComputation( + const HloComputation& computation, + const std::vector& instruction_sequence, + const TuplePointsToAnalysis& points_to_analysis, + const SequentialHloOrdering::HloModuleSequence* module_sequence); + bool IgnoreBuffer(const LogicalBuffer* buffer) const; void Alloc(const LogicalBuffer* buffer); void Free(const LogicalBuffer* buffer); diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 874bd5f1060c179d5547510c351909069aa935b8..0a6900f73304f7a7b1209807fd3a1e8220484e03 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -19,13 +19,16 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { namespace { @@ -69,6 +72,7 @@ class HeapCallRecorder : public HeapAlgorithm { // sequence against an expected sequence. class HeapSimulatorTracker { public: + // Constructor for testing a single entry computation. HeapSimulatorTracker( const string& name, std::unique_ptr computation, const std::vector& instruction_sequence) { @@ -83,12 +87,48 @@ class HeapSimulatorTracker { auto zero_size = [](const LogicalBuffer& buffer) { return 0; }; auto algorithm = MakeUnique( MakeUnique(&actual_calls_)); - result_ = HeapSimulator::Run(std::move(algorithm), instruction_sequence, - *module_->entry_computation(), - *points_to_analysis_, zero_size) + result_ = HeapSimulator::Run( + std::move(algorithm), *module_->entry_computation(), + instruction_sequence, *points_to_analysis_, zero_size) .ConsumeValueOrDie(); } + explicit HeapSimulatorTracker(const string& name) { + module_ = MakeUnique(name); + } + + // Similar to the single entry computation constructor above, but runs the + // simulation over the entire module. + void RunWholeModule( + const std::vector& full_module_sequence) { + points_to_analysis_ = + TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); + + // Construct the module sequence grouped by computation. + SequentialHloOrdering::HloModuleSequence module_sequence; + tensorflow::gtl::FlatMap reverse_position; + for (int i = 0; i < full_module_sequence.size(); ++i) { + const HloInstruction* instruction = full_module_sequence[i]; + module_sequence[instruction->parent()].push_back(instruction); + reverse_position[instruction] = full_module_sequence.size() - i; + } + + // Hack the size_fn so that it returns a decreasing value as we step through + // the sequence. This lets us ensure the Alloc calls are in the sequence + // order. The Free calls are sorted by LogicalBuffer.id, which is at least + // deterministic. + auto size_fn = [&reverse_position](const LogicalBuffer& buffer) { + return reverse_position[buffer.instruction()]; + }; + auto algorithm = MakeUnique( + MakeUnique(&actual_calls_)); + result_ = HeapSimulator::Run(std::move(algorithm), *module_, + module_sequence, *points_to_analysis_, size_fn) + .ConsumeValueOrDie(); + } + + HloModule* module() { return module_.get(); } + // Returns the buffer defined at the given instruction and index. const LogicalBuffer* BufferAt(const HloInstruction* instruction, const ShapeIndex& index) const { @@ -358,6 +398,86 @@ TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) { }); } +TEST_F(HeapSimulatorTest, WholeModule) { + HeapSimulatorTracker tracker(TestName()); + + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); + + auto cond_builder = HloComputation::Builder("WhileCond"); + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); + HloInstruction* cond_iter = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0)); + HloInstruction* cond_data = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); + HloInstruction* cond_lt = cond_builder.AddInstruction( + HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), + HloOpcode::kLt, cond_iter, cond_data)); + HloComputation* cond_computation = + tracker.module()->AddEmbeddedComputation(cond_builder.Build()); + + auto body_builder = HloComputation::Builder("WhileBody"); + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "body_param")); + HloComputation* body_computation = + tracker.module()->AddEmbeddedComputation(body_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( + tuple_shape, cond_computation, body_computation, param)); + tracker.module()->AddEntryComputation(builder.Build()); + + tracker.RunWholeModule( + {param, while_op, body_param, cond_param, cond_iter, cond_data, cond_lt}); + tracker.ExpectCallSequence({ + // The entry computation param and while_op are allocated first. + {kAlloc, tracker.BufferAt(param, {})}, + {kAlloc, tracker.BufferAt(param, {0})}, + {kAlloc, tracker.BufferAt(param, {1})}, + {kAlloc, tracker.BufferAt(while_op, {})}, + {kAlloc, tracker.BufferAt(while_op, {0})}, + {kAlloc, tracker.BufferAt(while_op, {1})}, + + // Now the while body param is allocated and freed. + {kAlloc, tracker.BufferAt(body_param, {})}, + {kAlloc, tracker.BufferAt(body_param, {0})}, + {kAlloc, tracker.BufferAt(body_param, {1})}, + {kFree, tracker.BufferAt(body_param, {})}, + {kFree, tracker.BufferAt(body_param, {0})}, + {kFree, tracker.BufferAt(body_param, {1})}, + + // Now the while cond param is allocated. The GTE instructions just alias + // the param elements, so the param tuple can immediately be freed. + {kAlloc, tracker.BufferAt(cond_param, {})}, + {kAlloc, tracker.BufferAt(cond_param, {0})}, + {kAlloc, tracker.BufferAt(cond_param, {1})}, + {kFree, tracker.BufferAt(cond_param, {})}, + + // Now the final cond less-than buffer is allocated. + {kAlloc, tracker.BufferAt(cond_lt, {})}, + + // The order of the remaining Free calls is based on the LogicalBuffer.id, + // which is deterministic, but not obvious. + {kFree, tracker.BufferAt(param, {})}, + {kFree, tracker.BufferAt(param, {0})}, + {kFree, tracker.BufferAt(param, {1})}, + + {kFree, tracker.BufferAt(while_op, {})}, + {kFree, tracker.BufferAt(while_op, {0})}, + {kFree, tracker.BufferAt(while_op, {1})}, + + {kFree, tracker.BufferAt(cond_param, {0})}, + {kFree, tracker.BufferAt(cond_param, {1})}, + {kFree, tracker.BufferAt(cond_lt, {})}, + + {kFinish, nullptr}, + }); +} + // Base class for heap algorithm tests. class HeapAlgorithmTestBase : public ::testing::Test { protected: diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 89371e44973a55811af436b1f1d42f8f40b02159..a749814f0dfbfbacb7c09be815ef572bb00687c0 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -35,10 +35,14 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { +using ::tensorflow::strings::StrCat; + std::unique_ptr HloComputation::Builder::Build( HloInstruction* root_instruction) { int parameter_count = 0; @@ -91,12 +95,7 @@ HloInstruction* HloComputation::AddInstruction( HloInstruction* HloComputation::AddInstructionInternal( std::unique_ptr instruction) { // Generate a unique name for the instruction. - instruction->set_name( - instruction_name_uniquer_.GetUniqueName(instruction->name())); - if (instruction->opcode() == HloOpcode::kParameter) { - instruction->set_parameter_name( - instruction_name_uniquer_.GetUniqueName(instruction->parameter_name())); - } + instruction->UniquifyName(&instruction_name_uniquer_); Reparent(instruction.get()); HloInstruction* pinst = instruction.get(); instruction_iterators_[pinst] = @@ -131,9 +130,24 @@ Status HloComputation::RemoveParameter(int64 param_no) { while (param_no < param_instructions_.size()) { param_instruction = param_instructions_[param_no]; - HloInstruction* new_instr = AddInstructionInternal( - HloInstruction::CreateParameter(param_no, param_instruction->shape(), - param_instruction->parameter_name())); + string param_name = param_instruction->name(); + // Fusion parameters are named foo.param_1, bar.param_2, etc. We are + // renumbering the parameters so replace the final number in the name with + // the updated value. + const string param_underscore = ".param_"; + size_t index = param_name.rfind(param_underscore); + if (index == string::npos) { + string after_param = name().substr(index + param_underscore.size()); + int64 numeric_suffix; + if (tensorflow::strings::safe_strto64(after_param, &numeric_suffix)) { + param_name = + StrCat(param_name.substr(0, index), param_underscore, param_no); + } + } + + HloInstruction* new_instr = + AddInstructionInternal(HloInstruction::CreateParameter( + param_no, param_instruction->shape(), param_name)); TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr)); new_instr->SetParentFusion(root_instruction_->fusion_instruction()); param_instructions_[param_no] = new_instr; @@ -672,4 +686,8 @@ std::unique_ptr HloComputation::Clone(const string& suffix) { return result; } +void HloComputation::UniquifyName(NameUniquer* name_uniquer) { + name_ = name_uniquer->GetUniqueName(name_); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index fa274cfc6331343eeb22684c0d3f5c7f284dec76..62e00a24fbb523e1e30f08141f9e026407a2015d 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -133,7 +133,10 @@ class HloComputation { } const string& name() const { return name_; } - void set_name(const string& name) { name_ = name; } + + // Use the given NameUniquer to select a unique name for the computation based + // on the computation's existing name. + void UniquifyName(NameUniquer* name_uniquer); // Return a string representation of the computation. string ToString(int nested_level = 0) const; diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 9a5345dc13d6db42553e9c343f7c81cd0e6c9d0e..cb0a99d773c57ba9a2fedc2842fe17cd5fe3571e 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -15,16 +15,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" -#include -#include #include -#include #include #include #include #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -34,52 +32,222 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" namespace xla { +namespace { + +template +static std::unique_ptr ConvertIfTypesMatch( + const Literal& src_literal) { + CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); + return LiteralUtil::Convert< + typename primitive_util::PrimitiveTypeToNative::type, + typename primitive_util::PrimitiveTypeToNative< + primitive_dest_type>::type>(src_literal); +} + +template +static std::unique_ptr ConvertIfDestTypeMatches( + const Literal& src_literal, PrimitiveType primitive_dest_type) { + switch (primitive_dest_type) { +#define CONVERT_IF_TYPES_MATCH(type) \ + case (type): \ + return ConvertIfTypesMatch(src_literal); + CONVERT_IF_TYPES_MATCH(PRED) + CONVERT_IF_TYPES_MATCH(S8) + CONVERT_IF_TYPES_MATCH(S32) + CONVERT_IF_TYPES_MATCH(S64) + CONVERT_IF_TYPES_MATCH(U8) + CONVERT_IF_TYPES_MATCH(U32) + CONVERT_IF_TYPES_MATCH(U64) + CONVERT_IF_TYPES_MATCH(F32) + CONVERT_IF_TYPES_MATCH(F64) +#undef CONVERT_IF_TYPES_MATCH + // Other types are not yet supported. + default: + LOG(FATAL) << "Unimplemented: ConvertIfDestTypeMatches for type " + << PrimitiveType_Name(src_literal.shape().element_type()); + } +} + +static std::unique_ptr ConvertIfSrcTypeMatches( + const Literal& src_literal, PrimitiveType primitive_dest_type) { + switch (src_literal.shape().element_type()) { +#define CONVERT_IF_DEST_TYPE_MATCHES(type) \ + case (type): \ + return ConvertIfDestTypeMatches<(type)>(src_literal, primitive_dest_type); + CONVERT_IF_DEST_TYPE_MATCHES(PRED) + CONVERT_IF_DEST_TYPE_MATCHES(S8) + CONVERT_IF_DEST_TYPE_MATCHES(S32) + CONVERT_IF_DEST_TYPE_MATCHES(S64) + CONVERT_IF_DEST_TYPE_MATCHES(U8) + CONVERT_IF_DEST_TYPE_MATCHES(U32) + CONVERT_IF_DEST_TYPE_MATCHES(U64) + CONVERT_IF_DEST_TYPE_MATCHES(F32) + CONVERT_IF_DEST_TYPE_MATCHES(F64) +#undef CONVERT_IF_DEST_TYPE_MATCHES + // Other types are not yet supported. + default: + LOG(FATAL) << "Unimplemented: ConvertIfSrcTypeMatches for type " + << PrimitiveType_Name(src_literal.shape().element_type()); + } +} + +} // namespace + +// ConstantFolderVisitor traverses the HLO computation and reduces certain +// constant graph sections, to literals. +class ConstantFolderVisitor : public DfsHloVisitorWithDefault { + public: + // Default visitor action is to do nothing and return OK. + Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { + return Status::OK(); + } + + Status HandleConcatenate( + HloInstruction* concatenate, + tensorflow::gtl::ArraySlice operands) override; + + Status HandleConvert(HloInstruction* convert, + HloInstruction* operand) override; + + Status HandleReshape(HloInstruction* reshape) override; + + Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override; + + Status HandleTranspose(HloInstruction* transpose) override; + + // Returns whether a constant folding operation has occurred. + const bool changed() const { return changed_; } + + // Runs the visitor on a computation and returns whether any changes were + // performed. + static StatusOr Run(HloComputation* computation); + + private: + ConstantFolderVisitor() = default; + + // Replaces the existing HLO instruction old_instruction, with a literal, + // and marks the optimizer status as changed. + // Returns the Status representing the result of the replace operation. + Status ReplaceWithConstant(HloInstruction* old_instruction, + std::unique_ptr literal) { + TF_RETURN_IF_ERROR(old_instruction->parent()->ReplaceWithNewInstruction( + old_instruction, HloInstruction::CreateConstant(std::move(literal)))); + changed_ = true; + return Status::OK(); + } + + // Whether any constant folding operations have occurred. + bool changed_ = false; +}; + +StatusOr ConstantFolderVisitor::Run(HloComputation* computation) { + ConstantFolderVisitor visitor; + TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + return visitor.changed(); +} StatusOr HloConstantFolding::Run(HloModule* module) { + XLA_VLOG_LINES(2, + "HloConstantFolding::Run(), before:\n" + module->ToString()); bool changed = false; - for (auto& computation : module->computations()) { - for (auto instruction : computation->MakeInstructionPostOrder()) { - // Skip dead code. - if (instruction->user_count() == 0 && - computation->root_instruction() != instruction) { - continue; - } - // Depending on the opcode, choose how to handle constant operands. - // - // TODO(b/35975797): Fold constant computations for more than reshapes and - // transposes. - switch (instruction->opcode()) { - case HloOpcode::kReshape: { - if (instruction->operand(0)->opcode() == HloOpcode::kConstant) { - TF_ASSIGN_OR_RETURN( - auto reshaped_literal, - LiteralUtil::Reshape( - instruction->operand(0)->literal(), - AsInt64Slice(instruction->shape().dimensions()))); - TF_CHECK_OK(computation->ReplaceWithNewInstruction( - instruction, - HloInstruction::CreateConstant(std::move(reshaped_literal)))); - changed = true; - } - break; - } - case HloOpcode::kTranspose: { - if (instruction->operand(0)->opcode() == HloOpcode::kConstant) { - auto transposed_literal = LiteralUtil::Transpose( - instruction->operand(0)->literal(), instruction->dimensions()); - TF_CHECK_OK(computation->ReplaceWithNewInstruction( - instruction, - HloInstruction::CreateConstant(std::move(transposed_literal)))); - changed = true; - } - break; - } - default: - break; + for (auto& comp : module->computations()) { + TF_ASSIGN_OR_RETURN(bool result, ConstantFolderVisitor::Run(comp.get())); + changed = changed || result; + } + XLA_VLOG_LINES(2, "HloConstantFolding::Run(), after:\n" + module->ToString()); + return changed; +} + +Status ConstantFolderVisitor::HandleReshape(HloInstruction* reshape) { + if (reshape->operand(0)->opcode() == HloOpcode::kConstant) { + TF_ASSIGN_OR_RETURN( + auto reshaped_literal, + LiteralUtil::Reshape(reshape->operand(0)->literal(), + AsInt64Slice(reshape->shape().dimensions()))); + return ReplaceWithConstant(reshape, std::move(reshaped_literal)); + } + return Status::OK(); +} + +Status ConstantFolderVisitor::HandleTranspose(HloInstruction* transpose) { + if (transpose->operand(0)->opcode() == HloOpcode::kConstant) { + auto transposed_literal = LiteralUtil::Transpose( + transpose->operand(0)->literal(), transpose->dimensions()); + return ReplaceWithConstant(transpose, std::move(transposed_literal)); + } + return Status::OK(); +} + +Status ConstantFolderVisitor::HandleConcatenate( + HloInstruction* concatenate, + tensorflow::gtl::ArraySlice operands) { + if (operands[0]->opcode() == HloOpcode::kConstant) { + // If all the operands of a concatenate are constant, fold them into a + // single constant tensor. + // The result concatenate dimension is going to be the sum of all the + // concatenate dimensions of the arrays taking part of the operation. + int64 concat_dim = concatenate->dimensions()[0]; + const Shape& reference_shape = operands[0]->shape(); + CHECK(!ShapeUtil::IsTuple(reference_shape)); + int64 rank = ShapeUtil::Rank(reference_shape); + std::vector concat_dimensions(reference_shape.dimensions().begin(), + reference_shape.dimensions().end()); + if (concat_dim < 0) { + concat_dim += rank; + } + for (int64 i = 1; i < operands.size(); ++i) { + const Shape& operand_shape = operands[i]->shape(); + CHECK(!ShapeUtil::IsTuple(operand_shape)); + if (operands[i]->opcode() != HloOpcode::kConstant) { + return Status::OK(); } + // Accumulate the concat dimension from all tensors taking part to the + // operation. + concat_dimensions[concat_dim] += + ShapeUtil::GetDimension(operand_shape, concat_dim); + } + + auto literal = LiteralUtil::CreateFromDimensions( + reference_shape.element_type(), concat_dimensions); + std::vector source_indices(rank, 0); + std::vector dest_indices(concat_dimensions.size(), 0); + for (auto operand : operands) { + const Shape& operand_shape = operand->shape(); + TF_RETURN_IF_ERROR(LiteralUtil::Copy( + operand->literal(), source_indices, literal.get(), dest_indices, + AsInt64Slice(operand_shape.dimensions()))); + dest_indices[concat_dim] += + ShapeUtil::GetDimension(operand_shape, concat_dim); } + return ReplaceWithConstant(concatenate, std::move(literal)); } - return changed; + return Status::OK(); +} + +Status ConstantFolderVisitor::HandleSlice(HloInstruction* slice, + HloInstruction* operand) { + if (operand->opcode() == HloOpcode::kConstant) { + const Shape& shape = slice->shape(); + auto literal = LiteralUtil::CreateFromDimensions( + shape.element_type(), AsInt64Slice(shape.dimensions())); + std::vector dest_indices(slice->slice_starts().size(), 0); + TF_RETURN_IF_ERROR(LiteralUtil::Copy( + operand->literal(), slice->slice_starts(), literal.get(), dest_indices, + AsInt64Slice(shape.dimensions()))); + TF_RETURN_IF_ERROR(ReplaceWithConstant(slice, std::move(literal))); + } + return Status::OK(); +} + +Status ConstantFolderVisitor::HandleConvert(HloInstruction* convert, + HloInstruction* operand) { + if (operand->opcode() == HloOpcode::kConstant) { + const Literal& src_literal = operand->literal(); + std::unique_ptr new_constant = + ConvertIfSrcTypeMatches(src_literal, convert->shape().element_type()); + return ReplaceWithConstant(convert, std::move(new_constant)); + } + return Status::OK(); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.h b/tensorflow/compiler/xla/service/hlo_constant_folding.h index 514bb8164c1e1fa10a36ceeeac63dc946de2ab5a..f45eccf825389609323eed9c5180dc385edc3092 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.h +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.h @@ -25,12 +25,10 @@ namespace xla { // computation on constants. class HloConstantFolding : public HloPassInterface { public: - explicit HloConstantFolding() {} - ~HloConstantFolding() override {} tensorflow::StringPiece name() const override { return "constant_folding"; } - // Run ConstantFolding on the given module. Returns whether the module was - // changed (common subexpressions were found and eliminated). + // Run constant folding operations on the given module. Returns whether the + // module was changed (constant expressions folded). StatusOr Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a56225da156dfc0a44b6a4b99191a3c7e706561f --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -0,0 +1,213 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_constant_folding.h" + +#include +#include + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_pass_fix.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/types.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace { + +using HloConstantFoldingTest = HloTestBase; + +TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { + HloComputation::Builder builder(TestName()); + HloInstruction* input = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + builder.AddInstruction( + HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {}), input)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Convert(input)); + + HloConstantFolding const_folder; + TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + EXPECT_TRUE(result); + + EXPECT_THAT(computation->root_instruction(), op::Constant()); + EXPECT_EQ(LiteralUtil::GetFirstElement( + computation->root_instruction()->literal()), + 42); +} + +TEST_F(HloConstantFoldingTest, ConvertS64ToF32) { + HloComputation::Builder builder(TestName()); + HloInstruction* input = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); + builder.AddInstruction( + HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Convert(input)); + + HloConstantFolding const_folder; + TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + EXPECT_TRUE(result); + + EXPECT_THAT(computation->root_instruction(), op::Constant()); + EXPECT_EQ(LiteralUtil::GetFirstElement( + computation->root_instruction()->literal()), + 42.0f); +} + +TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) { + HloComputation::Builder builder(TestName()); + HloInstruction* input = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({42.0f, 19.0f}))); + builder.AddInstruction( + HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {2}), input)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Convert(input)); + + HloConstantFolding const_folder; + TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + EXPECT_TRUE(result); + + EXPECT_THAT(computation->root_instruction(), op::Constant()); + EXPECT_EQ( + LiteralUtil::Get(computation->root_instruction()->literal(), {0}), + 42); + EXPECT_EQ( + LiteralUtil::Get(computation->root_instruction()->literal(), {1}), + 19); +} + +TEST_F(HloConstantFoldingTest, Concatenate) { + const struct TestConfig { + int concat_dimension; + tensorflow::gtl::ArraySlice dimensions; + tensorflow::gtl::ArraySlice concat_sizes; + } test_configs[] = { + {1, {11, 0, 7, 5, 9}, {2, 5, 7, 11}}, + {3, {1, 4, 17, 0, 8}, {1, 3, 9, 12}}, + }; + + for (auto& test_config : test_configs) { + HloComputation::Builder builder(TestName()); + std::vector dimensions(test_config.dimensions.begin(), + test_config.dimensions.end()); + int64 concat_size = 0; + std::vector operands; + for (auto csize : test_config.concat_sizes) { + dimensions[test_config.concat_dimension] = csize; + concat_size += csize; + auto literal = LiteralUtil::CreateFromDimensions(F32, dimensions); + HloInstruction* insn = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); + operands.push_back(insn); + } + dimensions[test_config.concat_dimension] = concat_size; + Shape shape = ShapeUtil::MakeShape(F32, dimensions); + builder.AddInstruction(HloInstruction::CreateConcatenate( + shape, operands, test_config.concat_dimension)); + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + HloConstantFolding const_folder; + TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + EXPECT_TRUE(result); + + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Constant()); + EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape)); + } +} + +TEST_F(HloConstantFoldingTest, Slice) { + HloComputation::Builder builder(TestName()); + const int64 dimensions[] = {11, 8, 7, 5, 9}; + const int64 slice_start[] = {4, 2, 3, 1, 5}; + const int64 slice_limits[] = {10, 8, 6, 5, 9}; + TF_ASSIGN_OR_ASSERT_OK(auto literal, + LiteralTestUtil::CreateRandomLiteral( + ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); + HloInstruction* literal_instruction = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); + Shape shape = ShapeUtil::MakeShape(F32, {6, 6, 3, 4, 4}); + builder.AddInstruction(HloInstruction::CreateSlice( + shape, literal_instruction, slice_start, slice_limits)); + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + HloConstantFolding const_folder; + TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + EXPECT_TRUE(result); + + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Constant()); + EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape)); +} + +TEST_F(HloConstantFoldingTest, TransposeConstantFold) { + HloComputation::Builder builder(TestName()); + const int64 dimensions[] = {11, 8, 7, 5, 9}; + TF_ASSIGN_OR_ASSERT_OK(auto literal, + LiteralTestUtil::CreateRandomLiteral( + ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); + auto literal_clone = LiteralUtil::CloneToUnique(*literal); + HloInstruction* literal_instruction = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); + Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5}); + const int64 permutation[] = {1, 2, 0, 4, 3}; + builder.AddInstruction( + HloInstruction::CreateTranspose(shape, literal_instruction, permutation)); + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + HloConstantFolding const_folder; + TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + EXPECT_TRUE(result); + + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Constant()); + EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), shape)); + + using NativeT = typename primitive_util::PrimitiveTypeToNative::type; + bool matched = true; + LiteralUtil::EachCell( + root->literal(), + [&](tensorflow::gtl::ArraySlice indices, NativeT value) { + std::vector rindexes = Permute(permutation, indices); + matched = matched && (value == LiteralUtil::Get(*literal_clone, + rindexes)); + }); + EXPECT_TRUE(matched); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index ec8161f55fd56c95bb088a0c539255aed2fe6993..9444382b5270b0f76fa33b598297d24572e5b2c9 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -36,6 +37,8 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/types.h" +namespace op = xla::testing::opcode_matchers; + namespace xla { namespace { @@ -88,13 +91,15 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, computation->instruction_count()); - EXPECT_NE(add->operand(0), add->operand(1)); + EXPECT_THAT(add, op::Add(constant1, constant2)); HloCSE cse(/*is_layout_sensitive=*/false); EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(2, computation->instruction_count()); - EXPECT_EQ(add->operand(0), add->operand(1)); + auto first_operand = add->operand(0); + EXPECT_THAT(first_operand, ::testing::AnyOf(constant1, constant2)); + EXPECT_THAT(add, op::Add(first_operand, first_operand)); auto result = ExecuteAndTransfer(std::move(module), {}); auto expected = LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); @@ -118,15 +123,13 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, computation->instruction_count()); - EXPECT_EQ(constant1, add->operand(0)); - EXPECT_EQ(constant2, add->operand(1)); + EXPECT_THAT(add, op::Add(constant1, constant2)); HloCSE cse(/*is_layout_sensitive=*/true); EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); - EXPECT_EQ(constant1, add->operand(0)); - EXPECT_EQ(constant2, add->operand(1)); + EXPECT_THAT(add, op::Add(constant1, constant2)); auto result = ExecuteAndTransfer(std::move(module), {}); auto expected = LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); @@ -185,16 +188,18 @@ TEST_F(HloCseTest, NonscalarConstants) { auto computation = module.AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); + EXPECT_THAT(tuple, + op::Tuple(common_constant1, common_constant2, uncommon_constant)); HloCSE cse(/*is_layout_sensitive=*/false); EXPECT_TRUE(cse.Run(&module).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); - - EXPECT_EQ(tuple->operand(0), tuple->operand(1)); - EXPECT_EQ(uncommon_constant, tuple->operand(2)); - EXPECT_TRUE(tuple->operand(0) == common_constant1 || - tuple->operand(0) == common_constant2); + auto first_operand = tuple->operand(0); + EXPECT_THAT(first_operand, + ::testing::AnyOf(common_constant1, common_constant2)); + EXPECT_THAT(tuple, + op::Tuple(first_operand, first_operand, uncommon_constant)); } TEST_F(HloCseTest, IdenticalInstructions) { @@ -215,16 +220,15 @@ TEST_F(HloCseTest, IdenticalInstructions) { auto computation = module.AddEntryComputation(builder.Build()); EXPECT_EQ(5, computation->instruction_count()); - EXPECT_NE(tuple->operand(0), tuple->operand(1)); - EXPECT_NE(tuple->operand(1), tuple->operand(2)); - EXPECT_NE(tuple->operand(0), tuple->operand(2)); + EXPECT_THAT(tuple, op::Tuple(exp1, exp2, exp3)); HloCSE cse(/*is_layout_sensitive=*/false); EXPECT_TRUE(cse.Run(&module).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); - EXPECT_EQ(tuple->operand(0), tuple->operand(1)); - EXPECT_EQ(tuple->operand(1), tuple->operand(2)); + auto first_operand = tuple->operand(0); + EXPECT_THAT(first_operand, ::testing::AnyOf(exp1, exp2, exp3)); + EXPECT_THAT(tuple, op::Tuple(first_operand, first_operand, first_operand)); } TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) { @@ -249,13 +253,13 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) { auto computation = module.AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); - EXPECT_NE(tuple->operand(0), tuple->operand(1)); + EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); HloCSE cse(/*is_layout_sensitive=*/true); EXPECT_FALSE(cse.Run(&module).ValueOrDie()); EXPECT_EQ(4, computation->instruction_count()); - EXPECT_NE(tuple->operand(0), tuple->operand(1)); + EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); } TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) { @@ -280,13 +284,15 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) { auto computation = module.AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); - EXPECT_NE(tuple->operand(0), tuple->operand(1)); + EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); HloCSE cse(/*is_layout_sensitive=*/false); EXPECT_TRUE(cse.Run(&module).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); - EXPECT_EQ(tuple->operand(0), tuple->operand(1)); + auto first_operand = tuple->operand(0); + EXPECT_THAT(first_operand, ::testing::AnyOf(exp1, exp2)); + EXPECT_THAT(tuple, op::Tuple(first_operand, first_operand)); } TEST_F(HloCseTest, IdenticalExpressions) { @@ -328,14 +334,15 @@ TEST_F(HloCseTest, IdenticalExpressions) { auto computation = module.AddEntryComputation(builder.Build()); EXPECT_EQ(8, computation->instruction_count()); - EXPECT_NE(tuple->operand(0), tuple->operand(1)); + EXPECT_THAT(tuple, op::Tuple(op::Add(negate1, exp1), op::Add(negate2, exp2))); HloCSE cse(/*is_layout_sensitive=*/false); EXPECT_TRUE(cse.Run(&module).ValueOrDie()); EXPECT_EQ(5, computation->instruction_count()); - EXPECT_EQ(tuple->operand(0), tuple->operand(1)); - EXPECT_EQ(HloOpcode::kAdd, tuple->operand(0)->opcode()); + auto operand = tuple->operand(0); + EXPECT_THAT(tuple, op::Tuple(operand, operand)); + EXPECT_THAT(operand, op::Add(op::Negate(), op::Exp())); } TEST_F(HloCseTest, DoNotCombineRng) { @@ -351,12 +358,16 @@ TEST_F(HloCseTest, DoNotCombineRng) { auto rng2 = builder.AddInstruction(HloInstruction::CreateRng( ShapeUtil::MakeShape(F32, {}), RandomDistribution::RNG_UNIFORM, {constant1, constant2})); + builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, rng1, rng2)); auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Add(rng1, rng2)); + uint32 count_before = computation->instruction_count(); HloCSE cse(/*is_layout_sensitive=*/false); @@ -364,11 +375,8 @@ TEST_F(HloCseTest, DoNotCombineRng) { uint32 count_after = computation->instruction_count(); EXPECT_EQ(count_before, count_after); - HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kAdd); - EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kRng); - EXPECT_EQ(root->operand(1)->opcode(), HloOpcode::kRng); - EXPECT_NE(root->operand(0), root->operand(1)); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Add(rng1, rng2)); } // TODO(b/28245743): Handle impure functions correctly in CSE. @@ -412,16 +420,17 @@ TEST_F(HloCseTest, DISABLED_DoNotCombineCallsToImpureFunctions) { } EXPECT_EQ(4, computation->instruction_count()); + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Add(op::Map(), op::Map())); HloCSE cse(/*is_layout_sensitive=*/false); EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(4, computation->instruction_count()); - HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kAdd); - EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kMap); - EXPECT_EQ(root->operand(1)->opcode(), HloOpcode::kMap); - EXPECT_NE(root->operand(0), root->operand(1)); + root = computation->root_instruction(); + auto operand = root->operand(0)->operand(0); + EXPECT_THAT(operand, op::Map()); + EXPECT_THAT(root, op::Add(operand, operand)); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index ebe74280525010436423163f746ddee6a23dc7e1..e0447d69aa2229e2cb391aac8b2afa8fde6145c1 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -26,342 +26,532 @@ limitations under the License. #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/bitmap.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" namespace xla { -namespace { +template +class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { + public: + explicit TypedVisitor(HloEvaluator* p) : parent_(p) {} -template -std::unique_ptr ElementWiseUnaryOp( - const Shape& shape, std::function&& unary_op, - const Literal& operand) { - DCHECK(ShapeUtil::SameDimensions(shape, operand.shape())); + Status DefaultAction(HloInstruction* hlo_instruction) override { + return Unimplemented("unhandled HLO ops for HloEvaluator: %s.", + HloOpcodeString(hlo_instruction->opcode()).c_str()); + }; - auto result = MakeUnique(); - *result->mutable_shape() = shape; - LiteralUtil::Reserve(ShapeUtil::ElementsIn(shape), result.get()); + // TODO(b/35950897): many of the stl functions used in the handlers are not + // overloaded for every XLA primitive types. - std::vector multi_index(ShapeUtil::Rank(result->shape()), 0); - do { - LiteralUtil::Set( - result.get(), multi_index, - unary_op(LiteralUtil::Get(operand, multi_index))); - } while (IndexUtil::BumpIndices(result->shape(), &multi_index)); + template ::value>::type* = + nullptr> + Status HandleAbs(HloInstruction* abs, HloInstruction* operand) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], + ElementWiseUnaryOp(abs, [](NativeT elem_operand) { + return elem_operand; + })); + return Status::OK(); + }; - return result; -} + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleAbs(HloInstruction* abs, HloInstruction* operand) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], + ElementWiseUnaryOp(abs, [](NativeT elem_operand) { + return std::abs(elem_operand); + })); + return Status::OK(); + }; -template -std::unique_ptr ElementWiseBinaryOp( - const Shape& shape, std::function&& binary_op, - const Literal& lhs, const Literal& rhs) { - DCHECK(ShapeUtil::SameDimensions(shape, rhs.shape())); - DCHECK(ShapeUtil::SameDimensions(lhs.shape(), rhs.shape())); - - auto result = MakeUnique(); - *result->mutable_shape() = shape; - LiteralUtil::Reserve(ShapeUtil::ElementsIn(shape), result.get()); - - std::vector multi_index(ShapeUtil::Rank(result->shape()), 0); - do { - LiteralUtil::Set( - result.get(), multi_index, - binary_op(LiteralUtil::Get(lhs, multi_index), - LiteralUtil::Get(rhs, multi_index))); - } while (IndexUtil::BumpIndices(result->shape(), &multi_index)); - - return result; -} + Status HandleAbs(HloInstruction* abs, HloInstruction* operand) override { + return HandleAbs(abs, operand); + }; -template -std::unique_ptr ElementWiseTernaryOp( - const Shape& shape, - std::function&& ternary_op, - const Literal& lhs, const Literal& rhs, const Literal& ehs) { - DCHECK(ShapeUtil::SameDimensions(shape, lhs.shape())); - DCHECK(ShapeUtil::SameDimensions(lhs.shape(), rhs.shape())); - DCHECK(ShapeUtil::SameDimensions(rhs.shape(), ehs.shape())); - - auto result = MakeUnique(); - *result->mutable_shape() = shape; - LiteralUtil::Reserve(ShapeUtil::ElementsIn(shape), result.get()); - - std::vector multi_index(ShapeUtil::Rank(result->shape()), 0); - do { - LiteralUtil::Set( - result.get(), multi_index, - ternary_op(LiteralUtil::Get(lhs, multi_index), - LiteralUtil::Get(rhs, multi_index), - LiteralUtil::Get(ehs, multi_index))); - } while (IndexUtil::BumpIndices(result->shape(), &multi_index)); - - return result; -} + Status HandleCeil(HloInstruction* ceil, HloInstruction* operand) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[ceil], + ElementWiseUnaryOp(ceil, [](ReturnT elem_operand) { + return std::ceil(elem_operand); + })); + return Status::OK(); + }; -// Templated abs so that unsigned types can be passed in without warning. -template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> -NativeT AbsoluteVal(NativeT value) { - return value; -} + Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[copy], + ElementWiseUnaryOp(copy, [](ReturnT elem_operand) { + return elem_operand; + })); + return Status::OK(); + }; -template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> -NativeT AbsoluteVal(NativeT value) { - return std::abs(value); -} + Status HandleExp(HloInstruction* exp, HloInstruction* operand) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp], + ElementWiseUnaryOp(exp, [](ReturnT elem_operand) { + return std::exp(elem_operand); + })); + return Status::OK(); + }; -template -StatusOr> EvaluateOpForLiteralInternal( - HloInstruction* instruction) { - DCHECK(hlo_query::AllOperandsAreConstants(*instruction)); - - const std::vector& operands = instruction->operands(); - HloOpcode opcode = instruction->opcode(); - const Shape& shape = instruction->shape(); - - switch (opcode) { - // TODO(b/35950897): many of the stl function used here are not overloaded - // for all XLA primitive types. - // Unary element-wise ops. - case HloOpcode::kAbs: - CHECK_EQ(operands.size(), 1); - return ElementWiseUnaryOp( - shape, [](NativeT operand) { return AbsoluteVal(operand); }, - operands[0]->literal()); - case HloOpcode::kCeil: - CHECK_EQ(operands.size(), 1); - return ElementWiseUnaryOp( - shape, [](NativeT operand) { return std::ceil(operand); }, - operands[0]->literal()); - case HloOpcode::kConvert: - CHECK_EQ(operands.size(), 1); - // TODO(b/35950897): implement Convert. - return Unimplemented("unhandled HLO ops for HloEvaluator: %s.", - HloOpcodeString(opcode).c_str()); - case HloOpcode::kCopy: - CHECK_EQ(operands.size(), 1); - return ElementWiseUnaryOp( - shape, [](NativeT operand) { return operand; }, - operands[0]->literal()); - case HloOpcode::kExp: - CHECK_EQ(operands.size(), 1); - return ElementWiseUnaryOp( - shape, [](NativeT operand) { return std::exp(operand); }, - operands[0]->literal()); - case HloOpcode::kFloor: - CHECK_EQ(operands.size(), 1); - return ElementWiseUnaryOp( - shape, [](NativeT operand) { return std::floor(operand); }, - operands[0]->literal()); - case HloOpcode::kIsFinite: - CHECK_EQ(operands.size(), 1); - return ElementWiseUnaryOp( - shape, [](NativeT operand) { return std::isfinite(operand); }, - operands[0]->literal()); - case HloOpcode::kLog: - CHECK_EQ(operands.size(), 1); - return ElementWiseUnaryOp( - shape, [](NativeT operand) { return std::log(operand); }, - operands[0]->literal()); - case HloOpcode::kLogicalNot: - CHECK_EQ(operands.size(), 1); - return ElementWiseUnaryOp( - shape, [](NativeT operand) { return !operand; }, - operands[0]->literal()); - case HloOpcode::kNegate: - CHECK_EQ(operands.size(), 1); - return ElementWiseUnaryOp( - shape, [](NativeT operand) { return -operand; }, - operands[0]->literal()); - case HloOpcode::kSign: - CHECK_EQ(operands.size(), 1); - CHECK(primitive_util::IsIntegralType(shape.element_type())); - return ElementWiseUnaryOp(shape, - [](NativeT operand) { - return (NativeT(0) < operand) - - (operand < NativeT(0)); - }, - operands[0]->literal()); - case HloOpcode::kTanh: - CHECK_EQ(operands.size(), 1); - return ElementWiseUnaryOp( - shape, [](NativeT operand) { return std::tanh(operand); }, - operands[0]->literal()); - // Binary element-wise ops. - case HloOpcode::kAdd: - CHECK_EQ(operands.size(), 2); - return ElementWiseBinaryOp( - shape, [](NativeT lhs, NativeT rhs) { return lhs + rhs; }, - operands[0]->literal(), operands[1]->literal()); - case HloOpcode::kDivide: - CHECK_EQ(operands.size(), 2); - return ElementWiseBinaryOp( - shape, [](NativeT lhs, NativeT rhs) { return lhs / rhs; }, - operands[0]->literal(), operands[1]->literal()); - case HloOpcode::kMultiply: - CHECK_EQ(operands.size(), 2); - return ElementWiseBinaryOp( - shape, [](NativeT lhs, NativeT rhs) { return lhs * rhs; }, - operands[0]->literal(), operands[1]->literal()); - case HloOpcode::kSubtract: - CHECK_EQ(operands.size(), 2); - return ElementWiseBinaryOp( - shape, [](NativeT lhs, NativeT rhs) { return lhs - rhs; }, - operands[0]->literal(), operands[1]->literal()); - case HloOpcode::kEq: - CHECK_EQ(operands.size(), 2); - return ElementWiseBinaryOp( - shape, [](NativeT lhs, NativeT rhs) { return lhs == rhs; }, - operands[0]->literal(), operands[1]->literal()); - case HloOpcode::kGe: - CHECK_EQ(operands.size(), 2); - return ElementWiseBinaryOp( - shape, [](NativeT lhs, NativeT rhs) { return lhs >= rhs; }, - operands[0]->literal(), operands[1]->literal()); - case HloOpcode::kGt: - CHECK_EQ(operands.size(), 2); - return ElementWiseBinaryOp( - shape, [](NativeT lhs, NativeT rhs) { return lhs > rhs; }, - operands[0]->literal(), operands[1]->literal()); - case HloOpcode::kLe: - CHECK_EQ(operands.size(), 2); - return ElementWiseBinaryOp( - shape, [](NativeT lhs, NativeT rhs) { return lhs <= rhs; }, - operands[0]->literal(), operands[1]->literal()); - case HloOpcode::kLt: - CHECK_EQ(operands.size(), 2); - return ElementWiseBinaryOp( - shape, [](NativeT lhs, NativeT rhs) { return lhs < rhs; }, - operands[0]->literal(), operands[1]->literal()); - case HloOpcode::kNe: - CHECK_EQ(operands.size(), 2); - return ElementWiseBinaryOp( - shape, [](NativeT lhs, NativeT rhs) { return lhs != rhs; }, - operands[0]->literal(), operands[1]->literal()); - case HloOpcode::kMaximum: - CHECK_EQ(operands.size(), 2); - return ElementWiseBinaryOp( - shape, [](NativeT lhs, NativeT rhs) { return std::max(lhs, rhs); }, - operands[0]->literal(), operands[1]->literal()); - case HloOpcode::kMinimum: - CHECK_EQ(operands.size(), 2); - return ElementWiseBinaryOp( - shape, [](NativeT lhs, NativeT rhs) { return std::min(lhs, rhs); }, - operands[0]->literal(), operands[1]->literal()); - case HloOpcode::kPower: - CHECK_EQ(operands.size(), 2); - return ElementWiseBinaryOp( - shape, [](NativeT lhs, NativeT rhs) { return std::pow(lhs, rhs); }, - operands[0]->literal(), operands[1]->literal()); - case HloOpcode::kRemainder: - CHECK_EQ(operands.size(), 2); - return ElementWiseBinaryOp( - shape, - [](NativeT lhs, NativeT rhs) { return std::remainder(lhs, rhs); }, - operands[0]->literal(), operands[1]->literal()); - case HloOpcode::kLogicalAnd: - CHECK_EQ(operands.size(), 2); - return ElementWiseBinaryOp( - shape, [](NativeT lhs, NativeT rhs) { return lhs && rhs; }, - operands[0]->literal(), operands[1]->literal()); - case HloOpcode::kLogicalOr: - CHECK_EQ(operands.size(), 2); - return ElementWiseBinaryOp( - shape, [](NativeT lhs, NativeT rhs) { return lhs || rhs; }, - operands[0]->literal(), operands[1]->literal()); - // Ternary element-wise ops. - case HloOpcode::kClamp: { - CHECK_EQ(operands.size(), 3); - std::function clamp_op = - [](NativeT low, NativeT high, NativeT value) { - return std::max(low, std::min(value, high)); - }; - return ElementWiseTernaryOp( - shape, std::move(clamp_op), operands[0]->literal(), - operands[1]->literal(), operands[2]->literal()); - } break; - case HloOpcode::kSelect: { - CHECK_EQ(operands.size(), 3); - CHECK(!ShapeUtil::IsTuple(shape)); - std::function select_op = - [](bool pred, NativeT on_true, NativeT on_false) { - if (pred) { - return on_true; - } - return on_false; - }; - return ElementWiseTernaryOp( - shape, std::move(select_op), operands[0]->literal(), - operands[1]->literal(), operands[2]->literal()); - } break; - default: - return Unimplemented("unhandled HLO ops for HloEvaluator: %s.", - HloOpcodeString(opcode).c_str()); - } + Status HandleFloor(HloInstruction* floor, HloInstruction* operand) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[floor], + ElementWiseUnaryOp(floor, [](ReturnT elem_operand) { + return std::floor(elem_operand); + })); + return Status::OK(); + }; + + Status HandleIsFinite(HloInstruction* is_finite, + HloInstruction* operand) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[is_finite], + ElementWiseUnaryOp(is_finite, [](ReturnT elem_operand) { + return std::isfinite(elem_operand); + })); + return Status::OK(); + }; + + Status HandleLog(HloInstruction* log, HloInstruction* operand) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[log], + ElementWiseUnaryOp(log, [](ReturnT elem_operand) { + return std::log(elem_operand); + })); + return Status::OK(); + }; + + Status HandleLogicalNot(HloInstruction* logical_not, + HloInstruction* operand) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[logical_not], + ElementWiseUnaryOp(logical_not, + [](ReturnT elem_operand) { return !elem_operand; })); + return Status::OK(); + }; + + Status HandleNegate(HloInstruction* negate, + HloInstruction* operand) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[negate], + ElementWiseUnaryOp(negate, [](ReturnT elem_operand) { + return -elem_operand; + })); + return Status::OK(); + }; + + Status HandleSign(HloInstruction* sign, HloInstruction* operand) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], + ElementWiseUnaryOp(sign, [](ReturnT elem_operand) { + return (ReturnT(0) < elem_operand) - + (elem_operand < ReturnT(0)); + })); + return Status::OK(); + }; + + Status HandleTanh(HloInstruction* tanh, HloInstruction* operand) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[tanh], + ElementWiseUnaryOp(tanh, [](ReturnT elem_operand) { + return std::tanh(elem_operand); + })); + return Status::OK(); + }; + + Status HandleMultiply(HloInstruction* multiply, HloInstruction* lhs, + HloInstruction* rhs) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[multiply], + ElementWiseBinaryOp(multiply, [](ReturnT lhs_elem, ReturnT rhs_elem) { + return lhs_elem * rhs_elem; + })); + return Status::OK(); + }; + + Status HandleSubtract(HloInstruction* subtract, HloInstruction* lhs, + HloInstruction* rhs) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[subtract], + ElementWiseBinaryOp(subtract, [](ReturnT lhs_elem, ReturnT rhs_elem) { + return lhs_elem - rhs_elem; + })); + return Status::OK(); + }; + + Status HandleAdd(HloInstruction* add, HloInstruction* lhs, + HloInstruction* rhs) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[add], + ElementWiseBinaryOp(add, [](ReturnT lhs_elem, ReturnT rhs_elem) { + return lhs_elem + rhs_elem; + })); + return Status::OK(); + }; + + Status HandleDivide(HloInstruction* divide, HloInstruction* lhs, + HloInstruction* rhs) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[divide], + ElementWiseBinaryOp(divide, [](ReturnT lhs_elem, ReturnT rhs_elem) { + return lhs_elem / rhs_elem; + })); + return Status::OK(); + }; + + Status HandleCompare(HloInstruction* compare, HloOpcode opcode, + HloInstruction* lhs, HloInstruction* rhs) override { + std::function compare_op; + switch (opcode) { + case HloOpcode::kEq: + compare_op = [](ReturnT lhs_el, ReturnT rhs_el) { + return lhs_el == rhs_el; + }; + break; + case HloOpcode::kNe: + compare_op = [](ReturnT lhs_el, ReturnT rhs_el) { + return lhs_el != rhs_el; + }; + break; + case HloOpcode::kGe: + compare_op = [](ReturnT lhs_el, ReturnT rhs_el) { + return lhs_el >= rhs_el; + }; + break; + case HloOpcode::kGt: + compare_op = [](ReturnT lhs_el, ReturnT rhs_el) { + return lhs_el > rhs_el; + }; + break; + case HloOpcode::kLe: + compare_op = [](ReturnT lhs_el, ReturnT rhs_el) { + return lhs_el <= rhs_el; + }; + break; + case HloOpcode::kLt: + compare_op = [](ReturnT lhs_el, ReturnT rhs_el) { + return lhs_el < rhs_el; + }; + break; + default: + LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: " + << HloOpcodeString(opcode); + } + + // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is + // removed. + if (!(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) && + ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) { + return Unimplemented( + "Compare operation with mismatched dimensions, likely due to " + "broadcasting is unsupported."); + } + + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + + auto result = LiteralUtil::CreateFromShape(compare->shape()); + std::vector multi_index(ShapeUtil::Rank(result->shape()), 0); + do { + LiteralUtil::Set( + result.get(), multi_index, + compare_op(LiteralUtil::Get(lhs_literal, multi_index), + LiteralUtil::Get(rhs_literal, multi_index))); + } while (IndexUtil::BumpIndices(result->shape(), &multi_index)); + + parent_->evaluated_[compare] = std::move(result); + + return Status::OK(); + }; + + Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs, + HloInstruction* rhs) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[maximum], + ElementWiseBinaryOp(maximum, [](ReturnT lhs, ReturnT rhs) { + return std::max(lhs, rhs); + })); + return Status::OK(); + }; + + Status HandleMinimum(HloInstruction* minimum, HloInstruction* lhs, + HloInstruction* rhs) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[minimum], + ElementWiseBinaryOp(minimum, [](ReturnT lhs_el, ReturnT rhs_el) { + return std::min(lhs_el, rhs_el); + })); + return Status::OK(); + }; + + Status HandlePower(HloInstruction* power, HloInstruction* lhs, + HloInstruction* rhs) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[power], + ElementWiseBinaryOp(power, [](ReturnT lhs_el, ReturnT rhs_el) { + return std::pow(lhs_el, rhs_el); + })); + return Status::OK(); + }; + + Status HandleRemainder(HloInstruction* remainder, HloInstruction* lhs, + HloInstruction* rhs) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[remainder], + ElementWiseBinaryOp(remainder, [](ReturnT lhs_el, ReturnT rhs_el) { + return std::remainder(lhs_el, rhs_el); + })); + return Status::OK(); + }; + + Status HandleLogicalAnd(HloInstruction* logical_and, HloInstruction* lhs, + HloInstruction* rhs) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[logical_and], + ElementWiseBinaryOp(logical_and, [](ReturnT lhs_el, ReturnT rhs_el) { + return lhs_el && rhs_el; + })); + return Status::OK(); + }; + + Status HandleLogicalOr(HloInstruction* logical_or, HloInstruction* lhs, + HloInstruction* rhs) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[logical_or], + ElementWiseBinaryOp(logical_or, [](ReturnT lhs_el, ReturnT rhs_el) { + return lhs_el || rhs_el; + })); + return Status::OK(); + }; + + Status HandleClamp(HloInstruction* clamp, HloInstruction* min, + HloInstruction* arg, HloInstruction* max) override { + std::function clamp_op = + [](ReturnT low, ReturnT high, ReturnT value) { + return std::max(low, std::min(value, high)); + }; + TF_ASSIGN_OR_RETURN(parent_->evaluated_[clamp], + ElementWiseTernaryOp(clamp, std::move(clamp_op))); + return Status::OK(); + }; + + Status HandleSelect(HloInstruction* select, HloInstruction* pred, + HloInstruction* on_true, + HloInstruction* on_false) override { + CHECK(!ShapeUtil::IsTuple(select->shape())); + std::function select_op = + [](bool pred, ReturnT on_true, ReturnT on_false) { + if (pred) { + return on_true; + } + return on_false; + }; + TF_ASSIGN_OR_RETURN(parent_->evaluated_[select], + ElementWiseTernaryOp(select, std::move(select_op))); + return Status::OK(); + }; + + Status Preprocess(HloInstruction* hlo) override { + VLOG(2) << hlo->ToString(); + return Status::OK(); + }; + + private: + StatusOr> ElementWiseUnaryOp( + HloInstruction* instruction, + const std::function& unary_op) { + const auto shape = instruction->shape(); + const auto* operand = instruction->operand(0); + + // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is + // removed. + if (!ShapeUtil::SameDimensions(shape, operand->shape())) { + return Unimplemented( + "Implicit broadcasting is currently unsupported in HLO evaluator " + "Shape Mismatch: %s vs %s", + ShapeUtil::HumanString(shape).c_str(), + ShapeUtil::HumanString(operand->shape()).c_str()); + } + + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); + + auto result = LiteralUtil::CreateFromShape(shape); + + std::vector multi_index(ShapeUtil::Rank(result->shape()), 0); + do { + LiteralUtil::Set( + result.get(), multi_index, + unary_op(LiteralUtil::Get(operand_literal, multi_index))); + } while (IndexUtil::BumpIndices(result->shape(), &multi_index)); + + return std::move(result); + }; + + StatusOr> ElementWiseBinaryOp( + HloInstruction* instruction, + const std::function& binary_op) { + const auto shape = instruction->shape(); + const auto* lhs = instruction->operand(0); + const auto* rhs = instruction->operand(1); + + // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is + // removed. + if (!(ShapeUtil::SameDimensions(shape, rhs->shape()) && + ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) { + return Unimplemented( + "Implicit broadcasting is currently unsupported in HLO evaluator " + "Shape Mismatch: %s vs %s vs %s: ", + ShapeUtil::HumanString(shape).c_str(), + ShapeUtil::HumanString(lhs->shape()).c_str(), + ShapeUtil::HumanString(rhs->shape()).c_str()); + } + + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + + auto result = LiteralUtil::CreateFromShape(shape); + std::vector multi_index(ShapeUtil::Rank(result->shape()), 0); + do { + LiteralUtil::Set( + result.get(), multi_index, + binary_op(LiteralUtil::Get(lhs_literal, multi_index), + LiteralUtil::Get(rhs_literal, multi_index))); + } while (IndexUtil::BumpIndices(result->shape(), &multi_index)); + + return std::move(result); + }; + + template + StatusOr> ElementWiseTernaryOp( + HloInstruction* instruction, + const std::function& ternary_op) { + const auto shape = instruction->shape(); + const auto* lhs = instruction->operand(0); + const auto* rhs = instruction->operand(1); + const auto* ehs = instruction->operand(2); + + // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is + // removed. + if (!(ShapeUtil::SameDimensions(shape, lhs->shape()) && + ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()) && + ShapeUtil::SameDimensions(rhs->shape(), ehs->shape()))) { + return Unimplemented( + "Implicit broadcasting is currently unsupported in HLO evaluator " + "Shape Mismatch: %s vs %s vs %s vs %s: ", + ShapeUtil::HumanString(shape).c_str(), + ShapeUtil::HumanString(lhs->shape()).c_str(), + ShapeUtil::HumanString(rhs->shape()).c_str(), + ShapeUtil::HumanString(ehs->shape()).c_str()); + } + + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); + + auto result = LiteralUtil::CreateFromShape(shape); + std::vector multi_index(ShapeUtil::Rank(result->shape()), 0); + do { + LiteralUtil::Set( + result.get(), multi_index, + ternary_op(LiteralUtil::Get(lhs_literal, multi_index), + LiteralUtil::Get(rhs_literal, multi_index), + LiteralUtil::Get(ehs_literal, multi_index))); + } while (IndexUtil::BumpIndices(result->shape(), &multi_index)); + + return std::move(result); + }; + + HloEvaluator* parent_; +}; + +HloEvaluator::HloEvaluator() { + typed_visitors_[PRED] = MakeUnique>(this); + typed_visitors_[U8] = MakeUnique>(this); + typed_visitors_[U16] = MakeUnique([](HloInstruction*) { + return Unimplemented("unhandled primitive type: U16."); + }); + typed_visitors_[U32] = MakeUnique>(this); + typed_visitors_[U64] = MakeUnique>(this); + typed_visitors_[S8] = MakeUnique>(this); + typed_visitors_[S16] = MakeUnique([](HloInstruction*) { + return Unimplemented("unhandled primitive type: S16."); + }); + typed_visitors_[S32] = MakeUnique>(this); + typed_visitors_[S64] = MakeUnique>(this); + typed_visitors_[F16] = MakeUnique([](HloInstruction*) { + return Unimplemented("unhandled primitive type: F16."); + }); + typed_visitors_[F32] = MakeUnique>(this); + typed_visitors_[F64] = MakeUnique>(this); } -} // namespace +StatusOr> HloEvaluator::Evaluate( + HloComputation* computation, + tensorflow::gtl::ArraySlice args) { + arg_literals_ = args; + evaluated_.clear(); -/* static */ StatusOr> -HloEvaluator::EvaluateOpForLiteral(HloInstruction* instruction) { - DCHECK(hlo_query::AllOperandsAreConstants(*instruction)); + TF_RETURN_IF_ERROR(computation->Accept(this)); + return std::move(FindOrDie(evaluated_, computation->root_instruction())); +} +StatusOr> HloEvaluator::Evaluate( + HloInstruction* instruction, + tensorflow::gtl::ArraySlice operands) { + DCHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction)); Shape shape = instruction->shape(); TF_CHECK_OK(ShapeUtil::ValidateShape(shape)); - // REVIEW QUESTION: other than a few operations, do we need to handle the - // general case of operands being of different types in the context of the - // evaluator? - - switch (shape.element_type()) { - case PRED: - return EvaluateOpForLiteralInternal(instruction); - case U8: - return EvaluateOpForLiteralInternal(instruction); - case U16: - LOG(FATAL) << "U16/uint16 is unimplemented."; - case U32: - return EvaluateOpForLiteralInternal(instruction); - case U64: - return EvaluateOpForLiteralInternal(instruction); - case S8: - return EvaluateOpForLiteralInternal(instruction); - case S16: - LOG(FATAL) << "S16/int16 is unimplemented."; - case S32: - return EvaluateOpForLiteralInternal(instruction); - case S64: - return EvaluateOpForLiteralInternal(instruction); - case F16: - LOG(FATAL) << "F16 is unimplemented."; - case F32: - return EvaluateOpForLiteralInternal(instruction); - case F64: - return EvaluateOpForLiteralInternal(instruction); - default: - return Unimplemented("unhandled primitive type: %s.", - PrimitiveType_Name(shape.element_type()).c_str()); + arg_literals_ = operands; + evaluated_.clear(); + + // Evaluate operands of Parameter type against the input literals which + // caches the evaluated literal results. + for (const auto operand : instruction->operands()) { + if (operand->opcode() == HloOpcode::kParameter) { + const Literal* input_literal = arg_literals_[operand->parameter_number()]; + VLOG(2) << "Parameter operand evaluated to: " + << LiteralUtil::ToString(*input_literal); + TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape())); + + evaluated_[operand] = MakeUnique(*input_literal); + } else if (operand->opcode() == HloOpcode::kConstant) { + evaluated_[operand] = MakeUnique(operand->literal()); + } } + + TF_RETURN_IF_ERROR(instruction->Visit(this)); + return std::move(FindOrDie(evaluated_, instruction)); +} + +Status HloEvaluator::HandleParameter(HloInstruction* parameter) { + VLOG(2) << "HandleParameter: " << parameter->ToString(); + const Literal* input_literal = arg_literals_[parameter->parameter_number()]; + VLOG(2) << "Parameter evaluated to: " + << LiteralUtil::ToString(*input_literal); + DCHECK(ShapeUtil::Equal(parameter->shape(), input_literal->shape())); + + evaluated_[parameter] = MakeUnique(*input_literal); + return Status::OK(); +} + +Status HloEvaluator::HandleConstant(HloInstruction* constant, + const Literal& literal) { + VLOG(2) << "HandleConstant: " << constant->ToString(); + DCHECK(ShapeUtil::Equal(constant->shape(), literal.shape())); + + evaluated_[constant] = MakeUnique(literal); + return Status::OK(); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index c6ec650d674117f8bcbc9517a76b16c5940981d2..50cb32eb85c04d8b3abe4cd0b46a4f8c10e9c568 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -18,22 +18,105 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/macros.h" namespace xla { -// Responsible for evaluating a HLO instruction with constant operands. -class HloEvaluator { +// Responsible for evaluating HLO and obtain literal as the evaluation results. +// +// This class is not thread-safe. +class HloEvaluator : public DfsHloVisitorWithDefault { public: - // Evaluates a single HLO instruction for constants and return the result as a - // Literal. - // Precondition: all operands of the instruction are constants, instruction is - // valid with corresponding number of operands for the given operator. + HloEvaluator(); + // Evaluates a HLO computation and an array of pointers to literals. + // Return the evaluated result as literal if successful. + // Precondition: argument literals are corresponds to the input computation's + // parameters in their post-ordering. For e.g., consider the following graph: + // + // * + // / \ + // + Parameter1 + // / \ + // / \ + // Parameter0 Constant + // + // The input literals array will have its first literal map to Parameter0 and + // the second map to Parameter1. + StatusOr> Evaluate( + HloComputation* computation, + tensorflow::gtl::ArraySlice arg_literals); + + // Evaluates a single HLO instruction and an array of pointers to literals. + // Return the evaluated result as literal if successful. + // Precondition: + // 1. argument literals are corresponds to the input instruction's + // parameters in their post-orderring. + // 2. the instruction's operands must be of either Parameter or Constant type. // TODO(b/35950897): implement more ops other than element-wise ops. - static StatusOr> EvaluateOpForLiteral( - HloInstruction* instruction); + StatusOr> Evaluate( + HloInstruction* instruction, + tensorflow::gtl::ArraySlice arg_literals); + + protected: + // Templated DfsHloVisitor. Typically ReturnT here indicates the resulting + // literal type of each evaluated Handle* method of a TypedVisitor. One + // exception to this is HandleCompare, where the resulting literal type is + // always boolean. + // Note the forward declaration here is necessary to enable TypedVisitor to + // access parent members. + template + class TypedVisitor; + + // Wraps around instruction handling to infer types before dispatching to + // the corresponding typed Visitor. + Status DefaultAction(HloInstruction* hlo) override { + return hlo->Visit(typed_visitors_.at(hlo->shape().element_type()).get()); + } + + Status HandleParameter(HloInstruction* parameter) override; + + Status HandleConstant(HloInstruction* constant, + const Literal& literal) override; + + private: + // Returns the already-evaluated literal result for the instruction. + // Crash with log if the given instruction has not been evaluated previously. + const Literal& GetEvaluatedLiteralFor(const HloInstruction* hlo) { + auto it = evaluated_.find(hlo); + CHECK(it != evaluated_.end()) + << "could not find evaluated value for: " << hlo->ToString(); + return *(it->second); + } + + // Map from a primitive type to its associated (templated) DfsHloVisitor. + // Note: the hash function here is only needed because current gcc std::hash + // does not specialize for enum types. This should however be fixed in the + // future: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=60970#c5 + tensorflow::gtl::FlatMap, + std::hash> + typed_visitors_; + + // Tracks the HLO instruciton and its evaluated literal result. + // TODO(b/35950897): have better memory management here to free instructions + // that are no longer a parent for any other subsequent instruction in + // post-orderring. + tensorflow::gtl::FlatMap> + evaluated_; + + // Stores input literals, assuming they are in post-order. Literals are not + // owned by this class, and they must outlive the lifetime of the instance of + // this class. + tensorflow::gtl::ArraySlice arg_literals_; + + TF_DISALLOW_COPY_AND_ASSIGN(HloEvaluator); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 585fe65def3fef4f271b5cfbbb500d3f7a0eba59..443e5ad4f4290ff10b867887ac5ed359a0c8f73a 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -14,10 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include #include #include +#include #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -29,9 +32,16 @@ limitations under the License. namespace xla { namespace { +class HloEvaluatorTest : public ::testing::Test { + protected: + HloEvaluatorTest() { evaluator_ = MakeUnique(); } + + std::unique_ptr evaluator_; +}; + // Verifies that HloEvaluator evaluates a HLO instruction that performs clamp // with 3 operands. -TEST(HloEvaluatorTest, DoesClamp) { +TEST_F(HloEvaluatorTest, DoesClamp) { auto low = LiteralUtil::CreateR2({{0.f, 2.f}, {2.f, 4.f}}); auto high = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); auto value = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); @@ -44,7 +54,7 @@ TEST(HloEvaluatorTest, DoesClamp) { shape, HloOpcode::kClamp, c1.get(), c2.get(), c3.get()); std::unique_ptr result = - HloEvaluator::EvaluateOpForLiteral(instruction.get()).ConsumeValueOrDie(); + evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); auto expected = LiteralUtil::CreateR2({{0, 4}, {2, 4}}); @@ -53,7 +63,7 @@ TEST(HloEvaluatorTest, DoesClamp) { // Verifies that HloEvaluator evaluates a HLO instruction that performs select // with 3 operands. -TEST(HloEvaluatorTest, DoesSelect) { +TEST_F(HloEvaluatorTest, DoesSelect) { auto pred = LiteralUtil::CreateR2({{true, false}, {false, true}}); auto on_true = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); auto on_false = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); @@ -66,7 +76,7 @@ TEST(HloEvaluatorTest, DoesSelect) { shape, HloOpcode::kSelect, c1.get(), c2.get(), c3.get()); std::unique_ptr result = - HloEvaluator::EvaluateOpForLiteral(instruction.get()).ConsumeValueOrDie(); + evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); auto expected = LiteralUtil::CreateR2({{2, 5}, {0, 4}}); @@ -75,7 +85,7 @@ TEST(HloEvaluatorTest, DoesSelect) { // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise addition with 2 operands. -TEST(HloEvaluatorTest, DoesAdd) { +TEST_F(HloEvaluatorTest, DoesAdd) { auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); @@ -86,7 +96,7 @@ TEST(HloEvaluatorTest, DoesAdd) { HloInstruction::CreateBinary(shape, HloOpcode::kAdd, c1.get(), c2.get()); std::unique_ptr result = - HloEvaluator::EvaluateOpForLiteral(instruction.get()).ConsumeValueOrDie(); + evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); auto expected = LiteralUtil::CreateR2({{3, 4}, {-96, 8}}); @@ -95,7 +105,7 @@ TEST(HloEvaluatorTest, DoesAdd) { // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise divide with 2 operands. -TEST(HloEvaluatorTest, DoesDivide) { +TEST_F(HloEvaluatorTest, DoesDivide) { auto lhs_s64 = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); auto rhs_s64 = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); @@ -106,7 +116,7 @@ TEST(HloEvaluatorTest, DoesDivide) { c1_s64.get(), c2_s64.get()); std::unique_ptr result = - HloEvaluator::EvaluateOpForLiteral(instruction.get()).ConsumeValueOrDie(); + evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); auto expected = LiteralUtil::CreateR2({{0, 0}, {-25, 1}}); @@ -121,8 +131,7 @@ TEST(HloEvaluatorTest, DoesDivide) { instruction = HloInstruction::CreateBinary(shape_f64, HloOpcode::kDivide, c1_f64.get(), c2_f64.get()); - result = - HloEvaluator::EvaluateOpForLiteral(instruction.get()).ConsumeValueOrDie(); + result = evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); expected = LiteralUtil::CreateR2({{0.45454545454545453, 0}, {-25, 1}}); @@ -132,21 +141,51 @@ TEST(HloEvaluatorTest, DoesDivide) { // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise abs op with 1 operand. -TEST(HloEvaluatorTest, DoesAbs) { +TEST_F(HloEvaluatorTest, DoesAbs) { auto operand = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); - Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); auto c1 = HloInstruction::CreateConstant(std::move(operand)); auto instruction = HloInstruction::CreateUnary(shape, HloOpcode::kAbs, c1.get()); std::unique_ptr result = - HloEvaluator::EvaluateOpForLiteral(instruction.get()).ConsumeValueOrDie(); + evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); auto expected = LiteralUtil::CreateR2({{1, 20}, {100, 4}}); EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); } +// Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor +// constant operands. +TEST_F(HloEvaluatorTest, DoesTraveseInstructions) { + HloComputation::Builder builder( + ::testing::UnitTest::GetInstance()->current_test_info()->name()); + + auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); + auto rhs2 = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); + std::vector args = {lhs.get(), rhs.get(), rhs2.get()}; + + Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); + + auto param_lhs = HloInstruction::CreateParameter(0, shape, "lhs"); + auto param_rhs = HloInstruction::CreateParameter(1, shape, "rhs"); + auto lhs_instruction = HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, param_lhs.get(), param_rhs.get()); + + auto param_rhs2 = HloInstruction::CreateParameter(2, shape, "rhs2"); + auto root_instruction = HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, lhs_instruction.get(), param_rhs2.get()); + + builder.AddInstruction(std::move(root_instruction)); + std::unique_ptr result = + evaluator_->Evaluate(builder.Build().get(), args).ConsumeValueOrDie(); + + auto expected = LiteralUtil::CreateR2({{4, -16}, {-196, 12}}); + + EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 5ce52af5b4616050988d2dba653c23d8acedf0d8..eb2e5dfb37f33fd138e20ee930a2242cb1db89ea 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -48,6 +48,73 @@ namespace xla { namespace hlo_graph_dumper { namespace { +// Node color schemes, used by NodeColorAttributes. +enum ColorScheme { + kBlue, + kBrown, + kDarkBlue, + kDarkGreen, + kDarkRed, + kGray, + kGreen, + kOrange, + kPurple, + kRed, + kWhite, + kYellow, +}; + +// Given a ColorScheme, returns an attribute string for a node of that color. +// Sets the node's fill, stroke, and text colors. +// +// Colors are from https://material.io/color. +string NodeColorAttributes(ColorScheme color) { + using std::make_tuple; + + const char *fill_color, *stroke_color, *font_color; + std::tie(fill_color, stroke_color, font_color) = + [color]() -> std::tuple { + switch (color) { + case kBlue: + return make_tuple("#bbdefb", "#8aacc8", "black"); + case kBrown: + return make_tuple("#bcaaa4", "#8c7b75", "black"); + case kDarkBlue: + return make_tuple("#1565c0", "#003c8f", "white"); + case kDarkGreen: + return make_tuple("#2e7d32", "#005005", "white"); + case kDarkRed: + return make_tuple("#b71c1c", "#7f0000", "white"); + case kGray: + return make_tuple("#cfd8dc", "#9ea7aa", "black"); + case kGreen: + return make_tuple("#c8e6c9", "#97b498", "black"); + case kOrange: + return make_tuple("#ffe0b2", "#cbae82", "black"); + case kPurple: + return make_tuple("#e1bee7", "#af8eb5", "black"); + case kRed: + return make_tuple("#ffcdd2", "#cb9ca1", "black"); + case kWhite: + return make_tuple("white", "black", "black"); + case kYellow: + return make_tuple("#fff9c4", "#cbc693", "black"); + } + }(); + + return Printf( + "style=filled, fontcolor=\"%s\", color=\"%s\", fillcolor=\"%s\"", + font_color, stroke_color, fill_color); +} + +// Replaces <> with <>, so that this string is safe(er) for use in a +// graphviz HTML-like string. +string HtmlLikeStringSanitize(tensorflow::StringPiece s) { + return tensorflow::str_util::StringReplace( + tensorflow::str_util::StringReplace(s, "<", "<", /*replace_all=*/true), + ">", ">", /*replace_all=*/true); +} + // Returns the dot graph identifier for the given instruction. string InstructionId(const HloInstruction* instruction) { return Printf("%lld", reinterpret_cast(instruction)); @@ -102,30 +169,36 @@ string InstructionSequenceGraph( param_ports.push_back( Printf("<%s> %s", InstructionId(param).c_str(), label.c_str())); } - StrAppend(&graph_body, param_node_name, - " [shape=record,style=filled,fillcolor=\"lightblue1\",", - "label=\"{parameters | {", Join(param_ports, "|"), "}}\"];\n"); + // (If we wanted the word "parameters" to be bold like the other op names, + // we'd have to make this into an HTML-like table. It is possible but + // complicated; see http://www.graphviz.org/doc/info/shapes.html#html.) + StrAppend(&graph_body, param_node_name, " [shape=record ", + NodeColorAttributes(kOrange), "label=\"{parameters | {", + Join(param_ports, "|"), "}}\"];\n"); } for (auto& instruction : instructions) { - string color = "peachpuff"; - string shape = "ellipse"; - string name = instruction->ExtendedOpcodeStr(); + ColorScheme color = kYellow; + string shape = "box"; + string name = + StrCat("", HtmlLikeStringSanitize(instruction->ExtendedOpcodeStr()), + " ", HtmlLikeStringSanitize(instruction->name())); if (HloOpcode::kConvolution == instruction->opcode()) { - name += ":\\n" + instruction->ConvolutionDimensionNumbersToString() + - "\\n" + window_util::ToString(instruction->window()); + StrAppend( + &name, "
", + HtmlLikeStringSanitize( + instruction->ConvolutionDimensionNumbersToString()), + "
", + HtmlLikeStringSanitize(window_util::ToString(instruction->window()))); } - name += "\\n" + instruction->name(); - if (!instruction->metadata().op_type().empty()) { - StrAppend(&name, "\\n", instruction->metadata().op_type()); - } if (!instruction->metadata().op_name().empty()) { - StrAppend(&name, "\\n", instruction->metadata().op_name()); + StrAppend(&name, "
", + HtmlLikeStringSanitize(instruction->metadata().op_name())); } if (!instruction->metadata().source_file().empty() && instruction->metadata().source_line() != 0) { - StrAppend(&name, "\\n", instruction->metadata().source_file(), ":", + StrAppend(&name, "
", instruction->metadata().source_file(), ":", instruction->metadata().source_line()); } @@ -140,11 +213,8 @@ string InstructionSequenceGraph( case HloOpcode::kAdd: case HloOpcode::kCeil: case HloOpcode::kClamp: - case HloOpcode::kConcatenate: case HloOpcode::kConvert: case HloOpcode::kDivide: - case HloOpcode::kDynamicSlice: - case HloOpcode::kDynamicUpdateSlice: case HloOpcode::kEq: case HloOpcode::kExp: case HloOpcode::kFloor: @@ -163,64 +233,49 @@ string InstructionSequenceGraph( case HloOpcode::kMultiply: case HloOpcode::kNe: case HloOpcode::kNegate: - case HloOpcode::kPad: case HloOpcode::kPower: case HloOpcode::kRemainder: - case HloOpcode::kReshape: - case HloOpcode::kReverse: case HloOpcode::kSelect: case HloOpcode::kSign: case HloOpcode::kSlice: case HloOpcode::kSort: case HloOpcode::kSubtract: case HloOpcode::kTanh: - case HloOpcode::kTuple: - case HloOpcode::kUpdate: - break; - - case HloOpcode::kBroadcast: - case HloOpcode::kTranspose: - StrAppend(&name, "\\n", "dims={", Join(instruction->dimensions(), ","), - "}"); - break; - case HloOpcode::kGetTupleElement: - StrAppend(&name, "\\nindex=", instruction->tuple_index()); break; case HloOpcode::kRng: - StrAppend(&name, "\\n", + StrAppend(&name, "
", RandomDistribution_Name(instruction->random_distribution())); break; - case HloOpcode::kConstant: - shape = "box"; - color = "palegreen"; - if (ShapeUtil::IsScalar(instruction->shape())) { - StrAppend(&name, "\\n", "value=", LiteralUtil::GetAsString( - instruction->literal(), {})); - } + case HloOpcode::kBroadcast: + case HloOpcode::kTranspose: + StrAppend(&name, "
", "dims={", + Join(instruction->dimensions(), ","), "}"); break; case HloOpcode::kBitcast: - case HloOpcode::kCopy: - color = "white"; - break; - case HloOpcode::kCall: - color = "tomato"; - break; - case HloOpcode::kCustomCall: - color = "tomato4"; - StrAppend(&name, "\\n", - "custom_call_target=", instruction->custom_call_target()); + case HloOpcode::kTuple: + case HloOpcode::kTrace: + color = kWhite; break; - case HloOpcode::kDot: - color = "slateblue"; + case HloOpcode::kGetTupleElement: + color = kWhite; + StrAppend(&name, "
index=", instruction->tuple_index()); break; - case HloOpcode::kSend: - color = "purple"; + case HloOpcode::kConcatenate: + case HloOpcode::kCopy: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kPad: + case HloOpcode::kReshape: + case HloOpcode::kReverse: + case HloOpcode::kUpdate: + color = kGreen; break; - case HloOpcode::kRecv: - color = "orange"; + case HloOpcode::kConstant: + color = kBlue; break; - case HloOpcode::kMap: - color = "palevioletred"; + case HloOpcode::kConvolution: + case HloOpcode::kDot: + color = kDarkBlue; break; case HloOpcode::kParameter: // A single record node is created for all the parameter nodes with a @@ -229,38 +284,54 @@ string InstructionSequenceGraph( continue; case HloOpcode::kReduce: StrAppend(&name, " dims=", Join(instruction->dimensions(), ",")); - color = "lightsalmon"; + color = kPurple; break; case HloOpcode::kSelectAndScatter: case HloOpcode::kReduceWindow: - color = "lightsalmon"; - break; - case HloOpcode::kTrace: - color = "white"; + color = kPurple; break; case HloOpcode::kWhile: - color = "forestgreen"; + shape = "ellipse"; + color = kDarkGreen; break; + case HloOpcode::kMap: case HloOpcode::kFusion: - color = "gray"; - break; - case HloOpcode::kConvolution: - color = "red"; - break; - case HloOpcode::kCrossReplicaSum: - color = "turquoise"; + color = kGray; break; + case HloOpcode::kSend: + case HloOpcode::kRecv: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: - color = "blue"; + case HloOpcode::kCrossReplicaSum: + color = kBrown; + break; + case HloOpcode::kCall: + color = kDarkGreen; + break; + case HloOpcode::kCustomCall: + color = kDarkGreen; + StrAppend(&name, "
", + "custom_call_target=", instruction->custom_call_target()); break; } // Create instruction node with appropriate label, shape, and color. + // label is interpreted as an HTML-like string, so newlines must be + // delimited with
, rather than \n. string label = - StrCat(name, "\\n", ShapeUtil::HumanString(instruction->shape())); + StrCat(name, "
", ShapeUtil::HumanString(instruction->shape())); + + if (instruction->opcode() == HloOpcode::kConstant && + ShapeUtil::IsEffectiveScalar(instruction->shape())) { + auto elem_idx = IndexUtil::LinearIndexToMultidimensionalIndex( + instruction->shape(), /*linear_index=*/0); + StrAppend(&label, " = {", + LiteralUtil::GetAsString(instruction->literal(), elem_idx), + "}"); + } + if (show_addresses) { - Appendf(&label, "\\n[%p]", instruction.get()); + Appendf(&label, "
[%p]", instruction.get()); } if (show_layouts && LayoutUtil::HasLayout(instruction->shape())) { string layout_string; @@ -272,7 +343,7 @@ string InstructionSequenceGraph( layout_string = Join(instruction->shape().layout().minor_to_major(), ","); } - StrAppend(&label, "\\nlayout={", layout_string, "}"); + StrAppend(&label, "
layout={", layout_string, "}"); } if (hlo_execution_profile != nullptr) { auto hlo_cycles_executed = @@ -280,16 +351,16 @@ string InstructionSequenceGraph( auto total_cycles_executed = hlo_execution_profile->total_cycles_executed(*instruction->parent()); if (hlo_cycles_executed > 0 && total_cycles_executed > 0) { - Appendf(&label, "\\n%% of cycles executed=%.2f", + Appendf(&label, "
%% of cycles executed=%.2f", (static_cast(hlo_cycles_executed) / static_cast(total_cycles_executed)) * 100); } } - Appendf(&graph_body, - "%s [label=\"%s\", shape=%s, style=filled, fillcolor=%s];\n", + + Appendf(&graph_body, "%s [label=<%s>, shape=%s, %s];\n", InstructionId(instruction.get()).c_str(), label.c_str(), - shape.c_str(), color.c_str()); + shape.c_str(), NodeColorAttributes(color).c_str()); // Create edges from the instruction's operands to the instruction. int64 operand_number = 0; @@ -319,7 +390,7 @@ string InstructionSequenceGraph( StrCat("cluster_", InstructionId(instruction.get())); StrAppend(&graph_body, "subgraph ", cluster_name, " {\n"); StrAppend(&graph_body, - "label=\"fused expression\";\nstyle=filled;\n" + "label=<fused expression>;\nstyle=\"rounded,filled\";\n" "color=lightgrey;\n"); StrAppend(&graph_body, InstructionSequenceGraph( instruction->fused_instructions(), @@ -349,19 +420,39 @@ string InstructionSequenceGraph( return graph_body; } +// DOT graphs accept a stylesheet as a URL. So naturally, an inline stylesheet +// is a data URI! +// +// We don't perform any escaping on this string, so be careful not to use double +// quotes inside. +static const char* dot_stylesheet = R"( +data:text/css, +@import url(https://fonts.googleapis.com/css?family=Roboto:400,700); +svg text { + font-family: 'Roboto'; + font-size: 12px; +} +)"; + string ComputationToDotGraph(const HloComputation& computation, const string& label, bool show_addresses, bool show_layouts, const HloExecutionProfile* hlo_execution_profile) { - string graph_label = StrCat(label, "\\n", computation.name()); + string graph_label = StrCat(label, "
", computation.name()); if (hlo_execution_profile != nullptr) { auto cycles = hlo_execution_profile->total_cycles_executed(computation); - Appendf(&graph_label, "\\ntotal cycles = %lld (%s)", cycles, + Appendf(&graph_label, "
total cycles = %lld (%s)", cycles, tensorflow::strings::HumanReadableNum(cycles).c_str()); } - string graph = - Printf("digraph G {\nrankdir=TB;\ncompound=true;\nlabel=\"%s\"\n", - graph_label.c_str()); + string graph = Printf( + R"(digraph G { +rankdir=TB; +compound=true; +label=<%s>; +labelloc=t; +stylesheet="%s" +)", + graph_label.c_str(), dot_stylesheet); // Emit embedded computations as subgraph clusters. std::vector intercomputation_edges; @@ -369,7 +460,9 @@ string ComputationToDotGraph(const HloComputation& computation, string graph_body = InstructionSequenceGraph( embedded->instructions(), show_addresses, show_layouts, &intercomputation_edges, hlo_execution_profile); - Appendf(&graph, "subgraph cluster_%s {\nlabel=\"%s\";\n%s}\n", + Appendf(&graph, + "subgraph cluster_%s " + "{\nstyle=rounded;label=<%s>;labelloc=t;\n%s}\n", ComputationId(embedded).c_str(), embedded->name().c_str(), graph_body.c_str()); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index c9722b942b957e8cb9788d221a79910e9f4c6539..10ab60cc8449a59ef3aefcc12f67e4738d63b900 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -213,10 +213,10 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, auto instruction = WrapUnique(new HloInstruction(HloOpcode::kConvolution, shape)); if (window_util::HasBaseDilation(window)) { - instruction->set_name(instruction->name() + "-base-dilated"); + instruction->name_ = instruction->name() + "-base-dilated"; } if (window_util::HasWindowDilation(window)) { - instruction->set_name(instruction->name() + "-window-dilated"); + instruction->name_ = instruction->name() + "-window-dilated"; } instruction->AppendOperand(lhs); instruction->AppendOperand(rhs); @@ -410,7 +410,9 @@ HloInstruction::CreateSelectAndScatter( /* static */ std::unique_ptr HloInstruction::CreateReshape( const Shape& shape, HloInstruction* operand) { CHECK_EQ(ShapeUtil::ElementsIn(shape), - ShapeUtil::ElementsIn(operand->shape())); + ShapeUtil::ElementsIn(operand->shape())) + << "shape: " << ShapeUtil::HumanString(shape) + << " operand: " << ShapeUtil::HumanString(operand->shape()); auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReshape, shape)); instruction->AppendOperand(operand); return instruction; @@ -505,16 +507,8 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( HloInstruction* clone = nullptr; if (fused_instructions_computation_ == nullptr) { // New fusion instruction. - string computation_name; - HloModule* module = GetModule(); - if (module) { - computation_name = module->GetUniqueCompuationName( - instruction_to_fuse->name() + ".fusion"); - } else { - computation_name = instruction_to_fuse->name() + ".fusion"; - } - auto builder = HloComputation::Builder(computation_name, true); - builder.AddInstruction(instruction_to_fuse->Clone()); + auto builder = HloComputation::Builder("fused_computation", true); + builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/"")); fused_instructions_computation_ = builder.Build(); clone = fused_expression_root(); clone->parent_fusion_instruction_ = this; @@ -522,7 +516,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( CHECK(fused_instructions_computation_ != nullptr && fused_instructions_computation_->IsFusionComputation()); clone = fused_instructions_computation_->AddInstruction( - instruction_to_fuse->Clone()); + instruction_to_fuse->Clone(/*suffix=*/"")); clone->parent_fusion_instruction_ = this; // instruction_to_fuse is necessarily an operand of the fusion instruction. // After fusion this will no longer be the case. Remove the operand from the @@ -578,8 +572,13 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( // instruction. Add it as an operand and add a corresponding fused // parameter instruction. int64 param_no = fused_parameters_.size(); - std::unique_ptr param_instruction = CreateParameter( - param_no, operand->shape(), StrCat("fusion_param.", param_no)); + // Name the parameter after the instruction it represents in the outer + // (non-fusion) computation. Strip the leading "%" from the operand name + // to avoid a double %%. + string param_name = + StrCat(operand->name().substr(1), ".param_", param_no); + std::unique_ptr param_instruction = + CreateParameter(param_no, operand->shape(), param_name); param_instruction->parent_fusion_instruction_ = this; fused_param = fused_instructions_computation_->AddParameter( @@ -858,32 +857,36 @@ HloInstruction::~HloInstruction() {} std::unique_ptr HloInstruction::Clone(const string& suffix) { std::unique_ptr clone = CloneWithNewOperands(shape_, operands_); - // If an instruction is cloned multiple times avoid names like - // foo.suffix.suffix.suffix. Instead of repeating the suffix add a numeric - // suffix. Specifically, the clone of foo.suffix is named foo.suffix2, the - // clone of foo.suffix2 is named foo.suffix3 and so on. - const string dot_suffix = "." + suffix; - size_t index = name().rfind(dot_suffix); - if (index == string::npos) { - // Existing name does not include ".suffix". - clone->name_ = name() + dot_suffix; + if (suffix.empty()) { + clone->name_ = name(); } else { - // Existing name includes ".suffix". Determine if substring after ".suffix" - // is numeric and should be replaced with an incremented number. - string after_suffix = name().substr(index + dot_suffix.size()); - if (after_suffix.empty()) { - // Existing name ends in ".suffix". New name should end in ".suffix2". - clone->name_ = name() + "2"; + // If an instruction is cloned multiple times avoid names like + // foo.suffix.suffix.suffix. Instead of repeating the suffix add a numeric + // suffix. Specifically, the clone of foo.suffix is named foo.suffix2, the + // clone of foo.suffix2 is named foo.suffix3 and so on. + const string dot_suffix = "." + suffix; + size_t index = name().rfind(dot_suffix); + if (index == string::npos) { + // Existing name does not include ".suffix". + clone->name_ = name() + dot_suffix; } else { - // If names ends with .suffix[0-9]+ then replace with a suffix with the - // numeric value incremented. - int64 numeric_suffix; - if (tensorflow::strings::safe_strto64(after_suffix, &numeric_suffix)) { - clone->name_ = - StrCat(name().substr(0, index), dot_suffix, numeric_suffix + 1); + // Existing name includes ".suffix". Determine if substring after + // ".suffix" is numeric and should be replaced with an incremented number. + string after_suffix = name().substr(index + dot_suffix.size()); + if (after_suffix.empty()) { + // Existing name ends in ".suffix". New name should end in ".suffix2". + clone->name_ = name() + "2"; } else { - // Substring after ".suffix" is non-numeric. - clone->name_ = name() + dot_suffix; + // If names ends with .suffix[0-9]+ then replace with a suffix with the + // numeric value incremented. + int64 numeric_suffix; + if (tensorflow::strings::safe_strto64(after_suffix, &numeric_suffix)) { + clone->name_ = + StrCat(name().substr(0, index), dot_suffix, numeric_suffix + 1); + } else { + // Substring after ".suffix" is non-numeric. + clone->name_ = name() + dot_suffix; + } } } } @@ -1080,7 +1083,7 @@ bool HloInstruction::Identical( // general, there is no need to check shape because shape is inferred from the // shape of the operands. if (opcode() != other.opcode() || - !ContainersEqual(operands(), other.operands(), eq_operands)) { + !ContainersEqual(operands(), other.operands(), std::move(eq_operands))) { return false; } @@ -1427,7 +1430,8 @@ string HloInstruction::ExtendedOpcodeStr() const { return opc_name; } -string HloInstruction::ToString(bool compact_operands) const { +string HloInstruction::ToString(bool compact_operands, + bool include_metadata) const { string operands; if (opcode() == HloOpcode::kConstant) { // For constants, show the actual value in place of an empty operand list. @@ -1508,8 +1512,9 @@ string HloInstruction::ToString(bool compact_operands) const { if (opcode() == HloOpcode::kGetTupleElement) { StrAppend(&extra, ", index=", tuple_index()); } - if (!metadata_.op_type().empty() || !metadata_.op_name().empty() || - !metadata_.source_file().empty()) { + if (include_metadata && + (!metadata_.op_type().empty() || !metadata_.op_name().empty() || + !metadata_.source_file().empty())) { StrAppend(&extra, " # metadata=", metadata_.ShortDebugString()); } @@ -1565,7 +1570,9 @@ string HloInstruction::ToCategory() const { return "non-elementwise fusion"; } case FusionKind::kInput: - return "reduce fusion"; + return "input fusion"; + case FusionKind::kOutput: + return "output fusion"; case FusionKind::kTransposeDot: return "dot fusion"; case FusionKind::kConvBackwardFilter: @@ -1613,7 +1620,6 @@ bool HloInstruction::IsFusable() const { // Some kinds of instructions don't make sense to fuse. switch (opcode_) { - case HloOpcode::kFusion: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kParameter: @@ -2181,6 +2187,8 @@ string ToString(HloInstruction::FusionKind kind) { return "kLoop"; case HloInstruction::FusionKind::kInput: return "kInput"; + case HloInstruction::FusionKind::kOutput: + return "kOutput"; case HloInstruction::FusionKind::kTransposeDot: return "kTransposeDot"; case HloInstruction::FusionKind::kConvBackwardFilter: @@ -2256,4 +2264,9 @@ HloModule* HloInstruction::GetModule() const { } return nullptr; } + +void HloInstruction::UniquifyName(NameUniquer* name_uniquer) { + name_ = name_uniquer->GetUniqueName(name_); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 3a03f28975e67880871c1e9f7d1d140e4b328c16..d300d99adec5201b70b0fe4eb65ef5b84362b018 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -53,7 +54,10 @@ class HloInstruction { public: enum class FusionKind { kLoop, // Fused into a loop. - kInput, // Fused into a reduction kernel. + kInput, // Op's input is fused into the op itself. + kOutput, // Op's output is fused into the op itself. + // REQUIRES: At least one operand buffer must be able + // to alias the output buffer. kTransposeDot, // Fused into a dot with transposed operands. kConvBackwardFilter, // Fused into a backward filter convolution. kConvBackwardInput, // Fused into a backward input convolution. @@ -488,7 +492,10 @@ class HloInstruction { string SignatureString() const; // Returns a debugging string that represents this instruction. - string ToString(bool compact_operands = false) const; + string ToString(bool compact_operands = false, + bool include_metadata = true) const; + + string ToStringNoMetadata() const { return ToString(false, false); } // As ToString, but returns a shorter string. string ToShortString() const; @@ -497,7 +504,9 @@ class HloInstruction { // or "elementwise". string ToCategory() const; - // Returns the string concatenation of parent name and this instructions name. + // Returns the string concatenation of parent name and this instructions + // name. This name is guaranteed to be unique among all instructions in the + // HloModule. string FullyQualifiedName() const; // Returns a logging instruction, if the output of this instruction is logged. @@ -721,8 +730,9 @@ class HloInstruction { // this instruction. const string& name() const { return name_; } - // Sets the string identifier for this instruction. - void set_name(const string& name) { name_ = name; } + // Use the given NameUniquer to select a unique name for the instruction based + // on the instruction's existing name. + void UniquifyName(NameUniquer* name_uniquer); // Sets the debug metadata for this instruction. void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index d0101ef19c34f27628a8a48607aad78f85e6d0f3..a226ab0d0c43e6df6216e4b0f58ed4270cb03d40 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -966,34 +966,39 @@ TEST_F(HloInstructionTest, CloneSuffixNames) { // Test that the suffix string added to cloned instructions is not // duplicated. Rather a numeric incrementing value should be appended. That // is, we want "foo.clone2", not "foo.clone.clone". - auto foo = HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.f)); - foo->set_name("foo"); // Test cloning the same instruction multiple times. - EXPECT_EQ(foo->Clone()->name(), "foo.clone"); - EXPECT_EQ(foo->Clone()->Clone()->name(), "foo.clone2"); - EXPECT_EQ(foo->Clone()->Clone()->Clone()->name(), "foo.clone3"); + auto foo = + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "foo"); + EXPECT_EQ(foo->Clone()->name(), "%foo.clone"); + EXPECT_EQ(foo->Clone()->Clone()->name(), "%foo.clone2"); + EXPECT_EQ(foo->Clone()->Clone()->Clone()->name(), "%foo.clone3"); // Test custom suffixes. - EXPECT_EQ(foo->Clone("bar")->name(), "foo.bar"); - EXPECT_EQ(foo->Clone("bar")->Clone("bar")->name(), "foo.bar2"); - EXPECT_EQ(foo->Clone("bar")->Clone("bar")->Clone()->name(), "foo.bar2.clone"); + EXPECT_EQ(foo->Clone("bar")->name(), "%foo.bar"); + EXPECT_EQ(foo->Clone("bar")->Clone("bar")->name(), "%foo.bar2"); + EXPECT_EQ(foo->Clone("bar")->Clone("bar")->Clone()->name(), + "%foo.bar2.clone"); // Test instruction name with a dot. - foo->set_name("foo.baz"); - EXPECT_EQ(foo->Clone()->name(), "foo.baz.clone"); + auto foo_baz = HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "foo.baz"); + EXPECT_EQ(foo_baz->Clone()->name(), "%foo.baz.clone"); // Test incrementing a large number after the suffix. - foo->set_name("foo.clone234"); - EXPECT_EQ(foo->Clone()->name(), "foo.clone235"); + auto foo_clone234 = HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "foo.clone234"); + EXPECT_EQ(foo_clone234->Clone()->name(), "%foo.clone235"); // Test a non-numeric string after the cloning suffix. - foo->set_name("foo.clonexyz"); - EXPECT_EQ(foo->Clone()->name(), "foo.clonexyz.clone"); + auto foo_clonexyz = HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "foo.clonexyz"); + EXPECT_EQ(foo_clonexyz->Clone()->name(), "%foo.clonexyz.clone"); // Test a name with multiple appearances of the suffix. - foo->set_name("foo.clone.clone3"); - EXPECT_EQ(foo->Clone()->name(), "foo.clone.clone4"); + auto foo_clone_clone3 = HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "foo.clone.clone3"); + EXPECT_EQ(foo_clone_clone3->Clone()->name(), "%foo.clone.clone4"); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 8ed672aa9b8fb73cc120f55d93530b3124519fcb..f5e13b4367bed5b029862f76ce2dd9eeb2b42c49 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -46,8 +46,7 @@ HloModule::HloModule(const string& name) HloComputation* HloModule::AddComputationInternal( std::unique_ptr computation) { - computation->set_name( - computation_name_uniquer_.GetUniqueName(computation->name())); + computation->UniquifyName(&computation_name_uniquer_); computation->set_parent(this); computations_.push_back(std::move(computation)); return computations_.back().get(); diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index b3168ed40ece3ea65c6b26b96250f2ea77969953..725ce17d6640fbbddbf11f4ca50c50c8c57e9bd3 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -34,15 +34,95 @@ limitations under the License. namespace xla { -PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module) - : module_(module) {} +namespace { + +// Returns the nearest call graph ancestors of instructions 'a' and 'b' for +// which the ancestors are in the same computation. An instruction is an call +// graph ancestor of 'a' if the instruction calls the computation containing 'a' +// either directly or transitively. Degeneratively an instruction is an ancestor +// of itself. nullptr is returned if there is no common ancestor or if the +// caller chain of 'a' or 'b' diverges (has multiple callers) before the nearest +// common ancestor. +// +// Example: +// +// Entry computation: +// %x = Call(A, {Constant(42.0)}) +// %y = Call(B, {%x}) +// +// Computation A: +// %a = Negate(Param()) +// +// Computation B: +// %b = Exp(Param()); +// +// If called with %a and %b, this function would return (%x, %y). %x is an +// ancestor of %a, and %y is an ancestor of %b, and %x and %y are in the same +// computation. +std::pair +GetNearestCallGraphAncestorsInSameComputation(const HloInstruction* a, + const HloInstruction* b, + const CallGraph& call_graph) { + // Lambda which returns the next instruction in the callee->caller chain in + // the call graph. This is the unique instruction which calls the computation + // containing 'instruction'. If more than one instruction calls the + // computation containing 'instruction' or no instructions call the + // computation then nullptr is returned. + auto next_caller = + [&call_graph]( + const HloInstruction* instruction) -> const HloInstruction* { + const CallGraphNode& node = call_graph.GetNode(instruction->parent()); + if (node.caller_callsites().size() != 1) { + return nullptr; + } + return node.caller_callsites()[0].instruction(); + }; + + // Iterate through the callee->caller chains and find the earliest common + // element. + for (const HloInstruction* a_ancestor = a; a_ancestor != nullptr; + a_ancestor = next_caller(a_ancestor)) { + for (const HloInstruction* b_ancestor = b; b_ancestor != nullptr; + b_ancestor = next_caller(b_ancestor)) { + if (a_ancestor->parent() == b_ancestor->parent()) { + return {a_ancestor, b_ancestor}; + } + } + } + return {nullptr, nullptr}; +} + +} // namespace -bool PredecessorHloOrdering::ExecutesBefore(const HloInstruction* a, - const HloInstruction* b) const { - // Instructions in different computations are unordered. - if (a->parent() != b->parent()) { +bool HloOrdering::ExecutesBefore(const HloInstruction* a, + const HloInstruction* b) const { + // 'a' and 'b' may be in different computations. In this case, find the + // callgraph ancestor instructions which call (potentially transitively) the + // computations containing 'a' and 'b' and use these ancestor instructions to + // compare order. + const HloInstruction* a_ancestor; + const HloInstruction* b_ancestor; + std::tie(a_ancestor, b_ancestor) = + GetNearestCallGraphAncestorsInSameComputation(a, b, *call_graph_); + + if (a_ancestor == nullptr) { + // Ancestors in a common computation could not be found so consider the + // instructions 'a' and 'b' to be unordered. return false; } + // a_ancestor and b_ancestor must be either both null or both non-null. + CHECK_NE(b_ancestor, nullptr); + CHECK_EQ(a_ancestor->parent(), b_ancestor->parent()); + return ExecutesBeforeInSameComputation(a_ancestor, b_ancestor); +} + +PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module) + : HloOrdering(module) {} + +bool PredecessorHloOrdering::ExecutesBeforeInSameComputation( + const HloInstruction* a, const HloInstruction* b) const { + CHECK_EQ(a->parent(), b->parent()); + // 'a' executes before 'b' if 'a' is in the strict predecessor set of 'b'. return strict_predecessors_.at(b->parent())->IsReachable(b, a); } @@ -86,7 +166,7 @@ string DependencyHloOrdering::ToString() const { SequentialHloOrdering::SequentialHloOrdering( const HloModule* module, const HloModuleSequence& module_sequence) - : module_(module), module_sequence_(module_sequence) { + : HloOrdering(module), module_sequence_(module_sequence) { // Create a map from instruction to its order position. for (auto computation_order : module_sequence_) { const std::vector& order = computation_order.second; @@ -97,12 +177,9 @@ SequentialHloOrdering::SequentialHloOrdering( } } -bool SequentialHloOrdering::ExecutesBefore(const HloInstruction* a, - const HloInstruction* b) const { - // Instructions in different computations are unordered. - if (a->parent() != b->parent()) { - return false; - } +bool SequentialHloOrdering::ExecutesBeforeInSameComputation( + const HloInstruction* a, const HloInstruction* b) const { + CHECK_EQ(a->parent(), b->parent()); // If either instruction is not in the order, then 'a' and 'b' are unordered. if (order_position_.count(a) == 0 || order_position_.count(b) == 0) { return false; @@ -144,23 +221,6 @@ string SequentialHloOrdering::ToString() const { return tensorflow::str_util::Join(pieces, "\n"); } -namespace { -StatusOr MinimumMemoryForSequence( - const HloComputation& computation, - const std::vector& sequence, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { - // The absolute minimum memory required for a given sequence of instructions - // is determined by the sequence of Alloc and Free calls on a simulated heap, - // ignoring fragmentation. - TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique(), sequence, - computation, points_to_analysis, size_function)); - return result.heap_size; -} -} // namespace - StatusOr MinimumMemoryForSequence( const SequentialHloOrdering::HloModuleSequence& module_sequence, const LogicalBuffer::SizeFunction& size_function) { @@ -172,17 +232,16 @@ StatusOr MinimumMemoryForSequence( TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(module)); - int64 total_memory = 0; - for (const auto& pair : module_sequence) { - const HloComputation* computation = pair.first; - const std::vector& sequence = pair.second; - TF_ASSIGN_OR_RETURN( - const int64 memory, - MinimumMemoryForSequence(*computation, sequence, *points_to_analysis, - size_function)); - total_memory += memory; - } - return total_memory; + // The absolute minimum memory required for a given sequence of instructions + // is determined by the sequence of Alloc and Free calls on a simulated heap, + // ignoring fragmentation. We run the heap simulation on the whole module, + // rather than summing each computation, since it gives us a better lower + // bound, by minimizing the liveness of sub-computations. + TF_ASSIGN_OR_RETURN( + HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique(), *module, + module_sequence, *points_to_analysis, size_function)); + return result.heap_size; } namespace { @@ -439,6 +498,18 @@ StatusOr> RunDFSMemoryScheduler( return sequence; } +StatusOr MinimumMemoryForComputation( + const HloComputation& computation, + const std::vector& sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function) { + TF_ASSIGN_OR_RETURN( + HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique(), computation, + sequence, points_to_analysis, size_function)); + return result.heap_size; +} + StatusOr> CreateMemoryMinimizingSequence( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, @@ -446,13 +517,17 @@ StatusOr> CreateMemoryMinimizingSequence( // We try both a list-scheduler based ordering and a DFS based ordering, and // choose whichever returns a lower min-memory, not accounting for // fragmentation. + // + // Note that this is just a heuristic. One obvious inaccuracy is that the + // memory required for sub-computations might be different when considered + // within the caller's context. But it's good enough for now. TF_ASSIGN_OR_RETURN( std::vector list_sequence, ListScheduler::Run(computation, points_to_analysis, size_function)); TF_ASSIGN_OR_RETURN( const int64 list_memory, - MinimumMemoryForSequence(computation, list_sequence, points_to_analysis, - size_function)); + MinimumMemoryForComputation(computation, list_sequence, + points_to_analysis, size_function)); VLOG(2) << "Min-memory list sequence: " << list_memory << " bytes"; TF_ASSIGN_OR_RETURN( @@ -460,8 +535,8 @@ StatusOr> CreateMemoryMinimizingSequence( RunDFSMemoryScheduler(computation, points_to_analysis, size_function)); TF_ASSIGN_OR_RETURN( const int64 dfs_memory, - MinimumMemoryForSequence(computation, dfs_sequence, points_to_analysis, - size_function)); + MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis, + size_function)); VLOG(2) << "Min-memory dfs sequence: " << dfs_memory << " bytes"; if (list_memory <= dfs_memory) { diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index e964c4c51ae14f89d1f1b0450990cfc50c8a74be..d2db18be0009b1ca62b538d3975e1a0a105c5e83 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" @@ -36,13 +37,13 @@ namespace xla { // buffers. class HloOrdering { public: - HloOrdering() = default; + HloOrdering(const HloModule* module) + : module_(module), call_graph_(CallGraph::Build(module)) {} virtual ~HloOrdering() = default; // Returns true if instruction 'a' executes before instruction 'b'. This is // not reflexive, that is, an instruction does not execute before itself. - virtual bool ExecutesBefore(const HloInstruction* a, - const HloInstruction* b) const = 0; + bool ExecutesBefore(const HloInstruction* a, const HloInstruction* b) const; // Returns the sequential instruction order for the given computation, or // nullptr if the computation does not have a sequential ordering. @@ -50,6 +51,21 @@ class HloOrdering { const HloComputation& computation) const = 0; virtual string ToString() const = 0; + + protected: + // Returns true if instruction 'a' executes before instruction 'b'. + // Precondition: 'a' and 'b' are in the same computation. + // + // Derived classes should implement this method for determining order of + // instructions in the same comptuation. ExecutesBefore() analyzes the + // callgraph and uses this method to determine ordering of instructions in + // different computations. + virtual bool ExecutesBeforeInSameComputation( + const HloInstruction* a, const HloInstruction* b) const = 0; + + const HloModule* module_; + + std::unique_ptr call_graph_; }; // Base class for partial orderings implemented by a map of strict predecessors @@ -58,11 +74,6 @@ class PredecessorHloOrdering : public HloOrdering { public: ~PredecessorHloOrdering() override = default; - // Returns true if instruction 'a' executes before instruction 'b'. - // Instructions in different computations are not ordered. - bool ExecutesBefore(const HloInstruction* a, - const HloInstruction* b) const override; - // Returns nullptr indicating the computation does not have a sequential // ordering. const std::vector* SequentialOrder( @@ -74,11 +85,12 @@ class PredecessorHloOrdering : public HloOrdering { explicit PredecessorHloOrdering(const HloModule* module); string ToStringHelper(const string& name) const; - const HloModule* module_; + bool ExecutesBeforeInSameComputation(const HloInstruction* a, + const HloInstruction* b) const override; - // For each each computation in the module, this is the set of the - // instruction's strict predecessors. An instruction is not an element of its - // own strict predecessor set. + // For each computation in the module, this is the set of the instruction's + // strict predecessors. An instruction is not an element of its own strict + // predecessor set. // // Subclasses should fill this in to define the desired ordering. tensorflow::gtl::FlatMap* SequentialOrder( const HloComputation& computation) const override; @@ -163,7 +169,9 @@ class SequentialHloOrdering : public HloOrdering { string ToString() const override; protected: - const HloModule* module_; + bool ExecutesBeforeInSameComputation(const HloInstruction* a, + const HloInstruction* b) const override; + const HloModuleSequence module_sequence_; // The position of every instruction in the HLO module in its respective diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index 425bee601a8d6357e21d3d00f8ccf5d69af03862..c387fbb89b196c340852db057754f85e3e5435f3 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -78,6 +78,142 @@ TEST_F(HloOrderingTest, LastUseScheduledFirst) { EXPECT_TRUE(ordering.ExecutesBefore(add, negate)); } +TEST_F(HloOrderingTest, InstructionsInDifferentComputations) { + // Tests the ordering of instructions in different computations using the + // following HLO code: + // + // Entry computation: + // %x = Call(A, {}) + // %y = Call(B, {%x}) + // + // Computation A: + // %a = Call(C, {}) + // + // Computation B: + // %b = Call(C, {}) + // + // Computation C: + // %c = Constant(42.0f) + // + // This results in a diamond-shaped callgraph. + HloModule module(TestName()); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + + auto builder_c = HloComputation::Builder("C"); + HloInstruction* c = builder_c.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloComputation* computation_c = + module.AddEmbeddedComputation(builder_c.Build()); + + auto builder_b = HloComputation::Builder("B"); + builder_b.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + HloInstruction* b = builder_b.AddInstruction( + HloInstruction::CreateCall(scalar_shape, {}, computation_c)); + HloComputation* computation_b = + module.AddEmbeddedComputation(builder_b.Build()); + + auto builder_a = HloComputation::Builder("A"); + HloInstruction* a = builder_a.AddInstruction( + HloInstruction::CreateCall(scalar_shape, {}, computation_c)); + HloComputation* computation_a = + module.AddEmbeddedComputation(builder_a.Build()); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* x = builder.AddInstruction( + HloInstruction::CreateCall(scalar_shape, {}, computation_a)); + HloInstruction* y = builder.AddInstruction( + HloInstruction::CreateCall(scalar_shape, {x}, computation_b)); + module.AddEntryComputation(builder.Build()); + + DependencyHloOrdering ordering(&module); + EXPECT_TRUE(ordering.ExecutesBefore(x, y)); + EXPECT_FALSE(ordering.ExecutesBefore(y, x)); + + EXPECT_TRUE(ordering.ExecutesBefore(a, b)); + EXPECT_FALSE(ordering.ExecutesBefore(b, a)); + + EXPECT_FALSE(ordering.ExecutesBefore(a, x)); + EXPECT_TRUE(ordering.ExecutesBefore(a, y)); + EXPECT_FALSE(ordering.ExecutesBefore(x, a)); + EXPECT_FALSE(ordering.ExecutesBefore(y, a)); + + EXPECT_FALSE(ordering.ExecutesBefore(b, x)); + EXPECT_FALSE(ordering.ExecutesBefore(b, y)); + EXPECT_TRUE(ordering.ExecutesBefore(x, b)); + EXPECT_FALSE(ordering.ExecutesBefore(y, b)); + + // Instruction 'c' is called from multiple callsites and should be unordered + // relative to all other instructions in the module. + EXPECT_FALSE(ordering.ExecutesBefore(c, a)); + EXPECT_FALSE(ordering.ExecutesBefore(c, b)); + EXPECT_FALSE(ordering.ExecutesBefore(c, x)); + EXPECT_FALSE(ordering.ExecutesBefore(c, y)); + EXPECT_FALSE(ordering.ExecutesBefore(a, c)); + EXPECT_FALSE(ordering.ExecutesBefore(b, c)); + EXPECT_FALSE(ordering.ExecutesBefore(x, c)); + EXPECT_FALSE(ordering.ExecutesBefore(y, c)); +} + +class MinimumMemoryForSequenceTest : public HloTestBase {}; + +TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { + HloModule module(TestName()); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); + + auto cond_builder = HloComputation::Builder("WhileCond"); + // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); + HloInstruction* cond_iter = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0)); + HloInstruction* cond_data = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); + // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte) + HloInstruction* cond_lt = cond_builder.AddInstruction( + HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), + HloOpcode::kLt, cond_iter, cond_data)); + HloComputation* cond_computation = + module.AddEmbeddedComputation(cond_builder.Build()); + + auto body_builder = HloComputation::Builder("WhileBody"); + // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "body_param")); + HloComputation* body_computation = + module.AddEmbeddedComputation(body_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + // Entry params: 8 bytes (4 bytes per param), TOTAL=8 + HloInstruction* iter = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param_iter")); + HloInstruction* data = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "param_data")); + // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24 + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({iter, data})); + // While: 8 bytes (4 bytes per element), TOTAL=32 + // Both cond and body use a max of 24 bytes, TOTAL=56 + HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( + tuple_shape, cond_computation, body_computation, tuple)); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build()); + + auto size_fn = [](const LogicalBuffer& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + }; + + SequentialHloOrdering::HloModuleSequence module_sequence; + module_sequence[cond_computation] = {cond_param, cond_iter, cond_data, + cond_lt}; + module_sequence[body_computation] = {body_param}; + module_sequence[entry_computation] = {iter, data, tuple, while_op}; + EXPECT_EQ(56, + MinimumMemoryForSequence(module_sequence, size_fn).ValueOrDie()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index eb7fe467b32a330d9b8ad6000ad47849288b6b7e..78aebe9c36dfb5f63099f5e2df7bffe8529b08de 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -42,11 +42,17 @@ void DumpModule(const Compiler::HloDumper& dumper_, const HloModule& module, StatusOr HloPassPipeline::Run(HloModule* module) { run_called_ = true; + VLOG(1) << "Running HLO pass pipeline " << name(); + legacy_flags::HloPassPipelineFlags* flags = legacy_flags::GetHloPassPipelineFlags(); std::vector tmp = tensorflow::str_util::Split(flags->xla_disable_hlo_passes, ','); tensorflow::gtl::FlatSet disabled_passes(tmp.begin(), tmp.end()); + if (!disabled_passes.empty()) { + VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: " + << tensorflow::str_util::Join(disabled_passes, ", "); + } auto run_invariant_checkers = [this, module]() -> Status { for (auto& invariant_checker : invariant_checkers_) { @@ -62,9 +68,13 @@ StatusOr HloPassPipeline::Run(HloModule* module) { for (auto& pass : passes_) { if (!disabled_passes.empty() && disabled_passes.count(pass->name().ToString()) > 0) { + VLOG(1) << " Skipping HLO pass " << pass->name() + << ", disabled by --xla_disable_hlo_passes"; continue; } + VLOG(1) << " HLO pass " << pass->name(); + // Emit label containing: "after foo-pass, before bar-pass". message.clear(); StrAppend(&message, prefix, ", before ", pass->name()); diff --git a/tensorflow/compiler/xla/service/hlo_query.cc b/tensorflow/compiler/xla/service/hlo_query.cc index d6997378642cf480402b2edf8f40ed875fefa517..a153d73dbd838663c0d7e0d72ad54668f243f2c2 100644 --- a/tensorflow/compiler/xla/service/hlo_query.cc +++ b/tensorflow/compiler/xla/service/hlo_query.cc @@ -32,6 +32,16 @@ bool IsConstantR0F32(HloInstruction* instruction, float* out) { return false; } +bool AllOperandsAreParametersOrConstants(const HloInstruction& instruction) { + for (const auto& operand : instruction.operands()) { + if (operand->opcode() != HloOpcode::kParameter && + operand->opcode() != HloOpcode::kConstant) { + return false; + } + } + return true; +} + bool AllOperandsAreParameters(const HloInstruction& instruction) { for (const auto& operand : instruction.operands()) { if (operand->opcode() != HloOpcode::kParameter) { diff --git a/tensorflow/compiler/xla/service/hlo_query.h b/tensorflow/compiler/xla/service/hlo_query.h index 56f3cfd863ce0b9004d14e6c43d41f21b6e7a3bf..c79347bbf9d6146943b7b787f713369cb37fadee 100644 --- a/tensorflow/compiler/xla/service/hlo_query.h +++ b/tensorflow/compiler/xla/service/hlo_query.h @@ -28,6 +28,10 @@ namespace hlo_query { // Precondition: out != nullptr bool IsConstantR0F32(HloInstruction* instruction, float* out); +// Returns whether all of an instruction's operands are of the types constants +// and parameters. +bool AllOperandsAreParametersOrConstants(const HloInstruction& instruction); + // Returns whether all of an instruction's operands are parameters. bool AllOperandsAreParameters(const HloInstruction& instruction); diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 101c9076f8d1cb3e079ad665177751ccccfe65d9..5d4fd7c2deae7e1b03f49f123e2aff174ab34667 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -29,8 +29,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -46,63 +46,58 @@ namespace xla { namespace { -// Returns a vector of the operands of 'instruction' with repeated elements -// removed. -std::vector UniqueOperands(const HloInstruction* instruction) { - std::vector unique_operands; - for (HloInstruction* operand : instruction->operands()) { - if (std::find(unique_operands.begin(), unique_operands.end(), operand) == - unique_operands.end()) { - unique_operands.push_back(operand); - } - } - return unique_operands; -} - // Returns true if the given instruction is rematerializable. bool IsRematerializable(const HloInstruction* instruction) { + // Conservatively, don't rematerialize instruction with control + // dependencies. For one, control dependencies are added to prevent + // interference of aliased buffers (say, in while bodies) and + // rematerialization is ignorant of liveness and may break the intended + // ordering. + if (!instruction->control_predecessors().empty() || + !instruction->control_successors().empty()) { + return false; + } + // Don't rematerialize instructions with side effects, those with a cost that // might not be captured by HloCostAnalysis, or instructions which cannot be // cloned safely. switch (instruction->opcode()) { case HloOpcode::kCall: + case HloOpcode::kConstant: case HloOpcode::kCrossReplicaSum: case HloOpcode::kCustomCall: case HloOpcode::kOutfeed: case HloOpcode::kInfeed: + case HloOpcode::kParameter: case HloOpcode::kRecv: case HloOpcode::kSend: case HloOpcode::kTrace: case HloOpcode::kWhile: return false; default: - break; - } - - // Skip tuple shapes because we do not currently account for buffer aliasing - // properly which results in improperly accounting of rematerialization cost - // for these shapes. - if (ShapeUtil::IsTuple(instruction->shape())) { - return false; - } - for (auto* operand : instruction->operands()) { - if (ShapeUtil::IsTuple(operand->shape())) { - return false; - } + return true; } - - return true; } -// Class which maintains an ordered list of instructions with fast insertion and -// removal of arbitrary elements. +// Class which maintains an ordered list of instructions with fast insertion +// before arbitrary elements. class InstructionList { public: explicit InstructionList(const std::vector order) { + int64 position = 0; for (const HloInstruction* inst : order) { instructions_.push_back(const_cast(inst)); instruction_iterators_.insert({const_cast(inst), std::next(instructions_.end(), -1)}); + // Initially position numbers are uniquely assigned in order. Later as + // instructions are added with InsertBefore* methods, some instructions + // may have duplicate position numbers, but the values will be guaranteed + // to be monotonically increasing through the list, and so is still useful + // for quickly(-ish) determining the order of arbitrary instructions in + // the list. + position_number_[inst] = position; + first_at_position_[position] = inst; + position++; } } @@ -111,22 +106,63 @@ class InstructionList { return instructions_; } - // Insert instruction 'to_insert' before instruction 'before' in the list. - Status InsertBefore(HloInstruction* to_insert, HloInstruction* before) { + // Insert instruction 'to_insert' immediately before instruction 'before' in + // the list. + void InsertBefore(HloInstruction* to_insert, HloInstruction* before) { + VLOG(3) << "InsertBefore: " << to_insert->name() << " before " + << before->name(); auto it = instruction_iterators_.find(before); - TF_RET_CHECK(it != instruction_iterators_.end()); + CHECK(it != instruction_iterators_.end()); instruction_iterators_.insert( {to_insert, instructions_.insert(it->second, to_insert)}); - return Status::OK(); + // Assign the same position number to the newly added instruction as + // 'before'. This guarantees monotonicity of the position numbers, but not + // uniqueness. + int64 pos = position_number_.at(before); + position_number_[to_insert] = pos; + if (first_at_position_.at(pos) == before) { + first_at_position_[pos] = to_insert; + } } - // Removes instruction from the list. - Status Remove(HloInstruction* instruction) { - auto it = instruction_iterators_.find(instruction); - TF_RET_CHECK(it != instruction_iterators_.end()); - instructions_.erase(it->second); - instruction_iterators_.erase(it); - return Status::OK(); + // Insert instruction 'to_insert' immediately before the earliest instruction + // in 'before_instructions'. + void InsertBeforeInstructions( + HloInstruction* to_insert, + tensorflow::gtl::ArraySlice before_instructions) { + VLOG(3) << "InsertBeforeInstructions: " << to_insert->name() << " before {" + << tensorflow::str_util::Join( + before_instructions, ", ", + [](string* out, HloInstruction* inst) { + tensorflow::strings::StrAppend(out, inst->name()); + }) + << "}"; + + // Find the minimal position number of any instruction in + // 'before_instructions'. + CHECK(!before_instructions.empty()); + int64 min_position_number = std::numeric_limits::max(); + for (const HloInstruction* instruction : before_instructions) { + min_position_number = + std::min(min_position_number, position_number_.at(instruction)); + } + + // Because more than one instruction in 'before_instructions' may have a + // position number of 'min_position_number', find the first such instruction + // with position number 'min_position_number'. + for (auto it = instruction_iterators_.at( + first_at_position_.at(min_position_number)); + it != instructions_.end() && + position_number_.at(*it) == min_position_number; + ++it) { + if (std::find(before_instructions.begin(), before_instructions.end(), + *it) != before_instructions.end()) { + return InsertBefore(to_insert, *it); + } + } + LOG(FATAL) << "Expected to find instruction in before_instructions with " + "position number " + << min_position_number; } private: @@ -137,283 +173,630 @@ class InstructionList { tensorflow::gtl::FlatMap::iterator> instruction_iterators_; + + // A number assigned to each instruction which increases monotonically through + // 'instructions_'. Used to facilitate fast insertion of an instruction before + // the earliest instruction in a set of instructions + // (InsertBeforeInstructions) by enabling fast-ish ordering queries between + // instructions. If position_number_[a] < position_number_[b] then 'a' comes + // before 'b' in the list. If the position numbers are the same then nothing + // can be said about their order without examining the list. + // + // On object construction this value is precisely the instruction's ordinal + // position in the list. Instructions inserted via InsertBefore receive + // duplicate values. However, monotonicity is preserved. + tensorflow::gtl::FlatMap position_number_; + + // The first instruction in the list assigned a particular position number. + tensorflow::gtl::FlatMap first_at_position_; }; +// Return the HloInstructions which use the given LogicalBuffer. Sets +// has_indirect_users to whether any of the uses is indirect. A use is indirect +// if the instruction defining logical_buffer is not an operand of the use. This +// can happen via buffer aliasing (eg, tuples). +std::vector GetUsers( + const LogicalBuffer* logical_buffer, + const TuplePointsToAnalysis& points_to_analysis, bool* has_indirect_users) { + std::vector users; + // To identify uses iterate through all HloInstruction users of the + // BufferAliases of the logical buffer. + *has_indirect_users = false; + for (const BufferAlias& buffer_alias : + points_to_analysis.GetBufferAliases(*logical_buffer)) { + for (const HloInstruction* user : buffer_alias.instruction()->users()) { + if (DoesNotUseOperandBuffer(buffer_alias.instruction(), + buffer_alias.index(), user, + points_to_analysis)) { + // The alias may be an operand of 'user', but the LogicalBuffer cannot + // possibly be used by the instruction so ignore 'user'. This is the + // case, for example, for the tuple element buffers in a GetTupleElement + // instruction (the GTE instruction only uses the pointer vector). + continue; + } + if (buffer_alias.instruction() != logical_buffer->instruction()) { + *has_indirect_users = true; + } + // A buffer may be used by the instruction via more than one alias. For + // example, a buffer which appears in more than one element of a tuple. + if (std::find(users.begin(), users.end(), user) == users.end()) { + users.push_back(user); + } + } + } + return users; +} + // Class for tracking memory usage of a computation as the instructions are -// placed sequentially. Memory usage is the sum of live values at the current -// point in the instruction sequence. +// placed sequentially. Memory usage is the sum of the sizes of live values +// (LogicalBuffers) at the current point in the instruction sequence. class MemoryUsageTracker { public: MemoryUsageTracker( const HloComputation* computation, - const HloRematerialization::ShapeSizeFunction& size_function) - : computation_(computation), size_function_(size_function) { - for (const std::unique_ptr& instruction : - computation->instructions()) { - // Initially only live-in values occupy memory. - if (IsLiveIn(instruction.get())) { - memory_usage_ += TotalSizeBytes(instruction->shape()); - } + const HloRematerialization::ShapeSizeFunction& size_function, + const TuplePointsToAnalysis& points_to_analysis, + const InstructionList& instruction_list); + + // Starts the placement of the given instruction. This adds the sizes of the + // LogicalBuffers defined by the instruction to the current memory + // usage. Placement is broken into two steps (BeginInstruction and + // EndInstruction) to accurately model memory usage. At BeginInstruction the + // memory for the output value(s) of the current instruction is allocated. At + // EndInstruction memory for dead operand(s) is freed. + Status BeginInstruction(const HloInstruction* instruction); + + // Finishes the placement of the current instruction. This frees any dead + // operands or dead result of the instruction. This must be called after + // each call to BeginInstruction. + Status EndInstruction(); + + // Returns the number of bytes that the current memory usage will be reduced + // if the given instruction is rematerialized. + int64 MemoryReducedIfRematerialized(const HloInstruction* instruction) const; + + // Adjusts memory usage to account for the rematerialization of + // original_instruction for all remaining unplaced uses. The rematerialization + // is remat_instruction. This method should be called after the HLO graph has + // been transformed (rematerialization instruction created and connected to + // uses). + Status AddRematerializedInstruction(HloInstruction* original_instruction, + HloInstruction* remat_instruction); + + // Returns whether the given instruction has been placed (BeginInstruction + // has been called with 'instruction' as the argument). + bool IsPlaced(const HloInstruction* instruction) const { + return ContainsKey(placed_instructions_, instruction); + } + + // Returns the current memory usage. This is the sum of sizes of all live + // values. + int64 memory_usage() const { return memory_usage_; } + + // Returns the current instruction being placed. + const HloInstruction* in_progress_instruction() const { + return in_progress_instruction_; + } + + // Check invariants of the data structure. This is expensive to call. + bool Check() const; + + string ToString() const; + + private: + // Type holding a unique identifier for each Buffer object. + using BufferId = int64; + + // A Buffer represents a single LogicalBuffer in the computation including + // various metadata useful for tracking liveness of the value. A LogicalBuffer + // is not used directly because the HLO graph is transformed and + // TuplePointsToAnalysis which owns all LogicalBuffers cannot be updated after + // HLO graph transformations. + struct Buffer { + // The unique id of this Buffer. This value is equal to the buffer's index + // in the vector buffers_. + const BufferId id; + + // The instruction which defines this buffer. + const HloInstruction* defining_instruction; + + // The materialized size of the buffer in bytes. + const int64 size; + + // Whether this buffer is live-out of the computation. + bool live_out; + + // Whether this buffer has indirect uses. Ie, an instruction which is not a + // user of defining_instruction uses this buffer. This can occur due to + // buffer aliasing (eg, tuples). + bool has_indirect_uses; + + // The instructions which use this buffer. + std::vector users; + + // The number of users (HloInstructions) of this buffer which have not yet + // been placed in the sequence. + int64 unfinished_user_count; + + string ToString() const { + return tensorflow::strings::StrCat("Buffer ", id, " (defined by ", + defining_instruction->name(), + ", size ", size, " bytes)"); } + }; + + // Creates a Buffer representing the given logical buffer. The buffer is added + // to buffers_ and a reference is returned. + Buffer& CreateBufferFromLogicalBuffer( + const LogicalBuffer* logical_buffer, + const TuplePointsToAnalysis& points_to_analysis, + const HloRematerialization::ShapeSizeFunction& size_function, + bool live_out) { + bool has_indirect_uses = false; + std::vector users = + GetUsers(logical_buffer, points_to_analysis, &has_indirect_uses); + return NewBuffer(logical_buffer->instruction(), + size_function(logical_buffer->shape()), std::move(users), + live_out, has_indirect_uses); } - // Starts the placement of the given instruction. This adds the output size of - // the instruction to the current memory usage. Placement is broken into two - // steps (BeginInstruction and EndInstruction) to accurately model memory - // usage. At BeginInstruction the memory for the output value of the current - // instruction is allocated. At EndInstruction memory for dead operands is - // freed. - Status BeginInstruction(const HloInstruction* instruction) { - VLOG(3) << "BeginInstruction " << instruction->name(); - TF_RET_CHECK(in_progress_instruction_ == nullptr); - in_progress_instruction_ = instruction; - - // Add instruction to remaining_uses_. - TF_RET_CHECK(!ContainsKey(remaining_uses_, instruction)); - std::vector& instruction_uses = - remaining_uses_[instruction]; - instruction_uses.insert(instruction_uses.begin(), - instruction->users().begin(), - instruction->users().end()); - - if (!IsLiveIn(instruction)) { - // Instruction was not previously live so add output size to memory usage. - memory_usage_ += TotalSizeBytes(instruction->shape()); + // Create a new buffer representing a rematerialization of given buffer for + // the given uses. + Buffer& RematerializeBuffer( + const Buffer& original_buffer, const HloInstruction* remat_instruction, + std::vector&& rematerialized_uses) { + CHECK(IsPlaced(original_buffer.defining_instruction)); + CHECK(!original_buffer.has_indirect_uses); + CHECK(!original_buffer.live_out); + for (const HloInstruction* use : rematerialized_uses) { + CHECK(!IsPlaced(use)); } + return NewBuffer(remat_instruction, original_buffer.size, + std::move(rematerialized_uses), /*live_out=*/false, + /*has_indirect_uses=*/false); + } + + // Return number of bytes allocated for the buffer with the given id. Buffers + // allocated by the calling computation (eg, parameter and output buffers) are + // considered to have zero bytes because the memory is accounted for in a + // different computation. + int64 AllocatedSize(BufferId buffer_id) const { + const Buffer& buffer = buffers_.at(buffer_id); + HloOpcode def_opcode = buffer.defining_instruction->opcode(); + if (buffer.live_out || def_opcode == HloOpcode::kParameter) { + return 0; + } else { + return buffer.size; + } + } - VLOG(3) << " memory usage = " << memory_usage_; - VLOG(10) << ToString(); - return Status::OK(); + // Returns true if BeginInstruction and EndInstruction has been called for the + // given instruction. + bool IsFinished(const HloInstruction* instruction) const { + return IsPlaced(instruction) && instruction != in_progress_instruction_; } - // Finishes the placement of the current instruction. This frees any dead - // operands or dead result of the instruction. This must be called after each - // call to BeginInstruction. - Status EndInstruction() { - TF_RET_CHECK(in_progress_instruction_ != nullptr); - VLOG(3) << "EndInstruction " << in_progress_instruction_->name(); - - for (HloInstruction* operand : UniqueOperands(in_progress_instruction_)) { - TF_RET_CHECK(ContainsKey(remaining_uses_, operand)); - std::vector& uses = remaining_uses_.at(operand); - auto it = std::find(uses.begin(), uses.end(), in_progress_instruction_); - TF_RET_CHECK(it != uses.end()); - uses.erase(it); - - if (uses.empty()) { - // Operand is dead. - int64 operand_size = TotalSizeBytes(operand->shape()); - if (!IsLiveOut(operand)) { - VLOG(4) << operand->name() << " (" - << HumanReadableNumBytes(operand_size) << ") is dead"; - memory_usage_ -= operand_size; - TF_RET_CHECK(memory_usage_ >= 0); + // Returns whether the given buffer is being used by the in-progress + // instruction. + bool IsInUse(BufferId buffer_id) const { + if (in_progress_instruction_ == nullptr) { + return false; + } + const std::vector& in_progress_uses = + buffers_used_by_instruction_.at(in_progress_instruction_); + return std::find(in_progress_uses.begin(), in_progress_uses.end(), + buffer_id) != in_progress_uses.end(); + } + + // Returns whether the given instruction is live at the current program + // point. + bool IsCurrentlyLive(BufferId buffer_id) const { + const Buffer& buffer = buffers_[buffer_id]; + return (IsPlaced(buffer.defining_instruction) && + buffer.unfinished_user_count > 0); + } + + // Create a new buffer, add it to buffers_, and return a reference. + Buffer& NewBuffer(const HloInstruction* defining_instruction, int64 size, + std::vector&& users, bool live_out, + bool has_indirect_uses) { + int buffer_id = buffers_.size(); + buffers_.push_back(Buffer{buffer_id, defining_instruction, size, live_out, + has_indirect_uses, users, + static_cast(users.size())}); + return buffers_.back(); + } + + const HloComputation* computation_; + + // Instruction list containing the ordering of instructions in + // computation_. This is the order in which instructions are placed + // (BeginInstruction/EndInstruction calls). + const InstructionList& instruction_list_; + + // Memory usage at the currently placed instruction. + int64 memory_usage_ = 0; + + // The instruction currently being placed. This value is non-null only + // between the calling of BeginInstruction and EndInstruction. + const HloInstruction* in_progress_instruction_ = nullptr; + + // The buffers defined by each instruction. + std::unordered_map> + buffers_defined_by_instruction_; + + // The buffers used by each instruction. + std::unordered_map> + buffers_used_by_instruction_; + + // The set of instructions which have been placed. That is, BeginInstruction + // has been called with the instruction as an argument. + tensorflow::gtl::FlatSet placed_instructions_; + + // All buffers in the computation. + std::vector buffers_; +}; + +MemoryUsageTracker::MemoryUsageTracker( + const HloComputation* computation, + const HloRematerialization::ShapeSizeFunction& size_function, + const TuplePointsToAnalysis& points_to_analysis, + const InstructionList& instruction_list) + : computation_(computation), instruction_list_(instruction_list) { + // Iterate through all LogicalBuffers in the computation and gather the + // instructions which define them in buffers_defined_by_instruction_ and the + // instructions which use them in buffers_used_by_instruction_. + for (auto& instruction : computation_->instructions()) { + // Initialize empty vectors for defs and uses of each instruction. + buffers_used_by_instruction_[instruction.get()]; + buffers_defined_by_instruction_[instruction.get()]; + } + + tensorflow::gtl::FlatSet live_out_set = + points_to_analysis.GetPointsToSet(computation_->root_instruction()) + .CreateFlattenedSet(); + tensorflow::gtl::FlatMap + logical_buffer_to_buffer_id; + + for (const HloInstruction* instruction : instruction_list_.instructions()) { + for (const LogicalBuffer* logical_buffer : + points_to_analysis.GetBuffersDefinedByInstruction(instruction)) { + Buffer* buffer; + if (instruction->opcode() == HloOpcode::kWhile) { + // The while instruction defines no new buffers. Instead it reuses the + // buffers of its operand. Find the Buffer of its operand at the + // proper ShapeIndex. + const PointsToSet& operand_points_to = + points_to_analysis.GetPointsToSet(instruction->operand(0)); + CHECK_EQ(operand_points_to.element(logical_buffer->index()).size(), 1); + const LogicalBuffer* source_logical_buffer = + operand_points_to.element(logical_buffer->index())[0]; + buffer = + &buffers_.at(logical_buffer_to_buffer_id.at(source_logical_buffer)); + + // Mark buffer as has indirect use and live out. + buffer->has_indirect_uses = true; + buffer->live_out = + buffer->live_out || ContainsKey(live_out_set, logical_buffer); + + // Add users of while to Buffer users. + bool unused; + for (const HloInstruction* user : + GetUsers(logical_buffer, points_to_analysis, &unused)) { + if (std::find(buffer->users.begin(), buffer->users.end(), user) == + buffer->users.end()) { + buffer->users.push_back(user); + buffer->unfinished_user_count++; + buffers_used_by_instruction_.at(user).push_back(buffer->id); + } + } + } else { + buffer = &CreateBufferFromLogicalBuffer( + logical_buffer, points_to_analysis, size_function, + ContainsKey(live_out_set, logical_buffer)); + buffers_defined_by_instruction_.at(instruction).push_back(buffer->id); + for (const HloInstruction* user : buffer->users) { + buffers_used_by_instruction_.at(user).push_back(buffer->id); } } - } - // Value is dead if the instruction has no uses and is not live out. - if (in_progress_instruction_->users().empty() && - !IsLiveOut(in_progress_instruction_)) { - memory_usage_ -= TotalSizeBytes(in_progress_instruction_->shape()); - TF_RET_CHECK(memory_usage_ >= 0); + logical_buffer_to_buffer_id[logical_buffer] = buffer->id; } + } + XLA_VLOG_LINES(10, ToString()); + DCHECK(Check()); +} + +Status MemoryUsageTracker::BeginInstruction(const HloInstruction* instruction) { + VLOG(3) << "BeginInstruction " << instruction->name(); + TF_RET_CHECK(in_progress_instruction_ == nullptr); + in_progress_instruction_ = instruction; - in_progress_instruction_ = nullptr; + placed_instructions_.insert(in_progress_instruction_); - VLOG(3) << " memory usage = " << memory_usage_; - VLOG(10) << ToString(); - return Status::OK(); + // All buffers defined by this instruction need memory. + for (BufferId buffer_id : buffers_defined_by_instruction_.at(instruction)) { + VLOG(3) << " Buffer " << buffers_.at(buffer_id).ToString() + << " is now live."; + memory_usage_ += AllocatedSize(buffer_id); } - // Adjusts memory usage to account for the rematerialization of - // original_instruction for the given use. The rematerialization is - // remat_instruction. This method should be called after the HLO graph has - // been transformed (rematerialization instruction created and connected to - // its use). - Status RematerializeInstructionForUse(HloInstruction* original_instruction, - HloInstruction* remat_instruction, - HloInstruction* use) { - VLOG(3) << "RematerializeInstructionForUse: original_instruction = " - << original_instruction->name() - << ", remat_instruction = " << remat_instruction->name() - << ", use = " << use->name(); - - TF_RET_CHECK(in_progress_instruction_ != nullptr); - TF_RET_CHECK(IsPlaced(original_instruction)); - TF_RET_CHECK(!IsPlaced(remat_instruction)); - TF_RET_CHECK(!IsPlaced(use)); - TF_RET_CHECK(IsCurrentlyLive(original_instruction)); - - // Remove 'use' from remaining uses of original_instruction. - auto it = std::find(remaining_uses_[original_instruction].begin(), - remaining_uses_[original_instruction].end(), use); - TF_RET_CHECK(it != remaining_uses_[original_instruction].end()); - remaining_uses_[original_instruction].erase(it); - - // If original_instruction is no longer live ('use' was its last use) then - // deduct original_instruction's memory usage. - if (!IsCurrentlyLive(original_instruction)) { - memory_usage_ -= TotalSizeBytes(original_instruction->shape()); - TF_RET_CHECK(memory_usage_ >= 0); + // TODO(b/37686934): Elementwise instructions can share the buffer of a (dead) + // operand. Account for this potential reuse here. + + VLOG(3) << " memory usage = " << memory_usage_; + VLOG(10) << ToString(); + + DCHECK(Check()); + return Status::OK(); +} + +Status MemoryUsageTracker::EndInstruction() { + TF_RET_CHECK(in_progress_instruction_ != nullptr); + VLOG(3) << "EndInstruction " << in_progress_instruction_->name(); + + for (BufferId buffer_id : + buffers_used_by_instruction_.at(in_progress_instruction_)) { + Buffer& buffer = buffers_.at(buffer_id); + buffer.unfinished_user_count--; + CHECK_GE(buffer.unfinished_user_count, 0) + << buffer.ToString() << " has negative unfinished use count."; + if (buffer.unfinished_user_count == 0) { + // Buffer is now dead. + VLOG(3) << " " << buffer.ToString() << " is now dead."; + memory_usage_ -= AllocatedSize(buffer_id); + CHECK_GE(memory_usage_, 0); } + } - // Add the new remat_instruction to the remaining uses of its operands. - for (auto* operand : UniqueOperands(remat_instruction)) { - // Rematerialization may extend the lifetime of the operand so account for - // this in memory_usage_. - TF_RET_CHECK(IsPlaced(operand)); - if (!IsCurrentlyLive(operand)) { - memory_usage_ += TotalSizeBytes(operand->shape()); - } - remaining_uses_.at(operand).push_back(remat_instruction); + // If any buffer defined by this instruction has no uses, then memory can be + // reclaimed immediately. + for (BufferId buffer_id : + buffers_defined_by_instruction_.at(in_progress_instruction_)) { + const Buffer& buffer = buffers_.at(buffer_id); + if (buffer.unfinished_user_count == 0) { + VLOG(3) << " " << buffer.ToString() << " is immediately dead."; + memory_usage_ -= AllocatedSize(buffer_id); + CHECK_GE(memory_usage_, 0); } + } + + in_progress_instruction_ = nullptr; + + VLOG(3) << " memory usage = " << memory_usage_; + VLOG(10) << ToString(); + + DCHECK(Check()); - VLOG(3) << " memory usage = " << memory_usage_; - VLOG(10) << ToString(); - return Status::OK(); + return Status::OK(); +} + +int64 MemoryUsageTracker::MemoryReducedIfRematerialized( + const HloInstruction* instruction) const { + CHECK_NE(in_progress_instruction_, nullptr); + if (!IsPlaced(instruction) || instruction == in_progress_instruction_) { + return 0; } - // Returns the number of bytes that the current memory usage will be reduced - // if the given instruction is rematerialized. - int64 MemoryReducedIfRematerialized(const HloInstruction* instruction) const { - // To reduce memory consumption 'instruction' must be currently live and - // rematerialization must make 'instruction' not live. - if (IsLiveIn(instruction) || IsLiveOut(instruction) || - !IsCurrentlyLive(instruction)) { + // TODO(b/37687140): Rematerialization can increase peak memory consumption at + // an earlier point in the program if rematerialization extends the live range + // of the operand of the instruction being rematerialized across the live + // range of the value of instruction being rematerialized. Don't rematerialize + // in this case (ie, return 0 here). + + // Compute the amount of memory reduced (if any) by rematerializing + // 'instruction'. The LogicalBuffers defined by 'instruction' will no longer + // be live at this program point, so initially set memory_reduced to the + // size of its defined values. + int64 memory_reduced = 0; + for (BufferId buffer_id : buffers_defined_by_instruction_.at(instruction)) { + // Avoid rematerializing instructions with indirect uses as it is difficult + // to reason about liveness after rematerializing the instruction. + // TODO(b/37714814): Consider rematerialzing instructions with indirect + // uses. + if (buffers_.at(buffer_id).has_indirect_uses) { return 0; } - // If the in-progress instruction is a user of 'instruction' (or - // 'instruction' itself) then rematerializing 'instruction' cannot reduce - // memory usage because the value is required to be live at this program - // point. - if (in_progress_instruction_ == instruction || - in_progress_instruction_->IsUserOf(instruction)) { - return 0; + if (IsCurrentlyLive(buffer_id) && !IsInUse(buffer_id)) { + memory_reduced += AllocatedSize(buffer_id); } + } - // Compute the amount of memory reduced (if any) by rematerializing - // 'instruction'. 'instruction' will no longer be live at this program - // point, so initially set memory_reduced to the size of its output value. - int64 memory_reduced = TotalSizeBytes(instruction->shape()); - - // Account for any operands whose live range must be extended across this - // program point. - for (const HloInstruction* operand : UniqueOperands(instruction)) { - if (!IsCurrentlyLive(operand)) { - // This operand of candidate is not live at this program - // point. Rematerializing 'instruction' will extend the operand's live - // range across this program point. - memory_reduced -= TotalSizeBytes(operand->shape()); - } + // Account for any logical buffers whose live range must be extended across + // this program point. + for (BufferId buffer_id : buffers_used_by_instruction_.at(instruction)) { + if (!IsCurrentlyLive(buffer_id)) { + // This logical buffer is used by 'instruction' but is not live at this + // program point. Rematerializing 'instruction' will extend the buffer's + // live range across this program point. + memory_reduced -= AllocatedSize(buffer_id); } - return memory_reduced; } - // Returns the remaining unplaced uses of the given instruction. - const std::vector& RemainingUses( - const HloInstruction* instruction) const { - return remaining_uses_.at(instruction); + return memory_reduced; +} + +Status MemoryUsageTracker::AddRematerializedInstruction( + HloInstruction* original_instruction, HloInstruction* remat_instruction) { + VLOG(3) << "AddRematerializedInstruction: original_instruction = " + << original_instruction->name() + << ", remat_instruction = " << remat_instruction->name(); + + TF_RET_CHECK(in_progress_instruction_ != nullptr); + TF_RET_CHECK(IsPlaced(original_instruction)); + TF_RET_CHECK(!IsPlaced(remat_instruction)); + CHECK(!ContainsKey(buffers_defined_by_instruction_, remat_instruction)); + CHECK(!ContainsKey(buffers_used_by_instruction_, remat_instruction)); + + // Construct the list of buffers used and defined by the rematerialization. + buffers_defined_by_instruction_[remat_instruction]; + buffers_used_by_instruction_[remat_instruction] = + buffers_used_by_instruction_.at(original_instruction); + + // Account for the additional buffer uses created by the new rematerialization + // instruction. Update memory usage if the rematerialization makes a dead + // buffer live again. + for (BufferId buffer_id : + buffers_used_by_instruction_.at(original_instruction)) { + Buffer& buffer = buffers_.at(buffer_id); + if (buffer.unfinished_user_count == 0) { + // Buffer used by this instruction was dead, now is alive. + memory_usage_ += AllocatedSize(buffer.id); + } + + buffer.unfinished_user_count++; + buffer.users.push_back(remat_instruction); } - // Returns whether the given instruction has been placed (BeginInstruction has - // been called with 'instruction' as the argument). - bool IsPlaced(const HloInstruction* instruction) const { - return ContainsKey(remaining_uses_, instruction); - } - - // Returns whether the given instruction is live at the current program point. - bool IsCurrentlyLive(const HloInstruction* instruction) const { - return (!IsPlaced(instruction) && IsLiveIn(instruction)) || - (IsPlaced(instruction) && - (!RemainingUses(instruction).empty() || IsLiveOut(instruction))); - } - - string ToString() const { - string output = tensorflow::strings::StrCat("MemoryUsageTracker for ", - computation_->name(), "\n"); - tensorflow::strings::StrAppend(&output, "memory usage = ", memory_usage(), - "\n"); - tensorflow::strings::StrAppend(&output, "Live values:\n"); - for (const auto& pair : remaining_uses_) { - const HloInstruction* instruction = pair.first; - const std::vector& uses = pair.second; - tensorflow::strings::StrAppend( - &output, " ", instruction->name(), "; remaining uses: ", - tensorflow::str_util::Join(uses, ", ", - [](string* out, HloInstruction* use) { - tensorflow::strings::StrAppend( - out, use->name()); - }), - "\n"); + // Create a new set of Buffers defined by the new rematerialization + // instruction. Update the internal data structures and memory use to account + // for them. + for (BufferId old_buffer_id : + buffers_defined_by_instruction_.at(original_instruction)) { + Buffer& old_buffer = buffers_.at(old_buffer_id); + + std::vector placed_users; + std::vector unplaced_users; + for (const HloInstruction* user : old_buffer.users) { + if (IsPlaced(user)) { + CHECK(IsFinished(user)); + placed_users.push_back(user); + } else { + unplaced_users.push_back(user); + } + } + old_buffer.users = std::move(placed_users); + old_buffer.unfinished_user_count = 0; + + // Buffer is now dead. + memory_usage_ -= AllocatedSize(old_buffer.id); + + Buffer& new_buffer = RematerializeBuffer(old_buffer, remat_instruction, + std::move(unplaced_users)); + + buffers_defined_by_instruction_.at(remat_instruction) + .push_back(new_buffer.id); + for (const HloInstruction* user : new_buffer.users) { + std::vector& buffers_used = + buffers_used_by_instruction_.at(user); + std::replace(buffers_used.begin(), buffers_used.end(), old_buffer_id, + new_buffer.id); } - return output; } - // Returns the current memory usage. This is the sum of sizes of all live - // values. - int64 memory_usage() const { return memory_usage_; } + VLOG(3) << " memory usage = " << memory_usage_; + XLA_VLOG_LINES(10, ToString()); - // Returns the current instruction being placed. - const HloInstruction* in_progress_instruction() const { - return in_progress_instruction_; - } + DCHECK(Check()); - private: - // Returns the total size of the shape (including nested elements) in bytes. - int64 TotalSizeBytes(const Shape& shape) const { - int64 total_size = 0; - ShapeUtil::ForEachSubshape( - shape, - [this, &total_size](const Shape& subshape, - const ShapeIndex& /*index*/) { - total_size += size_function_(subshape); - return Status::OK(); - }) - .IgnoreError(); - return total_size; - } - - // Returns true if the value of given instruction is live into the - // computation. - bool IsLiveIn(const HloInstruction* instruction) const { - return instruction->opcode() == HloOpcode::kConstant || - instruction->opcode() == HloOpcode::kParameter; - } - - // Returns true if the value of given instruction is live out of the - // computation. - bool IsLiveOut(const HloInstruction* instruction) const { - return instruction->opcode() == HloOpcode::kConstant || - instruction->opcode() == HloOpcode::kParameter || - instruction == instruction->parent()->root_instruction(); + return Status::OK(); +} + +string MemoryUsageTracker::ToString() const { + string output = tensorflow::strings::StrCat("MemoryUsageTracker for ", + computation_->name(), "\n"); + tensorflow::strings::StrAppend( + &output, "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (", + memory_usage(), " bytes)"); + for (const HloInstruction* instruction : instruction_list_.instructions()) { + string inprogress = + instruction == in_progress_instruction_ ? " in-progress" : ""; + string placed = IsPlaced(instruction) ? " placed" : ""; + tensorflow::strings::StrAppend(&output, " ", instruction->name(), + inprogress, placed, "\n Defines:\n"); + for (BufferId buffer_id : buffers_defined_by_instruction_.at(instruction)) { + const Buffer& buffer = buffers_[buffer_id]; + string live = IsCurrentlyLive(buffer_id) ? " live" : ""; + tensorflow::strings::StrAppend(&output, " ", buffer.ToString(), live, + ", ", buffer.unfinished_user_count, + " unfinished uses\n"); + } + tensorflow::strings::StrAppend(&output, " Uses:\n"); + for (BufferId buffer_id : buffers_used_by_instruction_.at(instruction)) { + tensorflow::strings::StrAppend(&output, " ", + buffers_[buffer_id].ToString(), "\n"); + } } + return output; +} - const HloComputation* computation_; +bool MemoryUsageTracker::Check() const { + auto elements_are_unique = [](const std::vector& vec) { + return vec.size() == std::set(vec.begin(), vec.end()).size(); + }; + + // Verify buffers_defined_by_instruction_. + for (auto& instruction : computation_->instructions()) { + const std::vector& defined_buffers = + buffers_defined_by_instruction_.at(instruction.get()); + CHECK(elements_are_unique(defined_buffers)) + << "Instruction " << instruction->name() + << " does not have unique defined buffers: " + << tensorflow::str_util::Join( + defined_buffers, ", ", [this](string* out, BufferId buffer_id) { + tensorflow::strings::StrAppend( + out, buffers_.at(buffer_id).ToString()); + }); - // Function which computes the size of the top-level buffer of a shape. - const HloRematerialization::ShapeSizeFunction size_function_; + for (const Buffer& buffer : buffers_) { + if (buffer.defining_instruction == instruction.get()) { + CHECK(std::find(defined_buffers.begin(), defined_buffers.end(), + buffer.id) != defined_buffers.end()) + << "Instruction " << instruction->name() + << " defined buffers is missing: " << buffer.ToString(); + } + } + } - // Memory usage at the currently placed instruction. - int64 memory_usage_ = 0; + // Verify buffers_used_by_instruction_. + for (auto& instruction : computation_->instructions()) { + const std::vector& used_buffers = + buffers_used_by_instruction_.at(instruction.get()); + CHECK(elements_are_unique(used_buffers)) + << "Instruction " << instruction->name() + << " does not have unique used buffers: " + << tensorflow::str_util::Join( + used_buffers, ", ", [this](string* out, BufferId buffer_id) { + tensorflow::strings::StrAppend( + out, buffers_.at(buffer_id).ToString()); + }); + } + for (const Buffer& buffer : buffers_) { + int64 unfinished_uses = 0; + for (const HloInstruction* user : buffer.users) { + const std::vector& used_buffers = + buffers_used_by_instruction_.at(user); + CHECK(std::find(used_buffers.begin(), used_buffers.end(), buffer.id) != + used_buffers.end()) + << "Instruction " << user->name() << " used buffers is missing " + << buffer.ToString(); + if (!IsFinished(user)) { + unfinished_uses++; + } + } + CHECK_EQ(buffer.unfinished_user_count, unfinished_uses) + << "Incorrect unplaced use count for " << buffer.ToString(); + } - // The instruction currently being placed. This value is non-null only between - // the calling of BeginInstruction and EndInstruction. - const HloInstruction* in_progress_instruction_ = nullptr; + // Verify live set size against memory_usage_. + int64 live_size = 0; + for (const Buffer& buffer : buffers_) { + // The while instruction reuses its input buffers as output buffers so + // don't double count its buffers if it is currently executing. + if (IsCurrentlyLive(buffer.id) && + !(buffer.defining_instruction == in_progress_instruction_ && + in_progress_instruction_->opcode() == HloOpcode::kWhile)) { + live_size += AllocatedSize(buffer.id); + } + } + CHECK_EQ(live_size, memory_usage_); - // remaining_uses is a vector of uses of the HLO instruction's value which - // have not yet been visited by in the rematerialization loop. Use to track - // liveness of HLO instructions. - // TODO(b/35212854): Track values using logical buffers rather than HLO - // instructions. Using HLO instructions over-estimates memory usage because - // buffer aliasing is ignored. - tensorflow::gtl::FlatMap> - remaining_uses_; -}; + return true; +} -// Computes and returns the cost of rematerializing the given instruction. Cost -// per rematerialized instruction is defined as: +// Computes and returns the cost of rematerializing the given instruction. +// Cost per rematerialized instruction is defined as: // // (flop_count + transcendental_count + element_count) / memory_reduced // @@ -425,33 +808,36 @@ class MemoryUsageTracker { // instruction. // // This is a rough estimate of the extra execution time per byte saved by -// rematerializing this instruction for its remaining uses. In general, we want -// the most memory saving for the least latency penalty which is captured by -// this heuristic. +// rematerializing this instruction for its remaining uses. In general, we +// want the most memory saving for the least latency penalty which is captured +// by this heuristic. int64 RematerializationCost(const HloInstruction* instruction, const MemoryUsageTracker& memory_tracker, const HloCostAnalysis& cost_analysis, int64 memory_reduced) { - const int64 bytes_accessed = cost_analysis.bytes_accessed(*instruction); - const int64 elements_accessed = - bytes_accessed / - ShapeUtil::ByteSizeOfPrimitiveType(instruction->shape().element_type()); - - // A duplicate of the rematerialized instruction will be created at each - // remaining use. - int64 duplication = memory_tracker.RemainingUses(instruction).size(); - if (duplication == instruction->users().size()) { - // All remaining uses of instruction are after this point so we can remove - // the original instruciton after rematerialization. - duplication -= 1; + // If none of the users of 'instruction' have been placed in the sequence (as + // tracked by memory_tracker), then rematerialization of 'instruction' is a + // zero-cost move of 'instruction' in the sequence. + if (!std::any_of(instruction->users().begin(), instruction->users().end(), + [&memory_tracker](const HloInstruction* inst) { + return memory_tracker.IsPlaced(inst); + })) { + return 0; } + CHECK_GT(memory_reduced, 0); + const int64 bytes_accessed = cost_analysis.bytes_accessed(*instruction); + const int64 elements_accessed = + ShapeUtil::IsTuple(instruction->shape()) + ? bytes_accessed + : bytes_accessed / ShapeUtil::ByteSizeOfPrimitiveType( + instruction->shape().element_type()); // Multiply by 256 to improve precision of cost. Without this factor, // many instructions such as many elementwise instructions would have // zero cost because the bytes reduced can be several times greater than // the element count. - return 256 * duplication * + return 256 * (cost_analysis.flop_count(*instruction) + cost_analysis.transcendental_count(*instruction) + elements_accessed) / @@ -467,7 +853,7 @@ HloInstruction* PickRematerializationCandidate( const MemoryUsageTracker& memory_tracker, const InstructionList& instruction_list, const HloCostAnalysis& cost_analysis, - const tensorflow::gtl::FlatSet& remat_instructions) { + const tensorflow::gtl::FlatSet& blacklist) { HloInstruction* best = nullptr; int64 best_cost = 0; @@ -482,11 +868,11 @@ HloInstruction* PickRematerializationCandidate( } VLOG(5) << "considering rematerialization candidate " << candidate->name(); - if (ContainsKey(remat_instructions, candidate)) { - // Skip instructions which are rematerialization clones to avoid infinite - // loops of rematerializing the same instruction(s) repeatedly. + if (ContainsKey(blacklist, candidate)) { + // Skip instructions on the blacklist to avoid infinite loops of + // rematerializing the same instruction(s) repeatedly. VLOG(5) << "candidate " << candidate->name() - << " not viable: is a rematerialized instruction"; + << " is excluded from rematerialization"; continue; } @@ -525,7 +911,9 @@ HloInstruction* PickRematerializationCandidate( StatusOr HloRematerialization::ComputePeakMemory( const HloComputation* computation, const std::vector& order) const { - MemoryUsageTracker tracker(computation, size_function_); + InstructionList instruction_list(order); + MemoryUsageTracker tracker(computation, size_function_, *points_to_analysis_, + instruction_list); int64 peak_memory = tracker.memory_usage(); for (const HloInstruction* instruction : order) { TF_RETURN_IF_ERROR(tracker.BeginInstruction(instruction)); @@ -542,9 +930,8 @@ StatusOr HloRematerialization::ComputePeakMemory( StatusOr HloRematerialization::CalledComputationsMemoryUsage( const HloInstruction* instruction) const { - TF_ASSIGN_OR_RETURN(const CallGraphNode* node, - call_graph_->GetNode(instruction->parent())); - const CallSite* callsite = node->GetCallSite(instruction); + const CallSite* callsite = + call_graph_->GetNode(instruction->parent()).GetCallSite(instruction); if (callsite == nullptr || callsite->context() == CallContext::kParallel) { return 0; } @@ -564,15 +951,24 @@ StatusOr HloRematerialization::RematerializeComputation( << " with limit " << HumanReadableNumBytes(memory_limit_bytes); VLOG(1) << "peak memory usage is " << HumanReadableNumBytes(computation_peak_memory_.at(computation)); + CHECK(!ContainsKey(rematerialized_computations_, computation)); InstructionList instruction_list(sequence->at(computation)); - MemoryUsageTracker memory_tracker(computation, size_function_); + MemoryUsageTracker memory_tracker(computation, size_function_, + *points_to_analysis_, instruction_list); bool changed = false; - // Set of instruction clones (not the originals) created during - // rematerialization. A record is kept to avoid rematerializing an instruction - // more than once to avoid looping infinitely during rematerialization. - tensorflow::gtl::FlatSet remat_instructions; + // To avoid an infinite loop rematerializing the same set of instructions ad + // infinitum, keep a blacklist of instructions which should not be + // rematerialized. + tensorflow::gtl::FlatSet blacklist; + + // If the rematerialization makes the source instruction dead, then the + // rematerialization is added to 'remat_move_instructions' (the + // rematerialization is essentially a move). If the next rematerialization of + // the instruction is also a move then the rematerialization is added to the + // blacklist. + tensorflow::gtl::FlatSet remat_move_instructions; // The peak memory of the computation at any point in the instruction // sequence. @@ -584,12 +980,12 @@ StatusOr HloRematerialization::RematerializeComputation( // instructions which are dead. int64 net_instructions_added = 0; - TF_ASSIGN_OR_RETURN(const CallGraphNode* call_graph_node, - call_graph_->GetNode(computation)); + const CallGraphNode& call_graph_node = call_graph_->GetNode(computation); // Iterate through all instructions in the sequence. At each instruction // (program point) if memory_usage exceeds the specified limit then // rematerialize HLO instructions until memory_usage is reduced. + int64 instruction_index = 0; for (auto list_it = instruction_list.instructions().begin(); list_it != instruction_list.instructions().end(); ++list_it) { HloInstruction* instruction = *list_it; @@ -599,7 +995,9 @@ StatusOr HloRematerialization::RematerializeComputation( VLOG(2) << "Program point at " << instruction->name() << ", memory usage = " << memory_tracker.memory_usage() - << ", callee usage = " << callee_usage; + << ", callee usage = " << callee_usage << ", [" << instruction_index + << "/" << instruction_list.instructions().size() << "]"; + instruction_index++; while (memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) { VLOG(2) << "Over memory limit at instruction " << instruction->name() @@ -609,7 +1007,7 @@ StatusOr HloRematerialization::RematerializeComputation( << ", limit is " << HumanReadableNumBytes(memory_limit_bytes); HloInstruction* best = PickRematerializationCandidate( - memory_tracker, instruction_list, cost_analysis_, remat_instructions); + memory_tracker, instruction_list, cost_analysis_, blacklist); if (best == nullptr) { VLOG(3) << "Unable to find rematerialization candidate at program " @@ -620,44 +1018,42 @@ StatusOr HloRematerialization::RematerializeComputation( break; } - VLOG(1) << "Rematerializing instruction " << best->name(); + VLOG(1) << "Rematerializing instruction " << best->name() << " (saving " + << memory_tracker.MemoryReducedIfRematerialized(best) << ")"; changed = true; remat_count++; - // Create a rematerialized copy of the candidate at each remaining use. - // Make a copy of remaining uses because RematerializeInstructionForUse - // modifies the remaining uses vector in memory_tracker. - // TODO(b/35213652): It may be profitable to share one rematerialized copy - // amongst more than one use. - std::vector remaining_uses_copy = - memory_tracker.RemainingUses(best); - for (HloInstruction* use : remaining_uses_copy) { - // Create a new rematerialized instruction in the HLO graph. - HloInstruction* remat = - computation->AddInstruction(best->Clone(/*suffix=*/"remat")); - - VLOG(3) << "Replacing use of " << best->name() << " in " << use->name() - << " with rematerialization " << remat->name(); + HloInstruction* remat = + computation->AddInstruction(best->Clone(/*suffix=*/"remat")); - TF_RETURN_IF_ERROR(best->ReplaceUseWith(use, remat)); - - // Account for the rematerialization in the memory tracker. - TF_RETURN_IF_ERROR( - memory_tracker.RematerializeInstructionForUse(best, remat, use)); - - // Insert rematerialized instruction right before its use. - TF_RETURN_IF_ERROR(instruction_list.InsertBefore(remat, use)); - - // Add rematerialized instruction to remat_instructions so the - // rematerialized instruction is not rematerialized again. - remat_instructions.insert(remat); - - net_instructions_added++; + // Replace each remaining use of 'best' with the rematerialization. + std::vector best_users_copy = best->users(); + for (HloInstruction* user : best_users_copy) { + if (!memory_tracker.IsPlaced(user)) { + VLOG(2) << " Replacing use of " << best->name() << " in " + << user->name() << " with " << remat->name(); + TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, remat)); + } } - // Original instruction should no longer be live at this point. All - // of its remaining uses are fed by rematerialized instructions. - TF_RET_CHECK(!memory_tracker.IsCurrentlyLive(best)); + // Account for the rematerialization in the memory tracker. + TF_RETURN_IF_ERROR( + memory_tracker.AddRematerializedInstruction(best, remat)); + + // Insert rematerialized instruction right before the earliest unplaced + // use of the instruction *and* the earliest unplaced last use of any + // operands of remat. Unplaced uses of the remat's operands are included + // because we don't want to extend the live range of remat's operands as + // this could increase memory usage. + std::vector place_before = remat->users(); + for (auto* operand : remat->operands()) { + for (auto* operand_user : operand->users()) { + if (!memory_tracker.IsPlaced(operand_user) && operand_user != remat) { + place_before.push_back(operand_user); + } + } + } + instruction_list.InsertBeforeInstructions(remat, place_before); // If the rematerialized instruction is dead then rematerialization is // essentially a move. Don't delete the instruction now because we don't @@ -665,15 +1061,24 @@ StatusOr HloRematerialization::RematerializeComputation( // transformation because we keep maps with HloInstruction* values as // keys. if (best->users().empty()) { - VLOG(3) << best->name() << " is now dead"; - net_instructions_added--; + VLOG(2) << best->name() << " is now dead"; + if (ContainsKey(remat_move_instructions, best)) { + // Previously, 'best' was a rematerialization which killed the + // instruction it was a copying of. Now 'remat' is a rematerialization + // of 'best' and kills 'best'. Stop rematerializing this instruction + // to avoid an infinite loop. + blacklist.insert(remat); + } + remat_move_instructions.insert(remat); + } else { + net_instructions_added++; } VLOG(3) << "memory_usage after rematerialization = " << memory_tracker.memory_usage(); } - const CallSite* callsite = call_graph_node->GetCallSite(instruction); + const CallSite* callsite = call_graph_node.GetCallSite(instruction); if (callsite != nullptr && callsite->context() == CallContext::kSequential && memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) { @@ -687,21 +1092,22 @@ StatusOr HloRematerialization::RematerializeComputation( // Recompute callee usage to account for any rematerialization performed // in the callee computations. - callee_usage = 0; for (HloComputation* called_computation : callsite->called_computations()) { - // Memory limit for the subcomputation is the memory limit less the - // amount of memory used at this point in the computation. - int64 subcomputation_memory_limit_bytes = std::max( - 0, memory_limit_bytes - memory_tracker.memory_usage()); - TF_ASSIGN_OR_RETURN( - bool subcomputation_changed, - RematerializeComputation(called_computation, sequence, - subcomputation_memory_limit_bytes)); - changed |= subcomputation_changed; - - callee_usage += computation_peak_memory_.at(called_computation); + if (!ContainsKey(rematerialized_computations_, called_computation)) { + // Memory limit for the subcomputation is the memory limit less the + // amount of memory used at this point in the computation. + int64 subcomputation_memory_limit_bytes = std::max( + 0, memory_limit_bytes - memory_tracker.memory_usage()); + TF_ASSIGN_OR_RETURN( + bool subcomputation_changed, + RematerializeComputation(called_computation, sequence, + subcomputation_memory_limit_bytes)); + changed |= subcomputation_changed; + } } + TF_ASSIGN_OR_RETURN(callee_usage, + CalledComputationsMemoryUsage(instruction)); } peak_memory = std::max(peak_memory, @@ -711,37 +1117,33 @@ StatusOr HloRematerialization::RematerializeComputation( TF_RETURN_IF_ERROR(memory_tracker.EndInstruction()); } - if (peak_memory > memory_limit_bytes) { - LOG(WARNING) << "Can't reduce memory use of computation " - << computation->name() << " below " - << HumanReadableNumBytes(memory_limit_bytes) - << " by rematerialization (only reduced to " - << HumanReadableNumBytes(peak_memory) << ")"; - } - - // Verify that there are no more remaining uses. + // Verify some invariants on the memory tracker. + CHECK_EQ(memory_tracker.memory_usage(), 0); for (auto& instruction : computation->instructions()) { - auto& remaining_uses = memory_tracker.RemainingUses(instruction.get()); - CHECK(remaining_uses.empty()) - << instruction->name() << " has remaining uses: " - << tensorflow::str_util::Join( - remaining_uses, ", ", [](string* out, HloInstruction* inst) { - tensorflow::strings::StrAppend(out, inst->name()); - }); + CHECK(memory_tracker.IsPlaced(instruction.get())); } - VLOG(1) << "Rematerialized " << remat_count << " instructions; " - << net_instructions_added << " net instructions added"; - VLOG(1) << "peak memory usage now " << HumanReadableNumBytes(peak_memory); + VLOG(1) << "In computation " << computation->name() << " rematerialized " + << remat_count << " instructions; " << net_instructions_added + << " net instructions added"; + VLOG(1) << " peak memory usage now " << HumanReadableNumBytes(peak_memory) + << " (was " + << HumanReadableNumBytes(computation_peak_memory_.at(computation)) + << ")"; // Update peak memory used by computation. - computation_peak_memory_[computation] = peak_memory; + computation_peak_memory_.at(computation) = peak_memory; // Update order to include rematerialized instructions. sequence->at(computation) .assign(instruction_list.instructions().begin(), instruction_list.instructions().end()); + rematerialized_computations_.insert(computation); + + instructions_rematerialized_ += remat_count; + net_instructions_added_ += net_instructions_added; + return changed; } @@ -754,6 +1156,28 @@ StatusOr HloRematerialization::Run( VLOG(1) << "HloRematerialization() with memory limit of " << HumanReadableNumBytes(memory_limit_bytes); + TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module)); + + // Adjust memory limit to account for the output of the entry + // computation. This is necessary because the per-computation accounting in + // MemoryUsageTracker do not include output as these are typically allocated + // by the caller. + int64 module_output_size = 0; + ShapeUtil::ForEachSubshape( + module->entry_computation()->root_instruction()->shape(), + [&module_output_size, this](const Shape& subshape, + const ShapeIndex& /*index*/) { + module_output_size += size_function_(subshape); + return Status::OK(); + }) + .IgnoreError(); + + const int64 adjusted_memory_limit_bytes = + memory_limit_bytes - module_output_size; + VLOG(1) << "Adjusted memory limit accounting for output (" + << HumanReadableNumBytes(module_output_size) + << "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes); + XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); // Create initial sequence of HLO instructions. TF_ASSIGN_OR_RETURN(*sequence, @@ -761,10 +1185,9 @@ StatusOr HloRematerialization::Run( *module, [this](const LogicalBuffer& buffer) { return size_function_(buffer.shape()); })); - // Compute peak memory usage of all computations in the module called in a // sequential context. - TF_ASSIGN_OR_RETURN(call_graph_, CallGraph::Build(module)); + call_graph_ = CallGraph::Build(module); TF_RETURN_IF_ERROR(call_graph_->VisitNodes( [this, sequence](const CallGraphNode& node) -> Status { if (node.context() == CallContext::kSequential) { @@ -776,9 +1199,15 @@ StatusOr HloRematerialization::Run( return Status::OK(); })); + // The peak memory usage of the module equals the peak memory use of the entry + // computation plus the output size of the computation. This is because the + // peak memory for a computation does not include the output as this is + // typically accounted for in the caller. + const int64 before_peak_memory = + computation_peak_memory_.at(module->entry_computation()) + + module_output_size; VLOG(1) << "Peak memory usage of module (before): " - << HumanReadableNumBytes( - computation_peak_memory_[module->entry_computation()]); + << HumanReadableNumBytes(before_peak_memory); // Run cost analysis. Operation cost is used in the heuristic for selecting // instructions for rematerialization. @@ -787,9 +1216,9 @@ StatusOr HloRematerialization::Run( // Subcomputations called by the entry computation will also be // rematerialized. - TF_ASSIGN_OR_RETURN(bool changed, - RematerializeComputation(module->entry_computation(), - sequence, memory_limit_bytes)); + TF_ASSIGN_OR_RETURN(bool changed, RematerializeComputation( + module->entry_computation(), sequence, + adjusted_memory_limit_bytes)); // Rematerialization can introduce dead code. This occurs if all uses of an // instruction are replaced with rematerializations of the instruction. @@ -824,19 +1253,38 @@ StatusOr HloRematerialization::Run( computation->instruction_count()); } } - - VLOG(1) << "Peak memory usage of module (after): " - << HumanReadableNumBytes( - computation_peak_memory_[module->entry_computation()]); + VLOG(1) << "Rematerialized " << instructions_rematerialized_ + << " instructions in module " << module->name() << "; " + << net_instructions_added_ << " net instructions added"; + const int64 current_peak_memory = + computation_peak_memory_.at(module->entry_computation()) + + module_output_size; + VLOG(1) << "Peak memory usage of module now " + << HumanReadableNumBytes(current_peak_memory) << " (" + << current_peak_memory << " bytes), was " + << HumanReadableNumBytes(before_peak_memory) << " (" + << before_peak_memory << " bytes)"; + const int64 reduced_peak_memory = before_peak_memory - current_peak_memory; + VLOG(1) << "Reduced peak memory by " + << HumanReadableNumBytes(reduced_peak_memory) << " (" + << reduced_peak_memory << " bytes)"; XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString()); + if (current_peak_memory > memory_limit_bytes) { + LOG(WARNING) << "Can't reduce memory use below " + << HumanReadableNumBytes(memory_limit_bytes) + << " by rematerialization (only reduced to " + << HumanReadableNumBytes(current_peak_memory) << ")"; + } + return changed; } /* static */ StatusOr HloRematerialization::RematerializeAndSchedule( - const ShapeSizeFunction& size_function, int64 memory_limit_bytes, - HloModule* hlo_module, SequentialHloOrdering::HloModuleSequence* sequence) { + const HloRematerialization::ShapeSizeFunction& size_function, + int64 memory_limit_bytes, HloModule* hlo_module, + SequentialHloOrdering::HloModuleSequence* sequence) { HloRematerialization remat(size_function); return remat.Run(hlo_module, sequence, memory_limit_bytes); } diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 86e1998b89454f75b1c10d0de2118fd1034c134d..1693f93183bc59c343e3c765cb4051566d4377ef 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -21,6 +21,7 @@ #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" namespace xla { @@ -108,6 +109,23 @@ class HloRematerialization { // occurs. tensorflow::gtl::FlatMap computation_peak_memory_; + + std::unique_ptr points_to_analysis_; + + // Set of computations which have had rematerialization + // applied. Rematerialization is only applied once per computation. + tensorflow::gtl::FlatSet rematerialized_computations_; + + // Count of the total instructions rematerialized. + int64 instructions_rematerialized_ = 0; + + // Count of the net instructions added to the HLO module by + // rematerialization. This can be different than instructions_rematerialized_ + // because some rematerializations are effectively moves in the HLO + // schedule. In these cases, the rematerialization instruction replaces all + // uses of the original instruction and the original instruction is + // dead. Hence, no net instructions were added. + int64 net_instructions_added_ = 0; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 0a4f2776891cfc932b4fc0627daaa9b5408f420a..2a1d728bc84067e6ad7f1f622216ab39b2b474d3 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -30,12 +31,16 @@ limitations under the License. namespace xla { namespace { -class HloOrderingTest : public HloTestBase { +namespace op = xla::testing::opcode_matchers; + +using ::testing::_; + +class HloRematerializationTest : public HloTestBase { protected: // Creates and returns a computation which can benefit from // rematerialization. The computation looks like: // - // F32[1] %param = {...} + // F32[] %param = {...} // F32[1024] %bcast = broadcast(%param) // F32[1024] %negate = negate(%bcast) // F32[2048] %concat_1 = concat({%negate, %negate}) @@ -52,7 +57,7 @@ class HloOrderingTest : public HloTestBase { const string& suffix = "") { auto builder = HloComputation::Builder(TestName() + suffix); auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, vec1_shape_, "param")); + HloInstruction::CreateParameter(0, scalar_shape_, "param")); auto bcast = builder.AddInstruction( HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); auto negate = builder.AddInstruction( @@ -77,7 +82,7 @@ class HloOrderingTest : public HloTestBase { // Creates and returns a computation which includes a while and can benefit // from rematerialization. The computation looks like: // - // F32[1] %param = {...} + // F32[] %param = {...} // F32[1024] %bcast = broadcast(%param) // F32[1] %slice_1 = slice(%bcast, {0:1}) // F32[1] %while = while(%slice_1, while_body, while_cond) @@ -93,7 +98,7 @@ class HloOrderingTest : public HloTestBase { const string& suffix = "") { auto builder = HloComputation::Builder(TestName() + suffix); auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, vec1_shape_, "param")); + HloInstruction::CreateParameter(0, scalar_shape_, "param")); auto bcast = builder.AddInstruction( HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); auto slice_1 = builder.AddInstruction( @@ -127,13 +132,14 @@ class HloOrderingTest : public HloTestBase { } // Various shapes used in the canned computations. + const Shape scalar_shape_ = ShapeUtil::MakeShape(xla::F32, {}); const Shape vec1_shape_ = ShapeUtil::MakeShape(xla::F32, {1}); const Shape vec1024_shape_ = ShapeUtil::MakeShape(xla::F32, {1024}); }; // Test rematerialization of a single computation produced by // MakeRematerializableComputation. -TEST_F(HloOrderingTest, SingleComputation) { +TEST_F(HloRematerializationTest, SingleComputation) { HloModule module(TestName()); HloComputation* computation = module.AddEntryComputation(MakeRematerializableComputation()); @@ -141,11 +147,9 @@ TEST_F(HloOrderingTest, SingleComputation) { // Find and save the original broadcast instruction which should be // rematerialized. const HloInstruction* slice = computation->root_instruction(); - ASSERT_EQ(HloOpcode::kSlice, slice->opcode()); + ASSERT_THAT(slice, op::Slice(op::Concatenate(op::Broadcast(_), _))); const HloInstruction* concat = slice->operand(0); - ASSERT_EQ(HloOpcode::kConcatenate, concat->opcode()); const HloInstruction* bcast = concat->operand(0); - ASSERT_EQ(HloOpcode::kBroadcast, bcast->opcode()); SequentialHloOrdering::HloModuleSequence sequence; // Computation requires 16KB without rematerialization, but uses only 12KB @@ -161,8 +165,7 @@ TEST_F(HloOrderingTest, SingleComputation) { // The broadcast should have been rematerialized. const HloInstruction* remat_bcast = concat->operand(0); - EXPECT_EQ(HloOpcode::kBroadcast, remat_bcast->opcode()); - EXPECT_NE(bcast, remat_bcast); + EXPECT_THAT(remat_bcast, op::Broadcast(::testing::Ne(bcast))); // The rematerialized broadcast should be immediate before the concat in the // sequence. @@ -175,7 +178,7 @@ TEST_F(HloOrderingTest, SingleComputation) { // Test rematerialization of a single computation produced by // MakeRematerializableComputation but with a sufficiently high memory limit // such that no instructions are rematerialized. -TEST_F(HloOrderingTest, SingleComputationNoRematerialization) { +TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { HloModule module(TestName()); HloComputation* computation = module.AddEntryComputation(MakeRematerializableComputation()); @@ -199,7 +202,7 @@ TEST_F(HloOrderingTest, SingleComputationNoRematerialization) { // only one computation needs to have an instruction rematerialized. The entry // computation should be the one chosen because rematerialization in the while // will presumably be more expensive. -TEST_F(HloOrderingTest, RematerializeAroundWhile) { +TEST_F(HloRematerializationTest, RematerializeAroundWhile) { HloModule module(TestName()); auto cond_builder = HloComputation::Builder(TestName() + ".cond"); @@ -237,7 +240,7 @@ TEST_F(HloOrderingTest, RematerializeAroundWhile) { // Test rematerialization of a computation which calls another computation via a // while. Both the entry computation and while body computation should have // computations rematerialized. -TEST_F(HloOrderingTest, RematerializeEntryAndWhileBody) { +TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { HloModule module(TestName()); auto cond_builder = HloComputation::Builder(TestName() + ".cond"); @@ -271,7 +274,7 @@ TEST_F(HloOrderingTest, RematerializeEntryAndWhileBody) { // Test rematerialization of a doubly nested computation. All computations // should have an instruction rematerialized. -TEST_F(HloOrderingTest, RematerializeNestedComputations) { +TEST_F(HloRematerializationTest, RematerializeNestedComputations) { HloModule module(TestName()); auto cond_builder = HloComputation::Builder(TestName() + ".cond"); @@ -311,6 +314,203 @@ TEST_F(HloOrderingTest, RematerializeNestedComputations) { EXPECT_EQ(inner_computation->instruction_count(), 8); } +TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { + // Test that a single instruction is rematerialized several times. Module: + // + // Entry computation: + // F32[] %param = {...} + // F32[1024] %bcast = broadcast(%param) + // F32[1024] %add_1 = add(%bcast, bcast) + // F32[1024] %call_1 = call(Subcomputation, {%add_1}) + // F32[1024] %add_2 = add(%bcast, call_1) + // F32[1024] %call_2 = call(SubComputation, {%add_2}) + // F32[1024] %add_3 = add(%bcast, call_2) + // F32[1024] %call_3 = call(Subcomputation, {%add_3}) + // F32[1024] %add_4 = add(%bcast, call_3) + // + // Subcomputation: + // F32[1024] %param = {...} + // F32[2048] %concat = concat({%param, %param}) + // F32[1024] %slice = slice(%concat) + // + // The value %bcast is live across each call of Subcomputation (which requires + // 8KB) though the value is not used in the calls. Rematerializing %bcast + // across these calls reduces peak memory use from ~20KB down to ~16KB. + HloModule module(TestName()); + + HloComputation* subcomputation = nullptr; + { + auto builder = HloComputation::Builder(TestName() + ".subcomputation"); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, vec1024_shape_, "param")); + auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(xla::F32, {2048}), {param, param}, + /*dimension=*/0)); + builder.AddInstruction(HloInstruction::CreateSlice( + vec1024_shape_, concat, /*start_indices=*/{0}, + /*limit_indices=*/{1024})); + subcomputation = module.AddEmbeddedComputation(builder.Build()); + } + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + auto bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); + auto add_1 = builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, bcast, bcast)); + auto call_1 = builder.AddInstruction( + HloInstruction::CreateCall(vec1024_shape_, {add_1}, subcomputation)); + auto add_2 = builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, bcast, call_1)); + auto call_2 = builder.AddInstruction( + HloInstruction::CreateCall(vec1024_shape_, {add_2}, subcomputation)); + auto add_3 = builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, bcast, call_2)); + auto call_3 = builder.AddInstruction( + HloInstruction::CreateCall(vec1024_shape_, {add_3}, subcomputation)); + auto add_4 = builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, bcast, call_3)); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build()); + + auto count_broadcasts = [](const HloComputation* computation) { + int64 bcast_count = 0; + for (auto& instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kBroadcast) { + bcast_count++; + } + } + return bcast_count; + }; + + // Before rematerialization there should be a single broadcast instruction in + // the graph. + EXPECT_EQ(count_broadcasts(entry_computation), 1); + EXPECT_EQ(entry_computation->instruction_count(), 9); + + EXPECT_EQ(add_2->operand(0), bcast); + EXPECT_EQ(add_3->operand(0), bcast); + EXPECT_EQ(add_4->operand(0), bcast); + + SequentialHloOrdering::HloModuleSequence sequence; + // Pick a memory limit some where between 24KB (initial peak memory including + // parameter and output) and 20KB (peak memory possible with + // rematerialization). + TF_ASSIGN_OR_ASSERT_OK( + bool changed, HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/22 * 1024, &module, &sequence)); + EXPECT_TRUE(changed); + + // The broadcast should have been rematerialized 3 times. + EXPECT_EQ(count_broadcasts(entry_computation), 4); + EXPECT_EQ(entry_computation->instruction_count(), 12); + + // The operands of add_2, add_3, and add_4 should all be rematerialized + // broadcasts. + EXPECT_NE(add_2->operand(0), bcast); + EXPECT_THAT(add_2->operand(0), op::Broadcast(param)); + EXPECT_NE(add_3->operand(0), bcast); + EXPECT_THAT(add_3->operand(0), op::Broadcast(param)); + EXPECT_NE(add_4->operand(0), bcast); + EXPECT_THAT(add_4->operand(0), op::Broadcast(param)); +} + +class IndirectUseTest : public HloRematerializationTest, + public ::testing::WithParamInterface {}; + +TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { + // Test that an rematerializable instruction is not rematerialized if it has + // an indirect use. Test is parameterized on whether the value has an indirect + // use, and the instruction should be rematerialized iff the value has no + // indirect use. Module: + // + // Entry computation: + // F32[] %param = {...} + // F32[1024] %bcast = broadcast(%param) + // F32[1024] %add_1 = add(%bcast, bcast) + // F32[1024] %call = call(Subcomputation, {%add_1}) + // F32[1024] %add_2 = add(%bcast, call) + // {F32[1024], F32[1024]} %tuple = tuple(%bcast, %add_2) + // F32[1024] %gte = GetTupleElememt(%tuple, 0) + // F32[1024] %negate = negate(%gte) + // + // Subcomputation: + // F32[1024] %param = {...} + // F32[2048] %concat = concat({%param, %param}) + // F32[1024] %slice = slice(%concat) + // + // The value %bcast is live across the call and rematerialization of %bcast + // across that point would reduce peak memory use by 4KB. However, %bcast is + // used indirectly in the %negate so rematerialization should not happen. + // + // This test is parameterized on whether the broadcast has an indirect use or + // not. The indirect use is controlled by the index of the GetTupleElement + // instruction. If the element is 0, then the %negate operand aliases %bcast + // (ie %bcast is used indirectly by %negate), otherwise the %negate operand + // aliases %add_2. + const bool indirectly_used = GetParam(); + HloModule module(TestName()); + + HloComputation* subcomputation = nullptr; + { + auto builder = HloComputation::Builder(TestName() + ".subcomputation"); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, vec1024_shape_, "param")); + auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(xla::F32, {2048}), {param, param}, + /*dimension=*/0)); + builder.AddInstruction(HloInstruction::CreateSlice( + vec1024_shape_, concat, /*start_indices=*/{0}, + /*limit_indices=*/{1024})); + subcomputation = module.AddEmbeddedComputation(builder.Build()); + } + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + auto bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); + auto add_1 = builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, bcast, bcast)); + auto call_1 = builder.AddInstruction( + HloInstruction::CreateCall(vec1024_shape_, {add_1}, subcomputation)); + auto add_2 = builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, bcast, call_1)); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({bcast, add_2})); + auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement( + vec1024_shape_, tuple, indirectly_used ? 0 : 1)); + builder.AddInstruction( + HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, gte)); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build()); + + EXPECT_EQ(entry_computation->instruction_count(), 8); + + SequentialHloOrdering::HloModuleSequence sequence; + // Pick a memory limit some where between 24KB (initial peak memory including + // parameter and output) and 20KB (peak memory possible with + // rematerialization). + TF_ASSIGN_OR_ASSERT_OK( + bool changed, HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/22 * 1024, &module, &sequence)); + // Rematerialization should only occur if the rematerializable instruction has + // no indirect uses. + if (indirectly_used) { + EXPECT_FALSE(changed); + EXPECT_EQ(entry_computation->instruction_count(), 8); + } else { + EXPECT_TRUE(changed); + EXPECT_EQ(entry_computation->instruction_count(), 9); + } +} + +INSTANTIATE_TEST_CASE_P(IndirectUseTestInstantiation, IndirectUseTest, + ::testing::Values(true, false)); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc index fdc1c0ba2d78bed66ead05cf71177ddabbe80108..2b14eca5d1b36fbe8b863cb32d64c79fb56ce761 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc @@ -68,9 +68,8 @@ void CleanNodeName(string* name) { } Status HloTfGraphBuilder::AddComputation(const HloComputation& computation) { - LOG(INFO) << "Adding computation " << computation.name(); + VLOG(2) << "Adding computation " << computation.name(); for (auto embedded : computation.MakeEmbeddedComputationsList()) { - LOG(INFO) << "Adding embedded computation " << embedded->name(); for (auto& instruction : embedded->instructions()) { TF_RETURN_IF_ERROR(AddInstruction(instruction.get())); } @@ -88,12 +87,18 @@ const string& HloTfGraphBuilder::GetNodeNameForInstruction( if (ContainsKey(instruction_to_node_name_, instruction)) { return instruction_to_node_name_[instruction]; } + string node_name; // If an instruction is fused, put it in the subgraph of the fusion; // otherwise, put it in the computation subgraph. - string node_name = - instruction->IsFused() - ? GetNodeNameForInstruction(instruction->fusion_instruction()) - : instruction->parent()->name(); + if (instruction->IsFused()) { + node_name = GetNodeNameForInstruction(instruction->fusion_instruction()); + } else { + node_name = instruction->parent()->name(); + if (!instruction->metadata().op_name().empty()) { + // Always make computations contain TF ops but not the other way around. + StrAppend(&node_name, "/", instruction->metadata().op_name()); + } + } string instruction_name = instruction->name(); if (instruction->opcode() == HloOpcode::kParameter) { StrAppend(&instruction_name, ".", instruction->parameter_number()); diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc index df664080228e6e5a682aa1772e89f3380c898852..6041debc4ae0ccbaad99bec9a461b640aeffbccf 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc @@ -137,6 +137,28 @@ TEST_F(HloTfGraphBuilderTest, GreaterThanOrEqualTo) { EXPECT_EQ(graph_def.node(2).op(), "HloGreaterThanOrEqualTo"); } +TEST_F(HloTfGraphBuilderTest, IncorparateTfOpsStructure) { + auto builder = HloComputation::Builder("GE"); + auto param_1 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32_, "param0")); + auto param_2 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32_, "param1")); + auto ge = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kGe, param_1, param_2)); + OpMetadata metadata; + metadata.set_op_name("x/y"); + metadata.set_op_type("Y"); + ge->set_metadata(metadata); + TF_CHECK_OK(generator_.AddComputation(*builder.Build())); + GraphDef graph_def = generator_.GetGraphDef(); + EXPECT_EQ(graph_def.node_size(), 3); + EXPECT_EQ(graph_def.node(0).name(), "GE/param0.0"); + EXPECT_EQ(graph_def.node(1).name(), "GE/param1.1"); + EXPECT_EQ(graph_def.node(2).input_size(), 2); + EXPECT_EQ(graph_def.node(2).name(), "GE/x/y/greater-than-or-equal-to"); + EXPECT_EQ(graph_def.node(2).op(), "HloGreaterThanOrEqualTo"); +} + TEST_F(HloTfGraphBuilderTest, EmbeddedComputationsDiamond) { // Create computations with a diamond-shaped callgraph. auto negate_computation = CreateNegateComputation(); diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 5e7bd4a7ce8a1152973979d4a8fdb790a7fbd219..6384f737b601000d5a9cc2386e5c896ca3a74b50 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -659,44 +659,6 @@ LayoutAssignment::LayoutAssignment(ComputationLayout* entry_computation_layout) } } -namespace { - -// Given a pemutation of `{0, 1, ..., n}` `indices`, returns a permutation of -// `{0, 1, ..., n - to_delete.size() + to_insert.size()}` by deleting the -// indices `to_delete` wherever in `indices` they are, and inserting the indices -// `to_insert` arbitrarily at the back. -tensorflow::protobuf::RepeatedField -DeleteAndInsertIndices( - std::vector to_delete, std::vector to_insert, - tensorflow::protobuf::RepeatedField indices) { - std::sort(to_delete.begin(), to_delete.end(), std::greater()); - std::sort(to_insert.begin(), to_insert.end(), std::less()); - for (auto index : to_delete) { - auto i = indices.begin(); - while (i != indices.end()) { - if (*i == index) { - i = indices.erase(i); - } else { - if (*i > index) { - (*i)--; - } - ++i; - } - } - } - for (auto index : to_insert) { - for (auto i = indices.begin(); i != indices.end(); ++i) { - if (*i >= index) { - (*i)++; - } - } - indices.Add(index); - } - return indices; -} - -} // namespace - std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( const Layout& output_layout, const HloInstruction* instruction, int64 operand_no) { @@ -705,7 +667,8 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( CHECK(ShapeUtil::IsArray(instruction->shape()) && ShapeUtil::IsArray(operand->shape())); - if (instruction->IsElementwiseOnOperand(operand_no) && + if ((instruction->IsElementwiseOnOperand(operand_no) || + InstructionRequiresInputLayoutEqualToOutputLayout(instruction)) && !ShapeUtil::IsScalar(operand->shape()) && ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(instruction->shape())) { @@ -719,21 +682,32 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( } if (instruction->opcode() == HloOpcode::kReshape) { - // Pick the operand layout that makes the reshape a bitcast. If the reshape - // only inserts or deletes degenerate dimensions, we can easily compute the - // desired layout by accordingly inserting and deleting the elements in the - // minor-to-major list. - bool merely_inserts_or_deletes_1_sized_dims; - std::vector inserted_indices, deleted_indices; - std::tie(merely_inserts_or_deletes_1_sized_dims, deleted_indices, - inserted_indices) = - instruction->ReshapeMerelyInsertsOrDeletes1SizedDimensions(); - if (merely_inserts_or_deletes_1_sized_dims) { - Layout operand_layout = LayoutUtil::MakeLayout( - AsInt64Slice(DeleteAndInsertIndices(inserted_indices, deleted_indices, - output_layout.minor_to_major()))); + // Prefer the operand layout that makes the reshape an bitcast. If any + // dimension bound is 1 in the operand shape, there may be several such + // layouts. So if 'output_layout' is a MajorToMinor layout, try if the + // reshape is a bitcast when using the same layout. This may avoid copy + // operations. + const Shape& output_shape = instruction->shape(); + Shape output_shape_with_layout = ShapeUtil::MakeShapeWithLayout( + output_shape.element_type(), AsInt64Slice(output_shape.dimensions()), + AsInt64Slice(output_layout.minor_to_major())); + const Shape& operand_shape = operand->shape(); + if (LayoutUtil::IsMonotonicWithDim0Major(output_layout)) { + Shape operand_shape_with_layout = + ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + operand_shape.element_type(), + AsInt64Slice(operand_shape.dimensions())); + if (ShapeUtil::ReshapeIsBitcast(operand_shape_with_layout, + output_shape_with_layout)) { + return MakeUnique(operand_shape_with_layout.layout()); + } + } + auto aligned_operand_shape = + ShapeUtil::AlignLayouts(output_shape_with_layout, operand_shape); + if (aligned_operand_shape) { + auto operand_layout = aligned_operand_shape.value().layout(); TF_CHECK_OK( - LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape())); + LayoutUtil::ValidateLayoutForShape(operand_layout, operand_shape)); return MakeUnique(operand_layout); } } @@ -768,18 +742,32 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( } if (user->opcode() == HloOpcode::kReshape) { - // Pick the user layout that makes the reshape a bitcast. - bool merely_inserts_or_deletes_1_sized_dims; - std::vector inserted_indices, deleted_indices; - std::tie(merely_inserts_or_deletes_1_sized_dims, deleted_indices, - inserted_indices) = - user->ReshapeMerelyInsertsOrDeletes1SizedDimensions(); - if (merely_inserts_or_deletes_1_sized_dims) { - Layout user_layout = LayoutUtil::MakeLayout(AsInt64Slice( - DeleteAndInsertIndices(deleted_indices, inserted_indices, - operand_layout.minor_to_major()))); + // Prefer the user layout that makes the reshape an bitcast. If any + // dimension bound is 1 in the user shape, there may be several such + // layouts. So if 'operand_layout' is a MajorToMinor layout, try if the + // reshape is a bitcast when using the same layout. This may avoid copy + // operations. + Shape operand_shape_with_layout = ShapeUtil::MakeShapeWithLayout( + operand->shape().element_type(), + AsInt64Slice(operand->shape().dimensions()), + AsInt64Slice(operand_layout.minor_to_major())); + const Shape& output_shape = user->shape(); + if (LayoutUtil::IsMonotonicWithDim0Major(operand_layout)) { + Shape output_shape_with_layout = + ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + output_shape.element_type(), + AsInt64Slice(output_shape.dimensions())); + if (ShapeUtil::ReshapeIsBitcast(output_shape_with_layout, + operand_shape_with_layout)) { + return MakeUnique(output_shape_with_layout.layout()); + } + } + auto aligned_user_shape = + ShapeUtil::AlignLayouts(operand_shape_with_layout, output_shape); + if (aligned_user_shape) { + auto user_layout = aligned_user_shape.value().layout(); TF_CHECK_OK( - LayoutUtil::ValidateLayoutForShape(user_layout, user->shape())); + LayoutUtil::ValidateLayoutForShape(user_layout, output_shape)); return MakeUnique(user_layout); } } @@ -1040,7 +1028,7 @@ StatusOr InferArrayLayout( *first_buffer_layout)) { // The points-to set is ambiguous for this index and the different source // buffers have different layouts. This case is possible in valid XLA - // computations because we do not propagate BufferLayoutConstaints to all + // computations because we do not propagate BufferLayoutConstraints to all // LogicalBuffers which may alias the constrained LogicalBuffer at some // point in the computation. return FailedPrecondition( @@ -1253,7 +1241,7 @@ Status LayoutAssignment::RunOnComputation( TF_ASSIGN_OR_RETURN(auto points_to_analysis, TuplePointsToAnalysis::Run(computation->parent())); - // Construct LayoutConstaints with all layout constraints of the computation. + // Construct LayoutConstraints with all layout constraints of the computation. LayoutConstraints constraints(*points_to_analysis, computation); // Add constraints required for correctness on all backends (eg, entry diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index 61dc7b120752d57cf09423f38546441de2fc8dd9..4f586c334dcdcb02cd7586750d39d6663c0f2703 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -248,6 +248,15 @@ class LayoutAssignment : public HloPassInterface { return Status::OK(); } + // This method can be overriden to mark instructions as requiring the operands + // to have the same layout as the result, for performance or correctness. This + // will propagate constraints through the instruction from the result into the + // operands. + virtual bool InstructionRequiresInputLayoutEqualToOutputLayout( + const HloInstruction* instruction) { + return false; + } + // Construct contraints and assign layouts to all instructions in the // computation satisfying the given ComputationLayout. Layouts constraints are // added, then propagated until all LogicalBuffers in the computation are diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index b6451738bdb4df8ce06efc2becd9f14aef92254d..c6df9839c33a86ee4d96ccece6ffdf4f496bc6fc 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -45,6 +45,8 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { +using ::testing::ElementsAre; + class LayoutAssignmentTest : public HloTestBase { protected: void AssignLayouts(HloModule* module, @@ -317,7 +319,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) { // param -> log -> reshape -> tanh auto builder = HloComputation::Builder(TestName()); Shape ashape = ShapeUtil::MakeShape(F32, {1, 2, 3, 1}); - Shape bshape = ShapeUtil::MakeShape(F32, {2, 1, 3}); + Shape bshape = ShapeUtil::MakeShape(F32, {3, 1, 2}); auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, ashape, "param")); auto log = builder.AddInstruction( @@ -332,8 +334,8 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) { Shape ashape_with_layout(ashape); Shape bshape_with_layout(bshape); - *ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2, 3}); - *bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2}); + *ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 2, 1, 3}); + *bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0}); ComputationLayout computation_layout(computation->ComputeProgramShape()); *computation_layout.mutable_parameter_layout(0) = @@ -343,12 +345,12 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) { auto log_minor_to_major = AsInt64Slice(log->shape().layout().minor_to_major()); - EXPECT_LT(PositionInContainer(log_minor_to_major, 1), + EXPECT_GT(PositionInContainer(log_minor_to_major, 1), PositionInContainer(log_minor_to_major, 2)); auto reshape_minor_to_major = AsInt64Slice(reshape->shape().layout().minor_to_major()); - EXPECT_LT(PositionInContainer(reshape_minor_to_major, 0), + EXPECT_GT(PositionInContainer(reshape_minor_to_major, 0), PositionInContainer(reshape_minor_to_major, 2)); } @@ -421,8 +423,8 @@ TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) { ShapeLayout(output_shape_with_layout); AssignLayouts(&module, &computation_layout); - EXPECT_TRUE( - ContainersEqual(broadcast->shape().layout().minor_to_major(), {0, 1, 2})); + EXPECT_THAT(broadcast->shape().layout().minor_to_major(), + ElementsAre(0, 1, 2)); } TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) { @@ -474,11 +476,9 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) { {transpose_shape_with_layout, broadcast2_shape_with_layout})); AssignLayouts(&module, &computation_layout); - EXPECT_TRUE( - ContainersEqual(broadcast->shape().layout().minor_to_major(), {0, 1})); - EXPECT_TRUE( - ContainersEqual(transpose->shape().layout().minor_to_major(), {1, 0})); - EXPECT_TRUE(ContainersEqual(tanh->shape().layout().minor_to_major(), {0, 1})); + EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1)); + EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(1, 0)); + EXPECT_THAT(tanh->shape().layout().minor_to_major(), ElementsAre(0, 1)); } // Add test which fails due to copy tuple. diff --git a/tensorflow/compiler/xla/service/liveness_util.cc b/tensorflow/compiler/xla/service/liveness_util.cc index 6c5f185ed1ba544e4132777216b9594b5cad7904..16e11ca6c6b3c5ef4dea3cbab5ba6c284e716add 100644 --- a/tensorflow/compiler/xla/service/liveness_util.cc +++ b/tensorflow/compiler/xla/service/liveness_util.cc @@ -28,8 +28,9 @@ limitations under the License. namespace xla { -bool DoesNotUseOperandBuffer(HloInstruction* operand, const ShapeIndex& index, - HloInstruction* user, +bool DoesNotUseOperandBuffer(const HloInstruction* operand, + const ShapeIndex& index, + const HloInstruction* user, const TuplePointsToAnalysis& points_to_analysis) { CHECK(user->IsUserOf(operand)) << "user: " << user->ToString() << " operand: " << operand->ToString(); @@ -98,6 +99,41 @@ std::vector> GetAllUsesOfInstructionAtIndex( return uses; } +// Returns true if there is exactly one use of 'operand' at 'operand_index' +// in 'fusion.fused_instructions', where the singleton use is the fused +// root at operand index 'use_operand_index'. Returns false otherwise. +// +// REQUIRES: 'fusion' opcode is a kFusion instruction. +bool HasUniqueFusedUseOfOperandAt( + HloInstruction* operand, const ShapeIndex& operand_index, + HloInstruction* fusion, const int64 use_operand_index, + const TuplePointsToAnalysis& points_to_analysis) { + CHECK_EQ(HloOpcode::kFusion, fusion->opcode()); + // Check that 'operand' is unique in the operand list of 'fusion'. + if (fusion->OperandIndices(operand).size() > 1) { + return false; + } + // Find fusion parameter associated with 'operand'. + const auto& fused_params = fusion->fused_parameters(); + auto fused_param_it = std::find_if( + fused_params.begin(), fused_params.end(), + [&](HloInstruction* fused_param) { + return fusion->operand(fused_param->parameter_number()) == operand; + }); + if (fused_param_it == fused_params.end()) { + return false; + } + auto* fused_param = *fused_param_it; + // Get all uses of 'operand' at 'index' from 'fusion.fused_instructions'. + auto fused_param_uses = GetAllUsesOfInstructionAtIndex( + fused_param, operand_index, points_to_analysis); + // Return true iff there is exactly one use of 'operand' at 'index', and + // this singleton use is the fused root (at index in 'use_operand_indices'). + return fused_param_uses.size() == 1 && + fused_param_uses[0].first == fusion->fused_expression_root() && + fused_param_uses[0].second == use_operand_index; +} + } // namespace // User and operand can share buffers iff both instructions emit the same shape @@ -106,6 +142,9 @@ std::vector> GetAllUsesOfInstructionAtIndex( // *) Is a loop fusion instruction where the only use of 'operand' at 'index' // in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root // at operand 0. Or... +// *) Is a kDot -> kAdd (or fused kTransposeDot -> kAdd) output fusion +// instruction where the only use of 'operand' at 'index' in the set +// 'user.fused_instructions' is a kAdd fused root at operand 0 or 1. Or... // *) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index 0. bool CanShareOperandBufferWithUser( HloInstruction* operand, const ShapeIndex& operand_index, @@ -125,30 +164,46 @@ bool CanShareOperandBufferWithUser( if (user->opcode() == HloOpcode::kCopy) { return false; } - // Check if 'user' is a loop fusion instruction with a kDynamicUpdateSlice - // fused root instruction. - if (user->opcode() == HloOpcode::kFusion && - user->fusion_kind() == HloInstruction::FusionKind::kLoop && - user->fused_expression_root()->opcode() == - HloOpcode::kDynamicUpdateSlice) { - for (auto& fused_param : user->fused_parameters()) { - // Find fusion parameter associated with 'operand'. - if (user->operand(fused_param->parameter_number()) != operand) { - continue; - } - // Get all uses of 'operand' at 'index' from 'user.fused_instructions'. - auto fused_param_uses = GetAllUsesOfInstructionAtIndex( - fused_param, operand_index, points_to_analysis); - // Return true iff there is exactly one use of 'operand' at 'index', and - // this singleton use is the fused root at operand index 0. - if (fused_param_uses.size() == 1 && - fused_param_uses[0].first == user->fused_expression_root() && - fused_param_uses[0].second == 0) { - return true; + if (user->opcode() == HloOpcode::kFusion) { + if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && + user->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice) { + // Loop fusion with kDynamicUpdateSlice fused root. + // + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root at operand + // index 0. + return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0, + points_to_analysis); + } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && + user->fused_expression_root()->opcode() == HloOpcode::kAdd) { + // Output fusion with kAdd fused root. + + // Check if one operand of kAdd fused root is either kDot, or nested + // kFusion of kind kTransposeDot. + auto* add = user->fused_expression_root(); + auto add_operand_it = + std::find_if(add->operands().begin(), add->operands().end(), + [&](HloInstruction* operand) { + return operand->opcode() == HloOpcode::kDot || + (operand->opcode() == HloOpcode::kFusion && + operand->fusion_kind() == + HloInstruction::FusionKind::kTransposeDot); + }); + if (add_operand_it == add->operands().end()) { + return false; } - break; + auto* matched_add_operand = *add_operand_it; + // Calculate operand index of 'add' operand which was not matched above. + const int64 other_add_operand_index = + matched_add_operand == add->operand(0) ? 1 : 0; + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root (at operand + // index 'other_add_operand_index'). + return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, + other_add_operand_index, + points_to_analysis); } - return false; } if (user->opcode() == HloOpcode::kDynamicUpdateSlice || user->opcode() == HloOpcode::kWhile) { diff --git a/tensorflow/compiler/xla/service/liveness_util.h b/tensorflow/compiler/xla/service/liveness_util.h index 410a7b1b519e117f21c01938cb8e4a5b1c358ad2..52de282ca6b444867c865f845ce794196c98b277 100644 --- a/tensorflow/compiler/xla/service/liveness_util.h +++ b/tensorflow/compiler/xla/service/liveness_util.h @@ -32,8 +32,9 @@ namespace xla { // 'operand'. Returns false otherwise. // // REQUIRES: 'operand' is an operand of 'user'. -bool DoesNotUseOperandBuffer(HloInstruction* operand, const ShapeIndex& index, - HloInstruction* user, +bool DoesNotUseOperandBuffer(const HloInstruction* operand, + const ShapeIndex& index, + const HloInstruction* user, const TuplePointsToAnalysis& points_to_analysis); // Returns true if 'user' (at 'user_index') can share a buffer with its operand diff --git a/tensorflow/compiler/xla/service/liveness_util_test.cc b/tensorflow/compiler/xla/service/liveness_util_test.cc index 1ee02925117f846ee3ad41e151b125e57db22904..49c2c2d4a268d1237ae04903416cf1f6708609d3 100644 --- a/tensorflow/compiler/xla/service/liveness_util_test.cc +++ b/tensorflow/compiler/xla/service/liveness_util_test.cc @@ -34,9 +34,7 @@ class PointsToAnalysisTestBase : public HloTestBase { void RunAnalysis() { CHECK_NOTNULL(module_.get()); points_to_analysis_ = - TuplePointsToAnalysis::Run(module_.get(), - /*include_loop_fusion_instructions=*/true) - .ConsumeValueOrDie(); + TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); } void BuildModuleAndRunAnalysis(std::unique_ptr computation) { @@ -231,6 +229,100 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { CanShareOperandBufferWithUser(starts, {}, dus, {}, *points_to_analysis_)); } +TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto a = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); + auto b = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + + auto dot = builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kDot, a, b)); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto add_operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape, HloOpcode::kAdd, dot, add_operand)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, dot}, HloInstruction::FusionKind::kOutput); + RunAnalysis(); + + // Output fused dot add should be able to share buffer with 'add_operand'. + EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, + *points_to_analysis_)); +} + +TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto a = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); + auto b = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + auto b_t = builder.AddInstruction( + HloInstruction::CreateTranspose(data_shape, b, {1, 0})); + + auto dot = builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kDot, a, b_t)); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto add_operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape, HloOpcode::kAdd, dot, add_operand)); + + BuildModule(builder.Build()); + + auto nested_fusion = computation_->CreateFusionInstruction( + {dot, b_t}, HloInstruction::FusionKind::kTransposeDot); + + auto fusion = computation_->CreateFusionInstruction( + {add, nested_fusion}, HloInstruction::FusionKind::kOutput); + RunAnalysis(); + + // Output fused transpose-dot-add should be share buffer with 'add_operand'. + EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, + *points_to_analysis_)); +} + +TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto reverse = builder.AddInstruction( + HloInstruction::CreateReverse(data_shape, operand, {0, 1})); + + auto two = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, two, reverse}, HloInstruction::FusionKind::kOutput); + RunAnalysis(); + + // Output fused operand->reverse->add cannot alias operand buffer 'operand'. + EXPECT_FALSE(CanShareOperandBufferWithUser(operand, {}, fusion, {}, + *points_to_analysis_)); +} + TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { Shape data_shape = ShapeUtil::MakeShape(F32, {8}); diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 17d7b97b21bd3296711295e0779b0a273c9917e0..78d21233c765ec8f18a865f55b752d418ad126d6 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -60,9 +60,12 @@ namespace xla { TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); } - TF_ASSIGN_OR_RETURN( - std::unique_ptr backend, - Backend::CreateBackend(platform, options.number_of_replicas())); + BackendOptions backend_options; + backend_options.set_platform(platform) + .set_number_of_replicas(options.number_of_replicas()) + .set_intra_op_parallelism_threads(options.intra_op_parallelism_threads()); + TF_ASSIGN_OR_RETURN(std::unique_ptr backend, + Backend::CreateBackend(backend_options)); TF_ASSIGN_OR_RETURN(std::unique_ptr compute_constant_backend, CreateComputeConstantBackend()); @@ -77,21 +80,6 @@ LocalService::LocalService(std::unique_ptr execute_backend, runs_in_client_process_ = true; } -tensorflow::Status LocalService::ResolveArguments( - const tensorflow::gtl::ArraySlice arguments, - int device_ordinal, - std::vector* argument_ptrs) { - TF_ASSIGN_OR_RETURN(std::vector arg_allocations, - ResolveAndValidateArguments( - arguments, execute_backend_.get(), device_ordinal)); - argument_ptrs->resize(arg_allocations.size()); - for (int i = 0; i < arguments.size(); ++i) { - const Allocation& allocation = *arg_allocations[i]; - (*argument_ptrs)[i] = allocation.device_memory(); - } - return tensorflow::Status::OK(); -} - namespace { // Returns the space required to allocate a shape. If // allocate_space_for_deep_copy the space includes all sub-buffers of @@ -128,70 +116,6 @@ StatusOr LocalService::AllocateBufferOnDevice( allocation_size)); } -StatusOr>> -LocalService::CompileAheadOfTime( - const tensorflow::gtl::ArraySlice - computations, - const AotCompilationOptions& options) { - std::vector> hlo_modules; - std::vector> module_configs; - for (const AheadOfTimeComputationInstance& instance : computations) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(instance.computation)); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - - // Dump computation proto state if flag is set. - legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); - const string& directory_path = flags->xla_dump_computations_to; - if (!directory_path.empty()) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr session_module, - computation_tracker_.SnapshotComputation(versioned_handle.handle)); - string filename = tensorflow::strings::StrCat( - "computation_", versioned_handle.handle.handle(), "__", - session_module->entry().name(), "__version_", - versioned_handle.version); - TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename, - *session_module)); - } - - TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, - computation_tracker_.BuildHloModule( - versioned_handle, - /*include_unreachable_instructions=*/true)); - hlo_modules.push_back(std::move(hlo_module)); - - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - module_configs.push_back(MakeUnique(*program_shape)); - HloModuleConfig* module_config = module_configs.back().get(); - auto* computation_layout = - module_config->mutable_entry_computation_layout(); - if (flags->xla_hlo_profile) { - module_config->enable_hlo_profiling(true); - } - for (int i = 0; i < instance.argument_layouts.size(); ++i) { - const Shape& argument_layout = *instance.argument_layouts[i]; - if (ShapeUtil::IsTuple(argument_layout)) { - return Unimplemented("tuple arguments not supported yet"); - } - TF_RETURN_IF_ERROR( - computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( - argument_layout)); - } - TF_RETURN_IF_ERROR( - computation_layout->mutable_result_layout()->CopyLayoutFromShape( - *instance.result_layout)); - } - - return execute_backend_->compiler()->CompileAheadOfTime( - std::move(hlo_modules), std::move(module_configs), MakeHloDumper(), - options); -} - StatusOr> LocalService::CompileExecutable( const ComputationHandle& computation, const tensorflow::gtl::ArraySlice argument_layouts, diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index df27f0a7a60dca99caf09994f417f1bc45ec15de..767a3ab697febb283af448b25369445152381a5e 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -43,14 +43,6 @@ class LocalService : public Service { static StatusOr> NewService( const ServiceOptions& options); - // For an array of arguments, validate that each is placed on the - // specified device_ordinal, and return the DeviceMemoryBase - // corresponding to each argument. - tensorflow::Status ResolveArguments( - const tensorflow::gtl::ArraySlice arguments, - int device_ordinal, - std::vector* argument_ptrs); - // Return a handle to a buffer large enough to hold shape, allocated // on device_ordinal. If allocate_space_for_deep_copy, the buffer is // large enough to hold all sub-buffers of a tuple shape, otherwise @@ -59,22 +51,6 @@ class LocalService : public Service { const Shape& shape, int device_ordinal, bool allocate_space_for_deep_copy); - // A description of a computation to compile using CompileAheadOfTime. - struct AheadOfTimeComputationInstance { - ComputationHandle computation; - std::vector argument_layouts; - const Shape* result_layout = nullptr; - }; - - // Compiles a list of computations for ahead-of-time execution. This is - // intended for use in static compilation. See - // |LocalClient::CompileAheadOfTime| for additional details. - StatusOr>> - CompileAheadOfTime( - const tensorflow::gtl::ArraySlice - computations, - const AotCompilationOptions& Options); - // Builds an Executable with the given argument layouts and options. If // result_layout is non-null, then the executable is compiled to produce a // result of the given layout. diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index b72ef95a6a7964aa1f41cd2ceef4cdee76e9f708..768977ba6bba2f9af55fcd467aa3d91488e4bf0f 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -13,17 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/reshape_mover.h" - -#include -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/util.h" - -namespace xla { - -namespace { - +// Implementation note: +// // The general idea behind this pass is that we're converting from this: // %param.A = OldShape // %param.B = OldShape @@ -44,6 +35,19 @@ namespace { // only implicit scalar broadcast is on Pred, not on A or B. Since reshapes or // transposes to a scalar should be cheap, we simply never move them. +#include "tensorflow/compiler/xla/service/reshape_mover.h" + +#include +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +namespace { + // Finds the first non-scalar operand of an instruction that is a reshape or // transpose and returns the operand if it is found or nullptr if not found. HloInstruction* FirstNonScalarReshapeOperand(const HloInstruction* hlo) { @@ -51,6 +55,9 @@ HloInstruction* FirstNonScalarReshapeOperand(const HloInstruction* hlo) { if (!ShapeUtil::IsScalar(operand->shape()) && (operand->opcode() == HloOpcode::kReshape || operand->opcode() == HloOpcode::kTranspose)) { + VLOG(5) << "Found first non-scalar reshape operand of " + << hlo->ToStringNoMetadata() << ":\n\t" + << operand->ToStringNoMetadata(); return operand; } } @@ -70,6 +77,9 @@ bool OperandCanTrivallyChangeShape(const HloInstruction* instruction, // A constant can trivially reshape the literal it holds. if (operand->opcode() == HloOpcode::kConstant && ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) { + VLOG(5) << "Constant had same dimensions as instruction:\n\toperand: " + << operand->ToStringNoMetadata() + << "\n\tinstruction: " << instruction->ToStringNoMetadata(); return true; } @@ -116,119 +126,159 @@ bool IsElementwiseOfEquivalentReshapesOrTransposes( if (!first_reshape_operand) { return false; } - return (instruction->user_count() > 0 || - instruction == instruction->parent()->root_instruction()) && - instruction->IsElementwise() && !operands.empty() && - // Check whether all operands: - // 1. are all reshapes or transposes that have the same input and - // output shapes as all other reshaped or transposed operands. - // or - // 2. can be any shape like kConstant, kRng, and scalars. - std::all_of( - operands.begin(), operands.end(), - [instruction, - first_reshape_operand](const HloInstruction* operand) { - return AreEquivalentReshapes(first_reshape_operand, operand) || - OperandCanTrivallyChangeShape(instruction, operand); - }); + VLOG(3) << "** Checking whether instruction is an elementwise operation of " + "equivalent reshapes/transposes: " + << instruction->ToStringNoMetadata(); + bool result = + (instruction->user_count() > 0 || + instruction == instruction->parent()->root_instruction()) && + instruction->IsElementwise() && !operands.empty() && + // Check whether all operands: + // 0. Have the same dimensions as the output -- if not, it may be + // implicitly broadcast, which can confound the movement's + // correctness. + // 1. Are all reshapes or transposes that have the same input and + // output shapes as all other reshaped or transposed operands. + // or + // 2. Can be any shape like kConstant, kRng, and scalars. + std::all_of( + operands.begin(), operands.end(), + [instruction, first_reshape_operand](const HloInstruction* operand) { + if (!ShapeUtil::SameDimensions(operand->shape(), + instruction->shape())) { + VLOG(5) << "Operand shape differs from output shape; may be " + "implicitly broadcast, so preventing " + "movement\n\toperand: " + << operand->ToStringNoMetadata() << "\n\tinstruction: " + << instruction->ToStringNoMetadata(); + return false; + } + if (AreEquivalentReshapes(first_reshape_operand, operand)) { + VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: " + << first_reshape_operand->ToStringNoMetadata() + << "\n\toperand: " << operand->ToStringNoMetadata(); + return true; + } + if (OperandCanTrivallyChangeShape(instruction, operand)) { + VLOG(5) << "Operand can trivially change shape: " + << operand->ToStringNoMetadata(); + return true; + } + return false; + }); + VLOG(3) << "ElementwiseOfEquivalentReshapesOrTransposes result for " + << instruction->ToStringNoMetadata() << ": " << result; + return result; } // Try to sink any reshape or transpose operands of `instruction` across it. We // do so if `instruction` is elementwise and all operands are equivalent // reshapes or transposes. -bool TrySinkReshapeOrTranspose(HloComputation* computation, - HloInstruction* instruction) { - if (IsElementwiseOfEquivalentReshapesOrTransposes(instruction)) { - std::vector operands = instruction->operands(); - HloInstruction* old_reshape = FirstNonScalarReshapeOperand(instruction); - CHECK(old_reshape != nullptr); - Shape new_elementwise_shape = old_reshape->operand(0)->shape(); - for (size_t i = 0; i < operands.size(); ++i) { - // All scalar operands remain as-is, even if they're reshape or transpose, - // to simplify handling wrt special scalar broadcast rules for ops like - // Select. Scalar reshapes should be cheap anyways. - if (ShapeUtil::IsScalar(operands[i]->shape())) { - continue; - } - auto element_type = operands[i]->shape().element_type(); - switch (operands[i]->opcode()) { - case HloOpcode::kConstant: { - if (old_reshape->opcode() == HloOpcode::kReshape) { - operands[i] = instruction->parent()->AddInstruction( - HloInstruction::CreateReshape( - ShapeUtil::ChangeElementType(new_elementwise_shape, - element_type), - operands[i])); - } else { - CHECK_EQ(old_reshape->opcode(), HloOpcode::kTranspose); - std::vector inverse_permutation = - InversePermutation(old_reshape->dimensions()); - operands[i] = instruction->parent()->AddInstruction( - HloInstruction::CreateTranspose( - ShapeUtil::ChangeElementType(new_elementwise_shape, - element_type), - operands[i], inverse_permutation)); - } - break; - } - case HloOpcode::kRng: { - CHECK_EQ(operands[i]->user_count(), 1); +StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, + HloInstruction* instruction) { + if (!IsElementwiseOfEquivalentReshapesOrTransposes(instruction)) { + return false; + } + + std::vector operands = instruction->operands(); + HloInstruction* old_reshape = FirstNonScalarReshapeOperand(instruction); + TF_RET_CHECK(old_reshape != nullptr); + Shape new_elementwise_shape = old_reshape->operand(0)->shape(); + + VLOG(3) << "** Trying to sink reshape or transpose: " + << instruction->ToStringNoMetadata() + << "\n\told reshape: " << old_reshape->ToStringNoMetadata() + << "\n\tnew elementwise shape: " + << ShapeUtil::HumanString(new_elementwise_shape); + for (size_t i = 0; i < operands.size(); ++i) { + // All scalar operands remain as-is, even if they're reshape or transpose, + // to simplify handling wrt special scalar broadcast rules for ops like + // Select. Scalar reshapes should be cheap anyways. + if (ShapeUtil::IsScalar(operands[i]->shape())) { + continue; + } + PrimitiveType element_type = operands[i]->shape().element_type(); + switch (operands[i]->opcode()) { + case HloOpcode::kConstant: { + if (old_reshape->opcode() == HloOpcode::kReshape) { + VLOG(3) << "Creating reshape for kConstant operand " << i << ": " + << operands[i]->ToStringNoMetadata(); operands[i] = instruction->parent()->AddInstruction( - operands[i]->CloneWithNewOperands( + HloInstruction::CreateReshape( ShapeUtil::ChangeElementType(new_elementwise_shape, element_type), - operands[i]->operands())); - break; + operands[i])); + } else { + TF_RET_CHECK(old_reshape->opcode() == HloOpcode::kTranspose); + std::vector inverse_permutation = + InversePermutation(old_reshape->dimensions()); + operands[i] = instruction->parent()->AddInstruction( + HloInstruction::CreateTranspose( + ShapeUtil::ChangeElementType(new_elementwise_shape, + element_type), + operands[i], inverse_permutation)); } - case HloOpcode::kReshape: - case HloOpcode::kTranspose: - operands[i] = operands[i]->mutable_operand(0); - break; - default: - LOG(FATAL) << "Unexpected opcode while trying to sink reshapes or " - "transposes."; + break; } - } - if (HloOpcode::kFusion == instruction->opcode()) { - // Here we already know `instruction` is elementwise, and no operand is - // implicit broadcast as if it were the operands would not be equivalent - // reshapes, so all the fused instructions have the same dimensions. - for (const auto& fused_instruction : instruction->fused_instructions()) { - Shape* shape = fused_instruction->mutable_shape(); - *shape->mutable_dimensions() = new_elementwise_shape.dimensions(); - *shape->mutable_layout() = new_elementwise_shape.layout(); + case HloOpcode::kRng: { + CHECK_EQ(operands[i]->user_count(), 1); + operands[i] = instruction->parent()->AddInstruction( + operands[i]->CloneWithNewOperands( + ShapeUtil::ChangeElementType(new_elementwise_shape, + element_type), + operands[i]->operands())); + break; } - } - auto new_elementwise = - computation->AddInstruction(instruction->CloneWithNewOperands( - // `instruction` may change the element type, e.g., from - // operands[0] -> reshape -> convert (`instruction`) - // to - // operands[0] -> convert' -> reshape' - // - // In this case, convert' should have the same element type as - // `convert` and the same dimensions as operands[0]. - ShapeUtil::ChangeElementType(new_elementwise_shape, - instruction->shape().element_type()), - operands)); - std::unique_ptr new_reshape; - switch (old_reshape->opcode()) { case HloOpcode::kReshape: - new_reshape = HloInstruction::CreateReshape(instruction->shape(), - new_elementwise); - break; case HloOpcode::kTranspose: - new_reshape = HloInstruction::CreateTranspose( - instruction->shape(), new_elementwise, old_reshape->dimensions()); + operands[i] = operands[i]->mutable_operand(0); break; default: - LOG(FATAL) << "Bad opcode"; + LOG(FATAL) << "Unexpected opcode while trying to sink reshapes or " + "transposes."; } - TF_CHECK_OK(computation->ReplaceWithNewInstruction(instruction, - std::move(new_reshape))); - return true; } - return false; + if (HloOpcode::kFusion == instruction->opcode()) { + // Here we already know `instruction` is elementwise, and no operand is + // implicit broadcast as if it were the operands would not be equivalent + // reshapes, so all the fused instructions have the same dimensions. + for (const auto& fused_instruction : instruction->fused_instructions()) { + Shape* shape = fused_instruction->mutable_shape(); + *shape->mutable_dimensions() = new_elementwise_shape.dimensions(); + *shape->mutable_layout() = new_elementwise_shape.layout(); + } + } + HloInstruction* new_elementwise = + computation->AddInstruction(instruction->CloneWithNewOperands( + // `instruction` may change the element type, e.g., from + // operands[0] -> reshape -> convert (`instruction`) + // to + // operands[0] -> convert' -> reshape' + // + // In this case, convert' should have the same element type as + // `convert` and the same dimensions as operands[0]. + ShapeUtil::ChangeElementType(new_elementwise_shape, + instruction->shape().element_type()), + operands)); + + std::unique_ptr new_reshape; + switch (old_reshape->opcode()) { + case HloOpcode::kReshape: + VLOG(3) << "Creating new reshape for new elementwise op: " + << new_elementwise->ToStringNoMetadata(); + new_reshape = + HloInstruction::CreateReshape(instruction->shape(), new_elementwise); + break; + case HloOpcode::kTranspose: + new_reshape = HloInstruction::CreateTranspose( + instruction->shape(), new_elementwise, old_reshape->dimensions()); + break; + default: + LOG(FATAL) << "Bad opcode"; + } + TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( + instruction, std::move(new_reshape))); + return true; } } // namespace @@ -237,9 +287,9 @@ StatusOr ReshapeMover::Run(HloModule* module) { bool changed = false; for (const auto& comp : module->computations()) { for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) { - if (TrySinkReshapeOrTranspose(comp.get(), instruction)) { - changed = true; - } + TF_ASSIGN_OR_RETURN(bool did_change, + TrySinkReshapeOrTranspose(comp.get(), instruction)); + changed |= did_change; } } return changed; diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index 1831d775d4a0d8e4e60a31eb91dd1ca4393ec398..5217e85d4fc12e2adc412644b8f11fd11a58039a 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -234,6 +234,58 @@ TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) { EXPECT_EQ(select, computation->root_instruction()); } +// Tree looks like: +// +// param0 [1,128,1] +// | +// reshape [128,1] constant [128,1024] +// \ / +// multiply w/implicit broadcast [128,1024] +// +// The reshape mover would like to sink the reshape below the multiply. +// +// Previously we would attempt to insert a reshape of the constant to [1,128,1] +// (which is unsound, because it has a different number of elements) as +// preparation for sinking the reshape. +// +// To eliminate the unsoundness, we outlaw reshape sinking when one of the +// operands is implicitly broadcast in the elementwise consumer. +// +// TODO(b/37799338) However, it would be possible in this case to do a more +// in-depth analysis to get reshape movement to occur: +// +// 1. Note that the broadcast dimension (logical dimension 1) in the operands +// would map back to logical dimension 2 in the param0 node. +// 2. Match rank of the constant to the param0 node (by prepending a trivial 1 +// dimension). +// 3. Reshape to [128,1024] at the root. +// +// But this is not currently done. +TEST_F(ReshapeMoverTest, ImplicitlyBroadcastReshapeIsNotMovedBug37787999) { + HloComputation::Builder builder(TestName()); + auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 128, 1}), "param0")); + auto reshape = builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {128, 1}), param0)); + Array2D a(128, 1024); + auto literal = LiteralUtil::CreateR2FromArray2D(a); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); + auto multiply = builder.AddInstruction(HloInstruction::CreateBinary( + constant->shape(), HloOpcode::kMultiply, constant, reshape)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), + op::Multiply(op::Constant(), op::Reshape(param0))); + + EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Multiply(op::Constant(), op::Reshape(param0))); + EXPECT_EQ(multiply, computation->root_instruction()); +} + // Tree looks like this: // // add1 diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 451bb8c7eadf3e2210788a722d8f75aa3050e30f..42450dfcae4be71af1002efb72b75857d5c80015 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -112,6 +112,16 @@ ServiceOptions& ServiceOptions::set_number_of_replicas(int number_of_replicas) { int ServiceOptions::number_of_replicas() const { return number_of_replicas_; } +ServiceOptions& ServiceOptions::set_intra_op_parallelism_threads( + int num_threads) { + intra_op_parallelism_threads_ = num_threads; + return *this; +} + +int ServiceOptions::intra_op_parallelism_threads() const { + return intra_op_parallelism_threads_; +} + /* static */ StatusOr> Service::NewService( perftools::gputools::Platform* platform) { ServiceOptions default_options; @@ -126,9 +136,10 @@ int ServiceOptions::number_of_replicas() const { return number_of_replicas_; } if (platform == nullptr) { TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); } - TF_ASSIGN_OR_RETURN( - execute_backend, - Backend::CreateBackend(platform, options.number_of_replicas())); + BackendOptions backend_options; + backend_options.set_platform(platform); + backend_options.set_number_of_replicas(options.number_of_replicas()); + TF_ASSIGN_OR_RETURN(execute_backend, Backend::CreateBackend(backend_options)); TF_ASSIGN_OR_RETURN(std::unique_ptr compute_constant_backend, CreateComputeConstantBackend()); std::unique_ptr service(new Service( @@ -142,7 +153,10 @@ Service::CreateComputeConstantBackend() { PlatformUtil::GetSupportedPlatforms()); for (auto* platform : platforms) { if (platform->id() == se::host::kHostPlatformId) { - return Backend::CreateBackend(platform, /*replica_count=*/1); + BackendOptions backend_options; + backend_options.set_platform(platform); + backend_options.set_number_of_replicas(1); + return Backend::CreateBackend(backend_options); } } return NotFound("CPU platform not found"); @@ -180,20 +194,24 @@ Service::Service(std::unique_ptr execute_backend, std::unique_ptr compute_constant_backend) : execute_backend_(std::move(execute_backend)), compute_constant_backend_(std::move(compute_constant_backend)) { - LOG(INFO) << Printf( - "XLA service %p executing computations on platform %s. Devices:", this, - execute_backend_->platform()->Name().c_str()); - for (int i = 0; i < execute_backend_->device_count(); ++i) { - if (execute_backend_->device_ordinal_supported(i)) { - se::StreamExecutor* executor = - execute_backend_->stream_executor(i).ValueOrDie(); - const auto& description = executor->GetDeviceDescription(); - LOG(INFO) << Printf(" StreamExecutor device (%d): %s, %s", i, - description.name().c_str(), - description.platform_version().c_str()); - } else { - LOG(INFO) << Printf(" StreamExecutor device (%d) not supported", i); + if (execute_backend_) { + LOG(INFO) << Printf( + "XLA service %p executing computations on platform %s. Devices:", this, + execute_backend_->platform()->Name().c_str()); + for (int i = 0; i < execute_backend_->device_count(); ++i) { + if (execute_backend_->device_ordinal_supported(i)) { + se::StreamExecutor* executor = + execute_backend_->stream_executor(i).ValueOrDie(); + const auto& description = executor->GetDeviceDescription(); + LOG(INFO) << Printf(" StreamExecutor device (%d): %s, %s", i, + description.name().c_str(), + description.platform_version().c_str()); + } else { + LOG(INFO) << Printf(" StreamExecutor device (%d) not supported", i); + } } + } else { + VLOG(1) << "XLA compile-only service constructed"; } } @@ -286,7 +304,7 @@ StatusOr> Service::ResolveAndValidateArguments( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions& execution_options) { + const ExecutionOptions& execution_options, Backend* backend) { auto module_config = MakeUnique(program_shape); auto* computation_layout = module_config->mutable_entry_computation_layout(); @@ -326,7 +344,7 @@ StatusOr> Service::CreateModuleConfig( module_config->enable_hlo_profiling(true); } - module_config->set_replica_count(execute_backend_->Replicas().size()); + module_config->set_replica_count(backend->Replicas().size()); module_config->set_fast_math_disabled(execution_options.disable_fast_math()); module_config->set_seed(execution_options.seed()); @@ -474,7 +492,7 @@ StatusOr> Service::BuildAndCacheExecutable( std::unique_ptr executable_unique_ptr, BuildExecutable(versioned_handle, std::move(module_config), /*executable_for_compute_constant=*/false, arguments, - execute_backend_.get(), executor)); + backend, executor)); if (profile != nullptr) { uint64 end_micros = tensorflow::Env::Default()->NowMicros(); @@ -569,21 +587,21 @@ StatusOr Service::ExecuteAndRegisterResult( options.set_inter_op_thread_pool(backend->inter_op_thread_pool()); options.set_intra_op_thread_pool( backend->eigen_intra_op_thread_pool_device()); - run_options.emplace_back(options, backend->StreamBorrower()); + run_options.emplace_back(options, backend->StreamBorrower(), + backend->inter_op_thread_pool()); } perftools::gputools::DeviceMemoryBase result; if (backend->Replicas().size() == 1) { TF_ASSIGN_OR_RETURN( - result, - ExecuteOnStreamWrapper>( - executable, &run_options[0], profile, execute_backend_.get(), - [&arguments](Executable* executable, - const ServiceExecutableRunOptions* run_options, - HloExecutionProfile* hlo_execution_profile) { - return executable->ExecuteOnStream(run_options, arguments, - hlo_execution_profile); - })); + result, ExecuteOnStreamWrapper>( + executable, &run_options[0], profile, backend, + [&arguments](Executable* executable, + const ServiceExecutableRunOptions* run_options, + HloExecutionProfile* hlo_execution_profile) { + return executable->ExecuteOnStream(run_options, arguments, + hlo_execution_profile); + })); } else { std::vector< tensorflow::gtl::ArraySlice> @@ -666,7 +684,8 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, // the program and the argument allocations. TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, CreateModuleConfig(*program_shape, arg_allocations, - request.execution_options())); + request.execution_options(), + execute_backend_.get())); VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -751,9 +770,10 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, ResolveAndValidateArguments(arg->arguments(), execute_backend_.get(), execute_backend_->default_device_ordinal())); - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arg_allocations, - arg->execution_options())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module_config, + CreateModuleConfig(*program_shape, arg_allocations, + arg->execution_options(), execute_backend_.get())); VLOG(3) << "Execute created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -818,9 +838,10 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, ResolveAndValidateArguments(arg->arguments(), execute_backend_.get(), execute_backend_->default_device_ordinal())); - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arg_allocations, - arg->execution_options())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module_config, + CreateModuleConfig(*program_shape, arg_allocations, + arg->execution_options(), execute_backend_.get())); VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -1141,7 +1162,8 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, } TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(program_shape, {}, execution_options)); + CreateModuleConfig(program_shape, {}, execution_options, + compute_constant_backend_.get())); TF_ASSIGN_OR_RETURN( std::shared_ptr executable, diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 9600f6989a40c9180d00ccabbeb29cb37a28900a..05a955137f8dfe7aa085058c5a6673ce8f2f77f1 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -63,9 +63,14 @@ class ServiceOptions { ServiceOptions& set_number_of_replicas(int number_of_replicas); int number_of_replicas() const; + // Sets the thread pool size for parallel execution of an individual operator. + ServiceOptions& set_intra_op_parallelism_threads(int num_threads); + int intra_op_parallelism_threads() const; + private: perftools::gputools::Platform* platform_ = nullptr; int number_of_replicas_ = -1; + int intra_op_parallelism_threads_ = -1; }; // The XLA service object, which is the same across all @@ -265,11 +270,11 @@ class Service : public ServiceInterface { tensorflow::gtl::ArraySlice arguments, const Backend* backend, int device_ordinal); - // Create a Hlo module config foe the given program shape and arguments. + // Create a Hlo module config for the given program shape and arguments. StatusOr> CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions& execution_options); + const ExecutionOptions& execution_options, Backend* backend); // Builds an Executable for the given parameters. If // executable_for_compute_constant is true, then the executable is intended to diff --git a/tensorflow/compiler/xla/service/service_executable_run_options.h b/tensorflow/compiler/xla/service/service_executable_run_options.h index 0d4b214f5f3624971ae68e23f0f4fdba846f9178..017e5ef09ed2f52b862821e9408540d188a1edf5 100644 --- a/tensorflow/compiler/xla/service/service_executable_run_options.h +++ b/tensorflow/compiler/xla/service/service_executable_run_options.h @@ -30,10 +30,12 @@ class ServiceExecutableRunOptions { using StreamBorrower = std::function::SmartPtr>(int)>; - explicit ServiceExecutableRunOptions(ExecutableRunOptions run_options, - StreamBorrower borrow_stream = nullptr) + explicit ServiceExecutableRunOptions( + ExecutableRunOptions run_options, StreamBorrower borrow_stream = nullptr, + tensorflow::thread::ThreadPool* xla_intra_op_thread_pool = nullptr) : run_options_(std::move(run_options)), - borrow_stream_(std::move(borrow_stream)) {} + borrow_stream_(std::move(borrow_stream)), + xla_intra_op_thread_pool_(xla_intra_op_thread_pool) {} // Returns reference or pointer to `ExecutableRunOptions` member. const ExecutableRunOptions& run_options() const { return run_options_; } @@ -53,9 +55,15 @@ class ServiceExecutableRunOptions { : Status(tensorflow::error::UNIMPLEMENTED, "No stream cache"); } + // Returns reference to thread pool for execution of XLA ops on CPU backend. + tensorflow::thread::ThreadPool* xla_intra_op_thread_pool() const { + return xla_intra_op_thread_pool_; + } + private: ExecutableRunOptions run_options_; StreamBorrower borrow_stream_; + tensorflow::thread::ThreadPool* xla_intra_op_thread_pool_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 338d63f1a002b490ac3017afafdf3743eb29b503..b2ef8ed486b5ab4643cb0e26fa6c18e1f3894a4b 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -244,8 +244,11 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, } if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) { return InvalidArgument( - "cannot concatenate arrays with different ranks: %lld vs %lld", - ShapeUtil::Rank(*arg_shape), ShapeUtil::Rank(*shape)); + "Cannot concatenate arrays with different ranks: %lld (%s) vs %lld " + "(%s)", + ShapeUtil::Rank(*arg_shape), + ShapeUtil::HumanString(*arg_shape).c_str(), ShapeUtil::Rank(*shape), + ShapeUtil::HumanString(*shape).c_str()); } if (arg_shape->element_type() != shape->element_type()) { return InvalidArgument( diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index cfb90e6e1d49ff49572977d938a53593970ad912..a0c88c6bbc23972bb6a0f3729e51ee0eaee72bc7 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -76,8 +76,7 @@ using InstructionOperandsPair = // the parent HLO computation of `dot`. // // Returns whether the module is changed. -bool FoldTransposeIntoDot(InstructionOperandsPair pair, - HloComputation* computation) { +bool FoldTransposeIntoDot(InstructionOperandsPair pair) { auto* dot = pair.first; std::vector instructions_to_fuse(1, dot); for (const int64 operand_index : pair.second) { @@ -89,7 +88,7 @@ bool FoldTransposeIntoDot(InstructionOperandsPair pair, return false; } - computation->CreateFusionInstruction( + dot->parent()->CreateFusionInstruction( instructions_to_fuse, HloInstruction::FusionKind::kTransposeDot); return true; } @@ -98,8 +97,7 @@ bool FoldTransposeIntoDot(InstructionOperandsPair pair, // `computation` is the parent HLO computation of `convolution`. // // Returns whether the module is changed. -bool FoldTransposeIntoConvolution(InstructionOperandsPair pair, - HloComputation* computation) { +bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { auto& convolution = *pair.first; // We only support fusing the RHS transpose into convolution. @@ -135,8 +133,8 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair, auto new_conv = HloInstruction::CreateConvolve( convolution.shape(), convolution.mutable_operand(0), &transpose_operand, convolution.window(), new_dnums); - TF_CHECK_OK(computation->ReplaceWithNewInstruction(&convolution, - std::move(new_conv))); + TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction( + &convolution, std::move(new_conv))); return true; } @@ -152,8 +150,6 @@ TransposeFolding::TransposeFolding( StatusOr TransposeFolding::Run(HloModule* module) { // Modifying the graph while traversing is dangerous, so we find all folding // opportunities before actually folding them. - HloComputation* entry_computation = module->entry_computation(); - std::vector> foldable_dots; std::vector> foldable_convolutions; auto visit_fn = [this, &foldable_dots, @@ -175,14 +171,17 @@ StatusOr TransposeFolding::Run(HloModule* module) { } return tensorflow::Status::OK(); }; - TF_RETURN_IF_ERROR(entry_computation->root_instruction()->Accept(visit_fn)); + + for (auto& comp : module->computations()) { + TF_RETURN_IF_ERROR(comp->Accept(visit_fn)); + } bool changed = false; for (InstructionOperandsPair& pair : foldable_dots) { - changed |= FoldTransposeIntoDot(pair, entry_computation); + changed |= FoldTransposeIntoDot(pair); } for (InstructionOperandsPair& pair : foldable_convolutions) { - changed |= FoldTransposeIntoConvolution(pair, entry_computation); + changed |= FoldTransposeIntoConvolution(pair); } return changed; } diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 6643f541daeb5f3dd3f36e1063eea951e604ad69..c72d127ea86e4e9daf99dff4335c538c081f0605 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -41,9 +41,7 @@ class TransposeFoldingTest : public ::testing::Test { TransposeFolding transpose_folding( [](const HloInstruction& dot, const TransposeFolding::OperandIndices& candidate_operands) { - return gpu::ImplementedAsGemm(dot) - ? candidate_operands - : TransposeFolding::OperandIndices{}; + return candidate_operands; }, [](const HloInstruction& convolution, const TransposeFolding::OperandIndices& candidate_operands) { @@ -159,6 +157,50 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { EXPECT_EQ(6, callee_computation->instructions().size()); } +TEST_F(TransposeFoldingTest, FoldDotTransposeInWhile) { + auto builder = HloComputation::Builder("entry_computation"); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3}), + /*name=*/"x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3}), + /*name=*/"y")); + HloInstruction* transpose_y = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0})); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {2, 2}), /*opcode=*/HloOpcode::kDot, + /*lhs=*/x, /*rhs=*/transpose_y)); + + HloModule module("test_module"); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build(dot)); + + HloInstruction* call = module.OutlineExpressionFromComputation( + {transpose_y, dot}, "outlined", entry_computation); + + FoldTranspose(&module); + + // Instructions after folding: x, y, and the fusion. + std::unordered_set instruction_set; + for (auto& instruction : entry_computation->instructions()) { + instruction_set.insert(instruction.get()); + } + CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; + CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; + CHECK_EQ(1, instruction_set.erase(call)) + << "call is not in entry_computation."; + CHECK(instruction_set.empty()) + << "entry_computation should contain exactly 3 instructions."; + HloInstruction* fusion = + call->called_computations().front()->root_instruction(); + EXPECT_EQ(HloOpcode::kFusion, fusion->opcode()); + + // The fusion instruction should contain two parameters, one transpose and + // one dot. + EXPECT_EQ(4, fusion->fused_instructions().size()); +} + // Test that a two dimension swap of the kernel gets folded into convolution. TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { auto builder = HloComputation::Builder("entry_computation"); diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 98c51b48f9022c5f2d1e23b59a6ce775f3a48e0b..554adaf0e32f7cb896e07a59d5235ff84a11bb92 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -131,10 +131,9 @@ void PointsToSet::add_tuple_source(const ShapeIndex& index, } /* static */ StatusOr> -TuplePointsToAnalysis::Run(const HloModule* module, - const bool include_loop_fusion_instructions) { +TuplePointsToAnalysis::Run(const HloModule* module) { std::unique_ptr analysis( - new TuplePointsToAnalysis(module, include_loop_fusion_instructions)); + new TuplePointsToAnalysis(module)); TF_RETURN_IF_ERROR(analysis->Analyze()); return std::move(analysis); } @@ -145,17 +144,14 @@ Status TuplePointsToAnalysis::Analyze() { TF_RETURN_IF_ERROR(computation->Accept(this)); TF_RETURN_IF_ERROR( PopulateDefinedBuffersAndAliases(computation->instructions())); - if (include_loop_fusion_instructions_) { - // Run points-to analysis on loop fusion instructions in 'computation'. - for (auto& instruction : computation->instructions()) { - if (instruction->opcode() != HloOpcode::kFusion || - instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) { - continue; - } - TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this)); - TF_RETURN_IF_ERROR(PopulateDefinedBuffersAndAliases( - instruction->fused_instructions())); + // Run points-to analysis on fusion instructions in 'computation'. + for (auto& instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kFusion) { + continue; } + TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this)); + TF_RETURN_IF_ERROR( + PopulateDefinedBuffersAndAliases(instruction->fused_instructions())); } } @@ -482,9 +478,7 @@ string TuplePointsToAnalysis::ToString() const { for (const HloInstruction* instruction : computation->MakeInstructionPostOrder()) { InstructionToString(instruction, &output); - if (include_loop_fusion_instructions_ && - instruction->opcode() == HloOpcode::kFusion && - instruction->fusion_kind() == HloInstruction::FusionKind::kLoop) { + if (instruction->opcode() == HloOpcode::kFusion) { for (auto& fused : instruction->fused_instructions()) { InstructionToString(fused.get(), &output); } diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index a384529171a7371c848ca8949d22cb6717d83a78..85a71b56ce5e9fb1a3441c302e18bd1fa7b68864 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -148,12 +148,9 @@ std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias); // the potential sources of each buffer in each instruction's output. class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { public: - // Runs points-to analysis on 'module'. If 'include_loop_fusion_instructions' - // is true, includes fused instructions from each loop fusion instruction - // in 'module' in the points-to analysis. + // Runs points-to analysis on 'module'. static StatusOr> Run( - const HloModule* module, - const bool include_loop_fusion_instructions = false); + const HloModule* module); // Return the points-to set of an instruction. This describes the potential // sources of each buffer in the instruction's output. @@ -218,10 +215,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { string ToString() const; private: - explicit TuplePointsToAnalysis(const HloModule* module, - const bool include_loop_fusion_instructions) - : module_(module), - include_loop_fusion_instructions_(include_loop_fusion_instructions) {} + explicit TuplePointsToAnalysis(const HloModule* module) : module_(module) {} // Perform the analysis. Should be called immediately after constructing the // object and before calling GetPointsToSet. @@ -261,9 +255,6 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { // The module this analysis is performed on. const HloModule* module_; - // Whether to run points-to analysis on loop fusion instructions in 'module_'. - const bool include_loop_fusion_instructions_; - // A map containing a PointsToSet for every HLO instruction. tensorflow::gtl::FlatMap> points_to_; diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index 808050bdabd188a51d03141fc7ebe3500b2cf110..87e1b058b79c0dc327cc1ad63a8cffa97c190df4 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -52,11 +52,10 @@ class TuplePointsToAnalysisTest : public HloTestBase { module_->AddEntryComputation(std::move(computation)); } - void RunAnalysis(const bool include_loop_fusion_instructions = false) { + void RunAnalysis() { CHECK_NOTNULL(module_.get()); - points_to_analysis_ = TuplePointsToAnalysis::Run( - module_.get(), include_loop_fusion_instructions) - .ConsumeValueOrDie(); + points_to_analysis_ = + TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); } // Returns the LogicalBuffer defined at the given instruction and @@ -609,7 +608,7 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest { auto* fusion = module_->entry_computation()->root_instruction(); EXPECT_THAT(fusion, op::Fusion(tuple_param0)); // Run points-to analysis (should include fused instructions from 'fusion'). - RunAnalysis(/*include_loop_fusion_instructions=*/true); + RunAnalysis(); // Check points-to set of fusion parameter associated with 'tuple_param0'. auto* fusion_param = GetFusionParameterForOperand(fusion, tuple_param0); diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index 34e8ee8acade129f2f43a399cb807b2032cd95a6..e9fcc9fa6666bb2e3c24252e1c0f5e8d763a5d48 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -1928,6 +1928,12 @@ HloInstruction* ComputationLowerer::Visit( const OperationRequest& request = session_computation_.requests().at(handle.handle()); + auto add_instruction = [&](std::unique_ptr instruction) { + HloInstruction* hlo_instruction = + hlo_builder_.AddInstruction(std::move(instruction)); + hlo_instruction->set_metadata(request.request().metadata()); + return hlo_instruction; + }; HloInstruction* hlo_instruction; switch (request.request().op_case()) { case OpRequest::kRngRequest: { @@ -1936,7 +1942,7 @@ HloInstruction* ComputationLowerer::Visit( for (const ComputationDataHandle& param : rng_request.parameter()) { parameters.push_back(Visit(param, visited)); } - hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateRng( + hlo_instruction = add_instruction(HloInstruction::CreateRng( request.output_shape(), rng_request.distribution(), parameters)); break; } @@ -1944,9 +1950,8 @@ HloInstruction* ComputationLowerer::Visit( case OpRequest::kConstantRequest: { const ConstantRequest& constant_request = request.request().constant_request(); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CloneToUnique(constant_request.literal()))); + hlo_instruction = add_instruction(HloInstruction::CreateConstant( + LiteralUtil::CloneToUnique(constant_request.literal()))); break; } @@ -1955,17 +1960,15 @@ HloInstruction* ComputationLowerer::Visit( request.request().get_tuple_element_request(); HloInstruction* operand = Visit(get_tuple_element_request.operand(), visited); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateGetTupleElement( - request.output_shape(), operand, - get_tuple_element_request.index())); + hlo_instruction = add_instruction(HloInstruction::CreateGetTupleElement( + request.output_shape(), operand, get_tuple_element_request.index())); break; } case OpRequest::kSliceRequest: { const SliceRequest& slice_request = request.request().slice_request(); HloInstruction* operand = Visit(slice_request.operand(), visited); - hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateSlice( + hlo_instruction = add_instruction(HloInstruction::CreateSlice( request.output_shape(), operand, AsInt64Slice(slice_request.start_indices()), AsInt64Slice(slice_request.limit_indices()))); @@ -1979,10 +1982,9 @@ HloInstruction* ComputationLowerer::Visit( HloInstruction* start_indices = Visit(dynamic_slice_request.start_indices(), visited); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateDynamicSlice( - request.output_shape(), operand, start_indices, - AsInt64Slice(dynamic_slice_request.slice_sizes()))); + hlo_instruction = add_instruction(HloInstruction::CreateDynamicSlice( + request.output_shape(), operand, start_indices, + AsInt64Slice(dynamic_slice_request.slice_sizes()))); break; } @@ -1996,7 +1998,7 @@ HloInstruction* ComputationLowerer::Visit( HloInstruction* start_indices = Visit(dynamic_update_slice_request.start_indices(), visited); hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + add_instruction(HloInstruction::CreateDynamicUpdateSlice( request.output_shape(), operand, update, start_indices)); break; } @@ -2010,9 +2012,8 @@ HloInstruction* ComputationLowerer::Visit( HloInstruction* operand = Visit(handle, visited); operands.push_back(operand); } - hlo_instruction = hlo_builder_.AddInstruction( - HloInstruction::CreateConcatenate(request.output_shape(), operands, - concatenate_request.dimension())); + hlo_instruction = add_instruction(HloInstruction::CreateConcatenate( + request.output_shape(), operands, concatenate_request.dimension())); break; } @@ -2021,10 +2022,9 @@ HloInstruction* ComputationLowerer::Visit( request.request().convolve_request(); HloInstruction* lhs = Visit(convolve_request.lhs(), visited); HloInstruction* rhs = Visit(convolve_request.rhs(), visited); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateConvolve( - request.output_shape(), lhs, rhs, convolve_request.window(), - convolve_request.dimension_numbers())); + hlo_instruction = add_instruction(HloInstruction::CreateConvolve( + request.output_shape(), lhs, rhs, convolve_request.window(), + convolve_request.dimension_numbers())); break; } @@ -2033,17 +2033,15 @@ HloInstruction* ComputationLowerer::Visit( request.request().cross_replica_sum_request(); HloInstruction* operand = Visit(cross_replica_sum_request.operand(), visited); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateCrossReplicaSum( - request.output_shape(), operand)); + hlo_instruction = add_instruction(HloInstruction::CreateCrossReplicaSum( + request.output_shape(), operand)); break; } case OpRequest::kInfeedRequest: { const InfeedRequest& infeed_request = request.request().infeed_request(); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateInfeed( - request.output_shape(), infeed_request.config())); + hlo_instruction = add_instruction(HloInstruction::CreateInfeed( + request.output_shape(), infeed_request.config())); break; } @@ -2051,9 +2049,8 @@ HloInstruction* ComputationLowerer::Visit( const OutfeedRequest& outfeed_request = request.request().outfeed_request(); HloInstruction* operand = Visit(outfeed_request.operand(), visited); - hlo_instruction = hlo_builder_.AddInstruction( - HloInstruction::CreateOutfeed(outfeed_request.shape(), operand, - outfeed_request.outfeed_config())); + hlo_instruction = add_instruction(HloInstruction::CreateOutfeed( + outfeed_request.shape(), operand, outfeed_request.outfeed_config())); break; } @@ -2069,7 +2066,7 @@ HloInstruction* ComputationLowerer::Visit( request.embedded_computation_versions(0); HloComputation* map_computation = ResolveComputation(map_request.to_apply(), map_version); - hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateMap( + hlo_instruction = add_instruction(HloInstruction::CreateMap( request.output_shape(), operands, map_computation)); break; } @@ -2083,10 +2080,9 @@ HloInstruction* ComputationLowerer::Visit( request.embedded_computation_versions(0); HloComputation* reduce_computation = ResolveComputation(reduce_request.to_apply(), reduce_version); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateReduce( - request.output_shape(), operand, init_value, - AsInt64Slice(reduce_request.dimensions()), reduce_computation)); + hlo_instruction = add_instruction(HloInstruction::CreateReduce( + request.output_shape(), operand, init_value, + AsInt64Slice(reduce_request.dimensions()), reduce_computation)); break; } @@ -2101,10 +2097,9 @@ HloInstruction* ComputationLowerer::Visit( request.embedded_computation_versions(0); HloComputation* reduce_window_computation = ResolveComputation( reduce_window_request.to_apply(), reduce_window_version); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateReduceWindow( - request.output_shape(), operand, init_value, - reduce_window_request.window(), reduce_window_computation)); + hlo_instruction = add_instruction(HloInstruction::CreateReduceWindow( + request.output_shape(), operand, init_value, + reduce_window_request.window(), reduce_window_computation)); break; } @@ -2126,11 +2121,10 @@ HloInstruction* ComputationLowerer::Visit( select_and_scatter_request.select(), select_version); HloComputation* scatter_computation = ResolveComputation( select_and_scatter_request.scatter(), scatter_version); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateSelectAndScatter( - request.output_shape(), operand, select_computation, - select_and_scatter_request.window(), source, init_value, - scatter_computation)); + hlo_instruction = add_instruction(HloInstruction::CreateSelectAndScatter( + request.output_shape(), operand, select_computation, + select_and_scatter_request.window(), source, init_value, + scatter_computation)); break; } @@ -2151,9 +2145,8 @@ HloInstruction* ComputationLowerer::Visit( ShapeUtil::Rank(request.output_shape()) - ShapeUtil::Rank(operand->shape())); } - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast( - request.output_shape(), operand, broadcast_dimensions)); + hlo_instruction = add_instruction(HloInstruction::CreateBroadcast( + request.output_shape(), operand, broadcast_dimensions)); break; } @@ -2165,14 +2158,13 @@ HloInstruction* ComputationLowerer::Visit( if (IsIdentityPermutation(AsInt64Slice(reshape_request.dimensions()))) { transposed = operand; } else { - transposed = - hlo_builder_.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::PermuteDimensions(InversePermutation(AsInt64Slice( - reshape_request.dimensions())), - operand->shape()), - operand, AsInt64Slice(reshape_request.dimensions()))); + transposed = add_instruction(HloInstruction::CreateTranspose( + ShapeUtil::PermuteDimensions( + InversePermutation(AsInt64Slice(reshape_request.dimensions())), + operand->shape()), + operand, AsInt64Slice(reshape_request.dimensions()))); } - hlo_instruction = hlo_builder_.AddInstruction( + hlo_instruction = add_instruction( HloInstruction::CreateReshape(request.output_shape(), transposed)); break; } @@ -2181,12 +2173,11 @@ HloInstruction* ComputationLowerer::Visit( const TransposeRequest& transpose_request = request.request().transpose_request(); HloInstruction* operand = Visit(transpose_request.operand(), visited); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::PermuteDimensions(InversePermutation(AsInt64Slice( - transpose_request.dimensions())), - operand->shape()), - operand, AsInt64Slice(transpose_request.dimensions()))); + hlo_instruction = add_instruction(HloInstruction::CreateTranspose( + ShapeUtil::PermuteDimensions( + InversePermutation(AsInt64Slice(transpose_request.dimensions())), + operand->shape()), + operand, AsInt64Slice(transpose_request.dimensions()))); break; } @@ -2194,10 +2185,9 @@ HloInstruction* ComputationLowerer::Visit( const ReverseRequest& reverse_request = request.request().reverse_request(); HloInstruction* operand = Visit(reverse_request.operand(), visited); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateReverse( - request.output_shape(), operand, - AsInt64Slice(reverse_request.dimensions()))); + hlo_instruction = add_instruction(HloInstruction::CreateReverse( + request.output_shape(), operand, + AsInt64Slice(reverse_request.dimensions()))); break; } @@ -2206,7 +2196,7 @@ HloInstruction* ComputationLowerer::Visit( HloInstruction* operand = Visit(pad_request.operand(), visited); HloInstruction* padding_value = Visit(pad_request.padding_value(), visited); - hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreatePad( + hlo_instruction = add_instruction(HloInstruction::CreatePad( request.output_shape(), operand, padding_value, pad_request.padding_config())); break; @@ -2214,7 +2204,7 @@ HloInstruction* ComputationLowerer::Visit( case OpRequest::kRecvRequest: { const RecvRequest& recv_request = request.request().recv_request(); - hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateRecv( + hlo_instruction = add_instruction(HloInstruction::CreateRecv( request.output_shape(), recv_request.channel_handle().handle())); break; } @@ -2222,10 +2212,9 @@ HloInstruction* ComputationLowerer::Visit( case OpRequest::kParameterRequest: { const ParameterRequest& parameter_request = request.request().parameter_request(); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateParameter( - parameter_request.parameter(), request.output_shape(), - parameter_request.name())); + hlo_instruction = add_instruction(HloInstruction::CreateParameter( + parameter_request.parameter(), request.output_shape(), + parameter_request.name())); break; } @@ -2233,7 +2222,7 @@ HloInstruction* ComputationLowerer::Visit( const ConvertRequest& convert_request = request.request().convert_request(); HloInstruction* operand = Visit(convert_request.operand(), visited); - hlo_instruction = hlo_builder_.AddInstruction( + hlo_instruction = add_instruction( HloInstruction::CreateConvert(request.output_shape(), operand)); break; } @@ -2250,7 +2239,7 @@ HloInstruction* ComputationLowerer::Visit( HloComputation* body = ResolveComputation(while_request.body(), body_version); HloInstruction* init = Visit(while_request.init(), visited); - hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateWhile( + hlo_instruction = add_instruction(HloInstruction::CreateWhile( request.output_shape(), condition, body, init)); break; } @@ -2262,9 +2251,8 @@ HloInstruction* ComputationLowerer::Visit( HloInstruction* rhs = Visit(ternary_op_request.rhs(), visited); HloInstruction* ehs = Visit(ternary_op_request.ehs(), visited); auto hlo_opcode = TernaryOperationToHloOpcode(ternary_op_request.triop()); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateTernary( - request.output_shape(), hlo_opcode, lhs, rhs, ehs)); + hlo_instruction = add_instruction(HloInstruction::CreateTernary( + request.output_shape(), hlo_opcode, lhs, rhs, ehs)); break; } @@ -2279,9 +2267,8 @@ HloInstruction* ComputationLowerer::Visit( } auto hlo_opcode = VariadicOperationToHloOpcode(variadic_op_request.varop()); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateVariadic( - request.output_shape(), hlo_opcode, operands)); + hlo_instruction = add_instruction(HloInstruction::CreateVariadic( + request.output_shape(), hlo_opcode, operands)); break; } @@ -2296,7 +2283,7 @@ HloInstruction* ComputationLowerer::Visit( request.embedded_computation_versions(0); HloComputation* call_computation = ResolveComputation(call_request.to_apply(), call_version); - hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateCall( + hlo_instruction = add_instruction(HloInstruction::CreateCall( request.output_shape(), operands, call_computation)); break; } @@ -2308,9 +2295,8 @@ HloInstruction* ComputationLowerer::Visit( for (const ComputationDataHandle& operand : cc_request.operands()) { operands.push_back(Visit(operand, visited)); } - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateCustomCall( - cc_request.shape(), operands, cc_request.call_target_name())); + hlo_instruction = add_instruction(HloInstruction::CreateCustomCall( + cc_request.shape(), operands, cc_request.call_target_name())); break; } @@ -2319,7 +2305,7 @@ HloInstruction* ComputationLowerer::Visit( request.request().unary_op_request(); HloInstruction* operand = Visit(unary_op_request.operand(), visited); auto hlo_opcode = UnaryOperationToHloOpcode(unary_op_request.unop()); - hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateUnary( + hlo_instruction = add_instruction(HloInstruction::CreateUnary( request.output_shape(), hlo_opcode, operand)); break; } @@ -2347,23 +2333,22 @@ HloInstruction* ComputationLowerer::Visit( // identical to the HLO broadcast semantics so the broadcast_dimensions // field can just be passed to the instruction builder. HloInstruction* broadcasted_operand = - hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast( + add_instruction(HloInstruction::CreateBroadcast( broadcast_shape, operand_to_broadcast, AsInt64Slice(binary_op_request.broadcast_dimensions()))); lhs = (lhs == operand_to_broadcast) ? broadcasted_operand : lhs; rhs = (rhs == operand_to_broadcast) ? broadcasted_operand : rhs; } - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateBinary( - request.output_shape(), hlo_opcode, lhs, rhs)); + hlo_instruction = add_instruction(HloInstruction::CreateBinary( + request.output_shape(), hlo_opcode, lhs, rhs)); break; } case OpRequest::kTraceRequest: { const TraceRequest& trace_request = request.request().trace_request(); HloInstruction* operand = Visit(trace_request.operand(), visited); - hlo_instruction = hlo_builder_.AddInstruction( + hlo_instruction = add_instruction( HloInstruction::CreateTrace(trace_request.tag(), operand)); operand->set_tracing(hlo_instruction); break; @@ -2372,7 +2357,7 @@ HloInstruction* ComputationLowerer::Visit( case OpRequest::kSendRequest: { const SendRequest& send_request = request.request().send_request(); HloInstruction* operand = Visit(send_request.operand(), visited); - hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateSend( + hlo_instruction = add_instruction(HloInstruction::CreateSend( operand, send_request.channel_handle().handle())); break; } @@ -2383,7 +2368,6 @@ HloInstruction* ComputationLowerer::Visit( default: LOG(FATAL) << "Unexpected request type: " << request.request().op_case(); } - hlo_instruction->set_metadata(request.request().metadata()); (*visited)[handle.handle()] = hlo_instruction; return hlo_instruction; } diff --git a/tensorflow/compiler/xla/service/user_computation_test.cc b/tensorflow/compiler/xla/service/user_computation_test.cc index 032b5cfac604a92bdf150a7fcee57e91bee65508..cf04cfde5003d70e26ce0a1543039c18c19282c9 100644 --- a/tensorflow/compiler/xla/service/user_computation_test.cc +++ b/tensorflow/compiler/xla/service/user_computation_test.cc @@ -59,6 +59,9 @@ TEST_F(UserComputationTest, SimpleComputation) { param_request.set_name("param0"); TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle param_handle, computation.AddParameterInstruction(param_request)); + OpMetadata metadata; + metadata.set_op_name("meta"); + TF_ASSERT_OK(computation.SetOpMetadata(param_handle, metadata)); OutfeedRequest outfeed_request; *outfeed_request.mutable_operand() = constant_handle; @@ -135,6 +138,8 @@ TEST_F(UserComputationTest, SimpleComputation) { // The root of the instruction should be the parameter instruction (not the // outfeed). EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter()); + EXPECT_EQ(hlo_computation->root_instruction()->metadata().op_name(), + "meta"); } } diff --git a/tensorflow/compiler/xla/service_interface.h b/tensorflow/compiler/xla/service_interface.h index 2159386152b34e4f9b59ca14faa756e37551d724..c8851d2ca512450b4022e0f70d55399323b2fa08 100644 --- a/tensorflow/compiler/xla/service_interface.h +++ b/tensorflow/compiler/xla/service_interface.h @@ -21,7 +21,10 @@ limitations under the License. namespace xla { -// Defines the interface for an XLA service. +// Defines the interface for an XLA service on the client side. This service +// helps abstract around the actual implementation of a service - the service +// can be local (running in the same process), or remote - in which case an RPC +// stub is used as the implementation. class ServiceInterface { public: ServiceInterface() {} diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 57d91e4bfc1145faa25c9b5c57422c7653d4a163..2b32b78f0b7c39dbf16b61f17d98c81027d013b0 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include "tensorflow/compiler/xla/index_util.h" @@ -28,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -675,7 +677,7 @@ namespace { // Helper for ForEachSubshape which visits the subshapes of the given shape in // DFS pre-order starting with the index. Status ForEachSubshapeHelper(const Shape& shape, - const ShapeUtil::VisitorFunction func, + const ShapeUtil::VisitorFunction& func, ShapeIndex* index) { TF_RETURN_IF_ERROR(func(shape, *index)); if (ShapeUtil::IsTuple(shape)) { @@ -692,7 +694,7 @@ Status ForEachSubshapeHelper(const Shape& shape, // Helper for ForEachMutableSubshape which visits the subshapes of the given // shape in DFS pre-order starting with the index. Status ForEachMutableSubshapeHelper( - Shape* shape, const ShapeUtil::MutatingVisitorFunction func, + Shape* shape, const ShapeUtil::MutatingVisitorFunction& func, ShapeIndex* index) { TF_RETURN_IF_ERROR(func(shape, *index)); if (ShapeUtil::IsTuple(*shape)) { @@ -709,13 +711,13 @@ Status ForEachMutableSubshapeHelper( } // namespace /* static */ Status ShapeUtil::ForEachSubshape(const Shape& shape, - VisitorFunction func) { + const VisitorFunction& func) { ShapeIndex index; return ForEachSubshapeHelper(shape, func, &index); } /* static */ Status ShapeUtil::ForEachMutableSubshape( - Shape* shape, MutatingVisitorFunction func) { + Shape* shape, const MutatingVisitorFunction& func) { ShapeIndex index; return ForEachMutableSubshapeHelper(shape, func, &index); } @@ -728,9 +730,17 @@ Status ForEachMutableSubshapeHelper( new_shape.add_dimensions(dim); } if (shape.has_layout()) { - new_shape.mutable_layout()->clear_minor_to_major(); + Layout* new_layout = new_shape.mutable_layout(); + new_layout->clear_minor_to_major(); for (auto index : Permute(permutation, shape.layout().minor_to_major())) { - new_shape.mutable_layout()->add_minor_to_major(index); + new_layout->add_minor_to_major(index); + } + if (shape.layout().padded_dimensions_size() > 0) { + new_layout->clear_padded_dimensions(); + for (auto dim : + Permute(permutation, shape.layout().padded_dimensions())) { + new_layout->add_padded_dimensions(dim); + } } } return new_shape; @@ -1013,6 +1023,144 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, check_input_unit_indices(output_shape, input_shape); } +/* static */ tensorflow::gtl::optional ShapeUtil::AlignLayouts( + const Shape& input_shape, const Shape& output_shape) { + int64 input_rank = ShapeUtil::Rank(input_shape); + int64 output_rank = ShapeUtil::Rank(output_shape); + + // First, calculate an alignment of the dimensions. A consecutive sequence of + // input dimensions and output dimensions belong to the same alignment part if + // the products of their dimension bounds are the same. In the easiest case, + // an alignment part consists of one input dimension and one output dimension + // which both have the same dimension bound. An alignment part specifies which + // dimensions need to be kept together in a physical layout if we want a + // reshape to be a bitcast. The order of the alignment parts is defined by the + // physical layout of the input shape, so when we construct the layout for the + // output shape we just process the alignment parts in this order, and then + // layout the dimensions belonging to each part in descending (major to minor) + // order. + + // Stores the input and output dimension numbers where each alignment part + // starts. + std::vector> alignment; + alignment.push_back({0, 0}); + + // Stores a mapping from the input dimension to the alignment part it belongs + // to. + std::vector dimension_to_alignment_index(input_rank); + int64 input_dimension_product = 1, output_dimension_product = 1; + for (int64 i = 0, j = 0; i < input_rank || j < output_rank;) { + // Check if we have reached the end of an alignment part. + if (input_dimension_product == output_dimension_product && + input_dimension_product > 1) { + alignment.push_back({i, j}); + input_dimension_product = output_dimension_product = 1; + } + if (input_dimension_product < output_dimension_product || + j == output_rank) { + if (i == input_rank) { + return tensorflow::gtl::nullopt; + } + dimension_to_alignment_index[i] = alignment.size() - 1; + input_dimension_product *= input_shape.dimensions(i); + ++i; + } else { + output_dimension_product *= output_shape.dimensions(j); + ++j; + } + } + if (input_dimension_product != output_dimension_product) { + return tensorflow::gtl::nullopt; + } + // We also need to store an end element so that we know where the last + // alignment part ends. + alignment.push_back({input_rank, output_rank}); + + // Now check if the physical layout can potentially be aligned to the output + // shape by changing the physical layout of the output shape. We need to check + // that all dimension numbers that belong to the same alignment part appear + // consecutively, and are in descending order. However we can ignore any + // trivial dimension bounds of 1, because they can be placed anywhere. + auto input_dimension_numbers = input_shape.layout().minor_to_major(); + std::vector output_layout; + output_layout.reserve(output_rank); + for (int64 i = 0; i < input_rank;) { + int64 current_dimension_number = input_dimension_numbers[i]; + + // Skip trivial dimensions with a bound of 1. + if (input_shape.dimensions(current_dimension_number) == 1) { + ++i; + continue; + } + + // Calculate the number of non-trivial dimension bounds in the input shape + // belonging to the current alignment part. + const int64 current_alignment_index = + dimension_to_alignment_index[current_dimension_number]; + // Because of the special end element that we added, we can be sure that + // 'current_alignment_index' is < alignment.size() - 1. + CHECK_LT(current_alignment_index, alignment.size() - 1); + int64 num_non_trivial_dimensions_in_alignment_part = 0; + for (int64 j = alignment[current_alignment_index].first; + j < alignment[current_alignment_index + 1].first; ++j) { + if (input_shape.dimensions(j) != 1) { + ++num_non_trivial_dimensions_in_alignment_part; + } + } + + // Check that the following 'num_non_trivial_dimensions_in_alignment_part' + // dimension numbers (ignoring dimension numbers with dimension bound 1) are + // in descending order and belong to the current alignment part. + for (int64 j = 0; j < num_non_trivial_dimensions_in_alignment_part; + ++i, ++j) { + if (i == input_rank) { + return tensorflow::gtl::nullopt; + } + // Skip trivial dimensions with a bound of 1. + if (input_shape.dimensions(input_dimension_numbers[i]) == 1) { + --j; + continue; + } + // If the current dimension number belongs to a different alignment part, + // or the dimension numbers are not in descending order, we can return + // early. + if (dimension_to_alignment_index[input_dimension_numbers[i]] != + current_alignment_index || + input_dimension_numbers[i] > current_dimension_number) { + return tensorflow::gtl::nullopt; + } + current_dimension_number = input_dimension_numbers[i]; + } + + // The output dimension numbers that belong to the current alignment part + // need to appear in the same descending order as in the input. Again, we + // can skip dimensions with a bound of 1. + for (int64 j = alignment[current_alignment_index + 1].second - 1; + j >= alignment[current_alignment_index].second; --j) { + if (output_shape.dimensions(j) != 1) { + output_layout.push_back(j); + } + } + } + // Now add all the dimensions with dimension bound 1 at the end of + // 'output_layout'. + for (int64 i = 0; i < output_rank; ++i) { + if (output_shape.dimensions(i) == 1) { + output_layout.push_back(i); + } + } + CHECK_EQ(output_layout.size(), output_rank); + std::vector dimension_sizes; + for (int64 i = 0; i < output_rank; ++i) { + dimension_sizes.push_back(output_shape.dimensions(i)); + } + Shape output_shape_with_layout = MakeShapeWithLayout( + output_shape.element_type(), AsInt64Slice(output_shape.dimensions()), + output_layout); + CHECK(ReshapeIsBitcast(input_shape, output_shape_with_layout)); + return output_shape_with_layout; +} + /* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete, Shape shape) { shape.mutable_dimensions()->erase(shape.dimensions().begin() + dim_to_delete); @@ -1047,4 +1195,31 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, return shape; } +/* static */ void ShapeUtil::ForEachIndex( + const Shape& shape, tensorflow::gtl::ArraySlice base, + tensorflow::gtl::ArraySlice count, + tensorflow::gtl::ArraySlice incr, + const IndexVisitorFunction& visitor_function) { + DCHECK_EQ(Rank(shape), base.size()); + DCHECK_EQ(incr.size(), base.size()); + DCHECK_EQ(count.size(), base.size()); + const Layout& layout = shape.layout(); + int64 rank = layout.minor_to_major_size(); + // Allows handling R0 arrays, such that the visitor function will be called + // once with the proper empty indexes. + int64 n = -1; + std::vector indexes(base.begin(), base.end()); + while (n < rank && visitor_function(indexes)) { + // Increments dimensions in minor to major order. + for (n = 0; n < rank; ++n) { + int64 dim = layout.minor_to_major(n); + indexes[dim] += incr[dim]; + if (indexes[dim] < base[dim] + count[dim]) { + break; + } + indexes[dim] = base[dim]; + } + } +} + } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 68e138e6aca9d2cf157466eca1ea6960e3c448e8..aaf8e84cfecb89080d690c66acd4f8d50ee17d56 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -299,13 +300,14 @@ class ShapeUtil { // pre-order starting with the entire shape (index {}). using VisitorFunction = std::function; - static Status ForEachSubshape(const Shape& shape, VisitorFunction func); + static Status ForEachSubshape(const Shape& shape, + const VisitorFunction& func); // Mutating variant of ForEachSubshape. using MutatingVisitorFunction = std::function; static Status ForEachMutableSubshape(Shape* shape, - MutatingVisitorFunction func); + const MutatingVisitorFunction& func); // Removes all degenerate dimensions (size one) from the given shape. The // stripped minor_to_major preserves the relative ordering of non-degenerate @@ -377,6 +379,15 @@ class ShapeUtil { static bool ReshapeIsBitcast(const Shape& input_shape, const Shape& output_shape); + // Find a physical layout for 'output_shape' such that + // ShapeUtil::ReshapeIsBitcast(input_shape, output_shape_with_layout) returns + // true (where 'output_shape_with_layout' is 'output_shape' with the found + // layout). The layout of 'input_shape' is kept fixed. Returns + // 'output_shape_with_layout' if such a layout can be found, and an error + // otherwise. + static tensorflow::gtl::optional AlignLayouts( + const Shape& input_shape, const Shape& output_shape); + // Returns a shape with the given dimension deleted. // For example: // • `DeleteDimension(1, T[m, n, k]) = T[m, k]` @@ -390,6 +401,19 @@ class ShapeUtil { static Shape FilterDimensions(const std::function& p, Shape shape); + // Iterates through all the shape indexes, in minor to major order, starting + // from the base indexes, incrementing by the incr steps, up to count + // (index[i] < base[i] + count[i]), and calls the visitor_function with the + // current index. + // The visitor_function visitor function should return true if it wants to + // continue, or false otherwise. + using IndexVisitorFunction = std::function&)>; + static void ForEachIndex(const Shape& shape, + tensorflow::gtl::ArraySlice base, + tensorflow::gtl::ArraySlice count, + tensorflow::gtl::ArraySlice incr, + const IndexVisitorFunction& visitor_function); + private: // Validates all of the non-layout properties of the shape -- this is a helper // used by both the layout-optional and layout-required public method. diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index b0a4b0c9a71ae8564d80d41169a4b3ab6af82e79..73538b8b88ecf14c00854d3c31715af8189bc21d 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -20,10 +20,13 @@ limitations under the License. #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace { +using ::testing::ElementsAre; + TEST(ShapeUtilTest, GetDimensionHelperCanNegativeIndex) { Shape matrix = ShapeUtil::MakeShape(F32, {2, 3}); EXPECT_EQ(3, ShapeUtil::GetDimension(matrix, -1)); @@ -446,21 +449,21 @@ TEST(ShapeUtilTest, InsertedOrDeleted1SizedDimensions) { TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1x1_to_1x1x1) { // All output dimensions should be unmodified. One of the input dimensions is // modified because the input rank is larger by one. - EXPECT_EQ(3, - ShapeUtil::DimensionsUnmodifiedByReshape( - ShapeUtil::MakeShape(S32, {1, 1, 1, 1}), - ShapeUtil::MakeShape(S32, {1, 1, 1})) - .size()); + EXPECT_THAT(ShapeUtil::DimensionsUnmodifiedByReshape( + ShapeUtil::MakeShape(S32, {1, 1, 1, 1}), + ShapeUtil::MakeShape(S32, {1, 1, 1})), + ElementsAre(std::make_pair(0, 0), std::make_pair(1, 1), + std::make_pair(2, 2))); } TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1_to_1x1x1x1) { // All input dimensions should be unmodified. One of the output dimensions is // modified because the output rank is larger by one. - EXPECT_EQ(3, - ShapeUtil::DimensionsUnmodifiedByReshape( - ShapeUtil::MakeShape(S32, {1, 1, 1}), - ShapeUtil::MakeShape(S32, {1, 1, 1, 1})) - .size()); + EXPECT_THAT(ShapeUtil::DimensionsUnmodifiedByReshape( + ShapeUtil::MakeShape(S32, {1, 1, 1}), + ShapeUtil::MakeShape(S32, {1, 1, 1, 1})), + ElementsAre(std::make_pair(0, 0), std::make_pair(1, 1), + std::make_pair(2, 2))); } TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_4x1x3x5x6x7_to_2x6x1x5x1x42) { @@ -468,11 +471,10 @@ TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_4x1x3x5x6x7_to_2x6x1x5x1x42) { // 4, 1, 3, 5, 6, 7 // | // 2, 6, 1, 5, 1, 42 - EXPECT_TRUE( - ContainersEqual(ShapeUtil::DimensionsUnmodifiedByReshape( - ShapeUtil::MakeShape(S32, {4, 1, 3, 5, 6, 7}), - ShapeUtil::MakeShape(S32, {2, 6, 1, 5, 1, 42})), - std::vector>({{3, 3}}))); + EXPECT_THAT(ShapeUtil::DimensionsUnmodifiedByReshape( + ShapeUtil::MakeShape(S32, {4, 1, 3, 5, 6, 7}), + ShapeUtil::MakeShape(S32, {2, 6, 1, 5, 1, 42})), + ElementsAre(std::make_pair(3, 3))); } TEST(ShapeUtilTest, ReshapeIsBitcast_3x4_6x2) { @@ -521,5 +523,58 @@ TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) { ShapeUtil::MakeShapeWithLayout(F32, {6, 2}, {0, 1}))); } +TEST(AlignmentTest, AlignLayoutsWithoutTrivialDimensions) { + Shape input = ShapeUtil::MakeShapeWithLayout(xla::F32, {3, 8, 5, 7, 11}, + {3, 2, 1, 0, 4}); + auto aligned_shape = ShapeUtil::AlignLayouts( + input, ShapeUtil::MakeShape(xla::F32, {4, 3, 2, 7, 5, 11})); + EXPECT_TRUE(aligned_shape); + EXPECT_THAT(aligned_shape.value().layout().minor_to_major(), + ElementsAre(4, 3, 2, 1, 0, 5)); + EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast(input, aligned_shape.value())); + + aligned_shape = ShapeUtil::AlignLayouts( + input, ShapeUtil::MakeShape(xla::F32, {3, 2, 4, 35, 11})); + EXPECT_TRUE(aligned_shape); + EXPECT_THAT(aligned_shape.value().layout().minor_to_major(), + ElementsAre(3, 2, 1, 0, 4)); + EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast(input, aligned_shape.value())); +} + +TEST(AlignmentTest, AlignLayoutsWithTrivialDimensions) { + Shape input = + ShapeUtil::MakeShapeWithLayout(xla::F32, {1, 3, 8, 1, 5, 7, 1, 11, 1, 1}, + {5, 0, 4, 2, 1, 3, 6, 7, 9, 8}); + auto aligned_shape = ShapeUtil::AlignLayouts( + input, ShapeUtil::MakeShape(xla::F32, {1, 4, 1, 3, 2, 7, 5, 11, 1})); + EXPECT_TRUE(aligned_shape); + EXPECT_THAT(aligned_shape.value().layout().minor_to_major(), + ElementsAre(6, 5, 4, 3, 1, 7, 0, 2, 8)); + EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast(input, aligned_shape.value())); +} + +// A test case where the consecutive elements of the input shape belonging to +// the same layout part are not in descending order. +TEST(AlignmentTest, AlignLayoutsWithoutTrivialDimensionsWrongInputLayout) { + // Same physical layout as in AlignLayoutsWithoutTrivialDimensions, except + // that the first two dimension numbers are exchanged. + Shape input = ShapeUtil::MakeShapeWithLayout(xla::F32, {3, 8, 5, 7, 11}, + {2, 3, 1, 0, 4}); + auto aligned_shape = ShapeUtil::AlignLayouts( + input, ShapeUtil::MakeShape(xla::F32, {4, 3, 2, 7, 5, 11})); + EXPECT_FALSE(aligned_shape); +} + +// A test case where the physical layout of the input shape does not place all +// dimensions that belong to the same alignment part consecutively. +TEST(AlignmentTest, + AlignLayoutsWithoutTrivialDimensionsNonConsecutiveAlignmentPart) { + Shape input = ShapeUtil::MakeShapeWithLayout(xla::F32, {3, 8, 5, 7, 11}, + {3, 2, 1, 0, 4}); + auto aligned_shape = ShapeUtil::AlignLayouts( + input, ShapeUtil::MakeShape(xla::F32, {4, 3, 2, 5, 77})); + EXPECT_FALSE(aligned_shape); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/status_macros_test.cc b/tensorflow/compiler/xla/status_macros_test.cc index 5563159776d11fda83aef86efb2480952689ef9d..dead17cdfa1e9f19e0ecfbc071e74e159ae82b5f 100644 --- a/tensorflow/compiler/xla/status_macros_test.cc +++ b/tensorflow/compiler/xla/status_macros_test.cc @@ -73,7 +73,7 @@ Status ReturnStatusError() { return (tensorflow::errors::Internal("foobar")); } using StatusReturningFunction = std::function; -StatusOr CallStatusReturningFunction(StatusReturningFunction func) { +StatusOr CallStatusReturningFunction(const StatusReturningFunction& func) { TF_RETURN_IF_ERROR(func()); return 42; } diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 12bc1e995611c244d830c6306725f6b34fdafd12..e0c2b9ab09c28a7b7a31917b9250bdca8016d1e0 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -200,11 +200,13 @@ cc_library( "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/service:pool", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "//third_party/eigen3", ], ) @@ -891,6 +893,7 @@ xla_test( name = "copy_test", srcs = ["copy_test.cc"], deps = [ + ":client_library_test_base", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:util", @@ -1206,12 +1209,12 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", - "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/legacy_flags:hlo_pass_pipeline_flags", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", @@ -1361,6 +1364,7 @@ cc_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/service:computation_tracker", "//tensorflow/compiler/xla/service:local_service", + "//tensorflow/core:lib", "//tensorflow/core:test_main", ], ) diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index e682f285e03b8b48cbb1aae34edd738fc723a944..2c748b6a7ee5bcd53fa89dbc9064eef8e5ee94a3 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -486,6 +486,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) { ComputeAndCompareR1(&builder, {}, {}); } +TEST_F(ArrayElementwiseOpTest, CompareNeF32s) { + // Disable fast-math because we're operating on NaNs. + SetFastMathDisabled(true); + + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); + auto rhs = builder.ConstantR1({10.0f, 25.5f, 1.0f, 10.0f, NAN}); + auto compare = builder.Ne(lhs, rhs); + + ComputeAndCompareR1(&builder, {true, false, true, true, true}, {}); +} + TEST_F(ArrayElementwiseOpTest, CompareNeS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index c8030aade8c7c4d96658045f996801380289f2bf..0ad1cf3e8cfa69b07db18a80be093e44144b953c 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -127,6 +127,251 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) { ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } +struct R3ImplicitBroadcastSpec { + std::array output_bounds; + std::array minor2major_layout; + std::array input_bounds; + HloOpcode op; +} kR3ImplicitBroadcastTestCases[] = { + {{{1, 1, 1}}, {{2, 1, 0}}, {{1, 1, 1}}, HloOpcode::kAdd}, + {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 1, 5}}, HloOpcode::kMaximum}, + {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 4, 1}}, HloOpcode::kMinimum}, + {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 1, 1}}, HloOpcode::kMultiply}, + {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 1, 1}}, HloOpcode::kAdd}, + {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 4, 5}}, HloOpcode::kAdd}, + {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 4, 1}}, HloOpcode::kAdd}, + {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 1, 5}}, HloOpcode::kAdd}, + {{{3, 199, 5}}, {{2, 1, 0}}, {{1, 199, 1}}, HloOpcode::kMinimum}, + {{{3, 4, 199}}, {{2, 1, 0}}, {{1, 1, 199}}, HloOpcode::kAdd}, +}; + +class BroadcastR3ImplicitTest + : public BroadcastSimpleTest, + public ::testing::WithParamInterface {}; + +XLA_TEST_P(BroadcastR3ImplicitTest, Doit) { + const R3ImplicitBroadcastSpec& spec = GetParam(); + ComputationBuilder builder(client_, TestName()); + const Shape r3_shape = ShapeUtil::MakeShapeWithLayout( + F32, spec.output_bounds, spec.minor2major_layout); + Array3D r3_array(spec.output_bounds[0], spec.output_bounds[1], + spec.output_bounds[2]); + r3_array.FillRandom(1.0, 2.5, 56789); + auto r3_input = + LiteralUtil::Relayout(*LiteralUtil::CreateR3FromArray3D(r3_array), + LayoutUtil::MakeLayout(spec.minor2major_layout)); + std::unique_ptr r3_global_data = + client_->TransferToServer(*r3_input).ConsumeValueOrDie(); + + const Shape r3_implicit_shape = ShapeUtil::MakeShapeWithLayout( + F32, spec.input_bounds, spec.minor2major_layout); + Array3D r3_implicit_array(spec.input_bounds[0], spec.input_bounds[1], + spec.input_bounds[2]); + r3_implicit_array.FillRandom(1.0, 0.2, 56789); + auto r3_implicit_input = LiteralUtil::Relayout( + *LiteralUtil::CreateR3FromArray3D(r3_implicit_array), + LayoutUtil::MakeLayout(spec.minor2major_layout)); + std::unique_ptr r3_implicit_global_data = + client_->TransferToServer(*r3_implicit_input).ConsumeValueOrDie(); + + auto r3_implicit_parameter = builder.Parameter(0, r3_implicit_shape, "input"); + auto r3_parameter = builder.Parameter(1, r3_shape, "input"); + ComputationDataHandle op; + switch (spec.op) { + case HloOpcode::kMinimum: { + auto tmp_op = builder.Min(r3_implicit_parameter, r3_parameter); + op.Swap(&tmp_op); + break; + } + case HloOpcode::kMaximum: { + auto tmp_op = builder.Max(r3_implicit_parameter, r3_parameter); + op.Swap(&tmp_op); + break; + } + case HloOpcode::kMultiply: { + auto tmp_op = builder.Mul(r3_implicit_parameter, r3_parameter); + op.Swap(&tmp_op); + break; + } + default: { + // Default to Add + auto tmp_op = builder.Add(r3_implicit_parameter, r3_parameter); + op.Swap(&tmp_op); + } + } + + Array3D expected_array(spec.output_bounds[0], spec.output_bounds[1], + spec.output_bounds[2]); + auto Each = ([&](tensorflow::gtl::ArraySlice indices, float* value) { + float r3_implicit = r3_implicit_array(indices[0] % spec.input_bounds[0], + indices[1] % spec.input_bounds[1], + indices[2] % spec.input_bounds[2]); + float r3 = r3_array(indices[0], indices[1], indices[2]); + switch (spec.op) { + case HloOpcode::kMinimum: { + *value = std::min(r3_implicit, r3); + break; + } + case HloOpcode::kMaximum: { + *value = std::max(r3_implicit, r3); + break; + } + case HloOpcode::kMultiply: { + *value = r3_implicit * r3; + break; + } + default: { + // Default to Add + *value = r3_implicit + r3; + break; + } + } + }); + + int n1 = expected_array.n1(); + int n2 = expected_array.n2(); + int n3 = expected_array.n3(); + for (int64 i = 0; i < n1; i++) { + for (int64 j = 0; j < n2; j++) { + for (int64 k = 0; k < n3; k++) { + Each({i, j, k}, &expected_array(i, j, k)); + } + } + } + auto expected = LiteralUtil::CreateR3FromArray3D(expected_array); + ComputeAndCompareLiteral( + &builder, *expected, + {r3_implicit_global_data.get(), r3_global_data.get()}, + ErrorSpec(1e-7, 1e-7)); +} + +INSTANTIATE_TEST_CASE_P(BroadcastR3ImplicitTestInstances, + BroadcastR3ImplicitTest, + ::testing::ValuesIn(kR3ImplicitBroadcastTestCases)); + +// r1 and r3's dim0 matches, and r1's dim1 and dim2 have size 1: +XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { + ComputationBuilder b(client_, TestName()); + ComputationDataHandle r1h; + ComputationDataHandle r3h; + + Array3D r1d = {{{1}}, {{2}}}; + Array3D r3d = {{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}; + auto r1 = CreateR3Parameter(r1d, 1, "r1", &b, &r1h); + auto r3 = CreateR3Parameter(r3d, 0, "r3", &b, &r3h); + + b.Add(r3h, r1h); + + auto expected = + LiteralUtil::CreateR3({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}}); + + ComputeAndCompareLiteral(&b, *expected, {r3.get(), r1.get()}, + ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) { + ComputationBuilder b(client_, TestName()); + auto r1 = b.ConstantLiteral(*LiteralUtil::CreateR3({{{1, 2}}})); + auto r3 = b.ConstantLiteral( + *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + b.Add(r3, r1); + + auto expected = + LiteralUtil::CreateR3({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}}); + + ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) { + ComputationBuilder b(client_, TestName()); + auto r1 = b.ConstantLiteral(*LiteralUtil::CreateR3({{{1}, {2}}})); + auto r3 = b.ConstantLiteral( + *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + b.Add(r3, r1); + + auto expected = + LiteralUtil::CreateR3({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}}); + + ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) { + ComputationBuilder b(client_, TestName()); + auto r1 = + b.ConstantLiteral(*LiteralUtil::CreateR3({{{1, 2}, {3, 4}}})); + auto r3 = b.ConstantLiteral( + *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + b.Add(r3, r1); + + auto expected = + LiteralUtil::CreateR3({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}}); + + ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) { + ComputationBuilder b(client_, TestName()); + auto r1 = + b.ConstantLiteral(*LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}})); + auto r3 = b.ConstantLiteral( + *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + b.Add(r3, r1); + + auto expected = + LiteralUtil::CreateR3({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}}); + + ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) { + ComputationBuilder b(client_, TestName()); + auto r1 = b.ConstantLiteral( + *LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}})); + auto r3 = b.ConstantLiteral( + *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + b.Add(r3, r1); + + auto expected = + LiteralUtil::CreateR3({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}}); + + ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) { + ComputationBuilder b(client_, TestName()); + auto r1 = b.ConstantLiteral(*LiteralUtil::CreateR3({{{1}}})); + auto r3 = b.ConstantLiteral( + *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + b.Add(r3, r1); + + auto expected = + LiteralUtil::CreateR3({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}}); + + ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) { + ComputationBuilder b(client_, TestName()); + auto r1 = b.ConstantLiteral(*LiteralUtil::CreateR2({{1, 2}})); + auto r2 = b.ConstantLiteral(*LiteralUtil::CreateR2({{1, 2}, {3, 4}})); + b.Add(r2, r1); + + auto expected = LiteralUtil::CreateR2({{2, 4}, {4, 6}}); + + ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) { + ComputationBuilder b(client_, TestName()); + auto r1 = b.ConstantLiteral(*LiteralUtil::CreateR2({{1}, {2}})); + auto r2 = b.ConstantLiteral(*LiteralUtil::CreateR2({{1, 2}, {3, 4}})); + b.Add(r2, r1); + + auto expected = LiteralUtil::CreateR2({{2, 3}, {5, 6}}); + + ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); +} + XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) { ComputationBuilder b(client_, TestName()); auto r1 = b.ConstantR1({10, 20}); diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index 4170e0f4e2942bc71ddfa3d0f3a9d86ce2ecc823..1d998fe33ebf71a2b35f99a51038e874edacc046 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -17,18 +17,18 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" -#include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" @@ -38,21 +38,60 @@ limitations under the License. namespace xla { namespace { -class ComputeConstantTest : public ClientLibraryTestBase { +// An enumerator for the client types that we want to iterate over in +// the various tests. +enum class ClientType { kLocal, kCompileOnly }; +ClientType client_types[] = {ClientType::kLocal, ClientType::kCompileOnly}; + +class ComputeConstantTest : public ::testing::Test { public: + explicit ComputeConstantTest( + perftools::gputools::Platform* platform = nullptr, + tensorflow::gtl::ArraySlice disabled_pass_names = {}) + : platform_(platform) { + legacy_flags::HloPassPipelineFlags* flags = + legacy_flags::GetHloPassPipelineFlags(); + flags->xla_disable_hlo_passes = + tensorflow::str_util::Join(disabled_pass_names, ","); + } + + string TestName() const { + return ::testing::UnitTest::GetInstance()->current_test_info()->name(); + } + + Client* ClientOrDie(::perftools::gputools::Platform* platform, + ClientType client_type) { + if (client_type == ClientType::kLocal) { + StatusOr result = + ClientLibrary::GetOrCreateLocalClient(platform); + TF_CHECK_OK(result.status()) + << "could not create LocalClient for testing"; + return result.ValueOrDie(); + } else if (client_type == ClientType::kCompileOnly) { + StatusOr result = + ClientLibrary::GetOrCreateCompileOnlyClient(platform); + TF_CHECK_OK(result.status()) + << "could not create CompileOnlyClient for testing"; + return result.ValueOrDie(); + } + LOG(FATAL) << "invalid client_type value"; + } + StatusOr> ComputeConstantLiteral( - ComputationDataHandle operand, ComputationBuilder* builder, - Layout* output_layout = nullptr) { + Client* client, const ComputationDataHandle& operand, + ComputationBuilder* builder, Layout* output_layout = nullptr) { TF_ASSIGN_OR_RETURN(auto remote_computed, builder->ComputeConstant(operand, output_layout)); - TF_ASSIGN_OR_RETURN(auto computed, client_->Transfer(*remote_computed)); + TF_ASSIGN_OR_RETURN(auto computed, client->Transfer(*remote_computed)); return std::move(computed); } template - StatusOr ComputeConstantScalar(ComputationDataHandle operand, + StatusOr ComputeConstantScalar(Client* client, + const ComputationDataHandle& operand, ComputationBuilder* builder) { - TF_ASSIGN_OR_RETURN(auto literal, ComputeConstantLiteral(operand, builder)); + TF_ASSIGN_OR_RETURN(auto literal, + ComputeConstantLiteral(client, operand, builder)); return LiteralUtil::Get(*literal, {}); } @@ -63,163 +102,188 @@ class ComputeConstantTest : public ClientLibraryTestBase { return result.ok() ? result.ValueOrDie() : false; } - template - void ExpectConstantComputedScalar(ComputationDataHandle operand, - Scalar expected, - ComputationBuilder* builder) { - Scalar computed = ComputeConstantScalar(operand, builder); - ASSERT_TRUE(computed.ok()) << computed.status(); - std::unique_ptr expected_literal = LiteralUtil::CreateR0(expected); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed); - } + perftools::gputools::Platform* platform_; }; TEST_F(ComputeConstantTest, ScalarInt32Literal) { - ComputationBuilder b(client_, TestName()); - auto computation = b.ConstantR0(42); - EXPECT_TRUE(IsConstant(computation, &b)); - - auto value = ComputeConstantScalar(computation, &b); - ASSERT_TRUE(value.ok()) << value.status(); - EXPECT_EQ(value.ValueOrDie(), 42); + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + ComputationBuilder b(client, TestName()); + auto computation = b.ConstantR0(42); + EXPECT_TRUE(IsConstant(computation, &b)); + + auto value = ComputeConstantScalar(client, computation, &b); + ASSERT_TRUE(value.ok()) << value.status(); + EXPECT_EQ(value.ValueOrDie(), 42); + } } TEST_F(ComputeConstantTest, ScalarFloatAdd) { - ComputationBuilder b(client_, TestName()); - auto computation = - b.Add(b.ConstantR0(42.5f), b.ConstantR0(1.5f)); - EXPECT_TRUE(IsConstant(computation, &b)); - - auto value = ComputeConstantScalar(computation, &b); - ASSERT_TRUE(value.ok()) << value.status(); - EXPECT_EQ(value.ValueOrDie(), 44.0f); + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + ComputationBuilder b(client, TestName()); + auto computation = + b.Add(b.ConstantR0(42.5f), b.ConstantR0(1.5f)); + EXPECT_TRUE(IsConstant(computation, &b)); + + auto value = ComputeConstantScalar(client, computation, &b); + ASSERT_TRUE(value.ok()) << value.status(); + EXPECT_EQ(value.ValueOrDie(), 44.0f); + } } TEST_F(ComputeConstantTest, ScalarRng) { - ComputationBuilder b(client_, TestName()); - auto computation = - b.RngUniform(b.ConstantR0(1.1f), b.ConstantR0(2.1f), - ShapeUtil::MakeShape(F32, {})); - EXPECT_FALSE(IsConstant(computation, &b)); - - auto value = ComputeConstantScalar(computation, &b); - ASSERT_FALSE(value.ok()) - << "computing a RNG value should not be considered a constant"; + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + ComputationBuilder b(client, TestName()); + auto computation = + b.RngUniform(b.ConstantR0(1.1f), b.ConstantR0(2.1f), + ShapeUtil::MakeShape(F32, {})); + EXPECT_FALSE(IsConstant(computation, &b)); + + auto value = ComputeConstantScalar(client, computation, &b); + ASSERT_FALSE(value.ok()) + << "computing a RNG value should not be considered a constant"; + } } TEST_F(ComputeConstantTest, DirectParam) { - ComputationBuilder b(client_, TestName()); - auto computation = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param"); - EXPECT_FALSE(IsConstant(computation, &b)); - - auto value = ComputeConstantScalar(computation, &b); - EXPECT_TRUE(tensorflow::StringPiece(value.status().ToString()) - .contains("depends on parameter")) - << value.status(); + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + ComputationBuilder b(client, TestName()); + auto computation = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param"); + EXPECT_FALSE(IsConstant(computation, &b)); + + auto value = ComputeConstantScalar(client, computation, &b); + EXPECT_TRUE(tensorflow::StringPiece(value.status().ToString()) + .contains("depends on parameter")) + << value.status(); + } } TEST_F(ComputeConstantTest, IndirectParam) { - ComputationBuilder b(client_, TestName()); - auto computation = - b.Add(b.ConstantR0(1.0f), - b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param")); - EXPECT_FALSE(IsConstant(computation, &b)); - - auto value = ComputeConstantScalar(computation, &b); - EXPECT_TRUE(tensorflow::StringPiece(value.status().ToString()) - .contains("depends on parameter")) - << value.status(); + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + ComputationBuilder b(client, TestName()); + auto computation = + b.Add(b.ConstantR0(1.0f), + b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param")); + EXPECT_FALSE(IsConstant(computation, &b)); + + auto value = ComputeConstantScalar(client, computation, &b); + EXPECT_TRUE(tensorflow::StringPiece(value.status().ToString()) + .contains("depends on parameter")) + << value.status(); + } } // Test computation of an expression interspersed with param nodes but // the expression does not depend on the param nodes. TEST_F(ComputeConstantTest, UnrelatedParam) { - ComputationBuilder b(client_, TestName()); + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + ComputationBuilder b(client, TestName()); - auto param_a = b.Parameter(10, ShapeUtil::MakeShape(F32, {}), "param0"); - auto constant_4 = b.Add(b.ConstantR0(2.5f), b.ConstantR0(1.5f)); - auto not_constant_a = b.Add(constant_4, param_a); + auto param_a = b.Parameter(10, ShapeUtil::MakeShape(F32, {}), "param0"); + auto constant_4 = + b.Add(b.ConstantR0(2.5f), b.ConstantR0(1.5f)); + auto not_constant_a = b.Add(constant_4, param_a); - auto param_b = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "param1"); - auto constant_9 = b.Mul(b.ConstantR0(2.0f), b.ConstantR0(4.5f)); - auto not_constant_b = b.Add(param_b, constant_9); + auto param_b = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "param1"); + auto constant_9 = + b.Mul(b.ConstantR0(2.0f), b.ConstantR0(4.5f)); + auto not_constant_b = b.Add(param_b, constant_9); - auto constant_13 = b.Add(constant_4, constant_9); - b.Add(not_constant_b, b.Add(constant_13, not_constant_a)); + auto constant_13 = b.Add(constant_4, constant_9); + b.Add(not_constant_b, b.Add(constant_13, not_constant_a)); - EXPECT_TRUE(IsConstant(constant_13, &b)); + EXPECT_TRUE(IsConstant(constant_13, &b)); - auto value = ComputeConstantScalar(constant_13, &b); - ASSERT_TRUE(value.ok()) << value.status(); - EXPECT_EQ(value.ValueOrDie(), 13.0f); + auto value = ComputeConstantScalar(client, constant_13, &b); + ASSERT_TRUE(value.ok()) << value.status(); + EXPECT_EQ(value.ValueOrDie(), 13.0f); + } } TEST_F(ComputeConstantTest, NonScalarAdd) { - ComputationBuilder b(client_, TestName()); + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + ComputationBuilder b(client, TestName()); - auto computation = - b.Add(b.ConstantR1({1, 2}), b.ConstantR1({3, 4})); - EXPECT_TRUE(IsConstant(computation, &b)); + auto computation = + b.Add(b.ConstantR1({1, 2}), b.ConstantR1({3, 4})); + EXPECT_TRUE(IsConstant(computation, &b)); - auto computed = ComputeConstantLiteral(computation, &b); - ASSERT_TRUE(computed.ok()) << computed.status(); - std::unique_ptr expected_literal = - LiteralUtil::CreateR1({4, 6}); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); + auto computed = ComputeConstantLiteral(client, computation, &b); + ASSERT_TRUE(computed.ok()) << computed.status(); + std::unique_ptr expected_literal = + LiteralUtil::CreateR1({4, 6}); + LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); + } } TEST_F(ComputeConstantTest, IntegerDivide) { - ComputationBuilder b(client_, TestName()); - auto computation = b.Div(b.ConstantR0(15), b.ConstantR0(3)); - EXPECT_TRUE(IsConstant(computation, &b)); - - auto computed = ComputeConstantLiteral(computation, &b); - ASSERT_TRUE(computed.ok()) << computed.status(); - std::unique_ptr expected_literal = LiteralUtil::CreateR0(5); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); -} + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + ComputationBuilder b(client, TestName()); + auto computation = b.Div(b.ConstantR0(15), b.ConstantR0(3)); + EXPECT_TRUE(IsConstant(computation, &b)); -XLA_TEST_F(ComputeConstantTest, Layout) { - ComputationBuilder b(client_, TestName()); - - std::vector> layouts = {{0, 1}, {1, 0}}; - for (const std::vector& layout : layouts) { - auto layout_proto = LayoutUtil::MakeLayout(layout); - auto computed = - ComputeConstantLiteral(b.Add(b.ConstantR2({{1, 2}, {3, 4}}), - b.ConstantR2({{10, 20}, {30, 40}})), - &b, &layout_proto); + auto computed = ComputeConstantLiteral(client, computation, &b); ASSERT_TRUE(computed.ok()) << computed.status(); - - std::unique_ptr expected_literal = - test_utils::CreateR2LiteralWithLayout({{11, 22}, {33, 44}}, - layout); - LiteralTestUtil::AssertEqualShapesAndLayouts( - expected_literal->shape(), computed.ValueOrDie()->shape()); + std::unique_ptr expected_literal = LiteralUtil::CreateR0(5); LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); } } +XLA_TEST_F(ComputeConstantTest, Layout) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + ComputationBuilder b(client, TestName()); + + std::vector> layouts = {{0, 1}, {1, 0}}; + for (const std::vector& layout : layouts) { + auto layout_proto = LayoutUtil::MakeLayout(layout); + auto computed = ComputeConstantLiteral( + client, + b.Add(b.ConstantR2({{1, 2}, {3, 4}}), + b.ConstantR2({{10, 20}, {30, 40}})), + &b, &layout_proto); + ASSERT_TRUE(computed.ok()) << computed.status(); + + std::unique_ptr expected_literal = + test_utils::CreateR2LiteralWithLayout({{11, 22}, {33, 44}}, + layout); + LiteralTestUtil::AssertEqualShapesAndLayouts( + expected_literal->shape(), computed.ValueOrDie()->shape()); + LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); + } + } +} + // This test is permanently disabled on CPU because it requires that the // backend used for execution is different than the backend used for // ComputeConstant which is always cpu. TEST_F(ComputeConstantTest, DISABLED_ON_CPU(ReuseComputedConstant)) { // Compute a trivial constant, then try to use the value in an Execute // call. This should fail because the constant resides on the CPU and the - // Execute call is executed on a different backend. - ComputationBuilder constant_b(client_, TestName()); + // Execute call is executed on a different backend. This test only makes + // sense with LocalClient, since CompileOnlyClient does not support + // execution. + Client* client = ClientOrDie(platform_, ClientType::kLocal); + ComputationBuilder constant_b(client, TestName()); auto constant = constant_b.ConstantR0(42); auto handle = constant_b.ComputeConstant(constant).ConsumeValueOrDie(); - auto literal = client_->Transfer(*handle).ConsumeValueOrDie(); + auto literal = client->Transfer(*handle).ConsumeValueOrDie(); LiteralTestUtil::ExpectR0Equal(42, *literal); // Build trivial computation which takes one parameter. - ComputationBuilder b(client_, TestName()); + ComputationBuilder b(client, TestName()); b.Neg(b.Parameter(0, ShapeUtil::MakeShape(S32, {}), "param0")); auto computation = b.Build().ConsumeValueOrDie(); // Try to use value from ComputeConstant in Execute. - auto execute_status = client_->Execute(computation, {handle.get()}); + auto execute_status = client->Execute(computation, {handle.get()}); EXPECT_FALSE(execute_status.ok()); EXPECT_THAT( execute_status.status().error_message(), diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index e645e2336190c706912f94c0662bca08f5dc281a..63bfac441d3c1f7aa257a7f9fc81df98f47551d5 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -57,6 +57,15 @@ XLA_TEST_F(ConcatTest, Concat_R1_With_Nothing) { ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } +XLA_TEST_F(ConcatTest, Concat_R1_L0_With_Nothing) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto concatenated = builder.ConcatInDim({a}, 0); + + std::vector expected = {}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + // Show that we can't concatenate R0 with R0 because we can't name the dimension // to concatenate on. XLA_TEST_F(ConcatTest, CannotConcatR0WithR0) { diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 29e29505333b64926cdd0b3e9fe7ef3407eaaec2..8ea97e67d640d97baa70cddf60f3336a8849552a 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -256,6 +257,22 @@ XLA_TEST_F(CopyOpTest, CopyConstantR4Layout0312_MultipleTilesPerLayer) { TestCopyConstantLayoutR4(2, 14, 5, 35, {0, 3, 1, 2}); } +using CopyOpClientTest = ClientLibraryTestBase; + +XLA_TEST_F(CopyOpClientTest, Copy0x0) { + Shape in_shape = ShapeUtil::MakeShapeWithLayout(F32, {0, 0}, {0, 1}); + Shape out_shape = ShapeUtil::MakeShapeWithLayout(F32, {0, 0}, {1, 0}); + auto empty = LiteralUtil::CreateFromShape(in_shape); + + ComputationBuilder builder(client_, TestName()); + auto param0 = builder.Parameter(0, in_shape, "input"); + auto input_data = client_->TransferToServer(*empty).ConsumeValueOrDie(); + + auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape) + .ConsumeValueOrDie(); + LiteralTestUtil::ExpectEqual(*empty, *actual); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index dc54c9defec2394049c38781a8d02fc8bd05158a..8b5b38b0b4b9d91f9491648e9c6ee6301ed74ff7 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -29,22 +29,23 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/test.h" -extern "C" void __attribute__((visibility("default"))) +extern "C" void TF_EXPORT R0F32Add2(float* out, float** in) { TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float*)); *out = **in + 2.0f; } -extern "C" void __attribute__((visibility("default"))) +extern "C" void TF_EXPORT R2F32ReduceSum(float* out, float** in) { TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float) * 4); float* array = in[0]; *out = array[0] + array[1] + array[2] + array[3]; } -extern "C" void __attribute__((visibility("default"))) +extern "C" void TF_EXPORT Add1ToValues(float* out, float** in) { TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float) * 4); float* array = in[0]; diff --git a/tensorflow/compiler/xla/tests/deallocation_test.cc b/tensorflow/compiler/xla/tests/deallocation_test.cc index 94f34f753b7ff8799cf9b505e1a762c9ba640389..cc3c4a2a5e115d7791e8574f4ead17f77dcd5e7c 100644 --- a/tensorflow/compiler/xla/tests/deallocation_test.cc +++ b/tensorflow/compiler/xla/tests/deallocation_test.cc @@ -52,7 +52,7 @@ TEST_F(DeallocationTest, DeallocateScalar) { builder.ConstantR0(42.0); auto global_data = ExecuteAndCheckTransfer(&builder, {}); - // A result can be transfered an arbitrary number of times. Add an extra + // A result can be transferred an arbitrary number of times. Add an extra // transfer here so we're not just testing that a second call to Transfer // fails. ASSERT_IS_OK(client_->Transfer(*global_data).status()); diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 180e8514102d115a169b327a26a544bbeb1c8499..cdb4498f4ed1e4f7fb2ad7a29a1cec4e26b76ed3 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -109,7 +109,7 @@ class DynamicSliceTest : public ClientLibraryTestBase { template void RunR1(const std::vector& input_values, const std::vector slice_starts, - const std::vector slice_sizes, + const std::vector& slice_sizes, const std::vector& expected_values) { ComputationBuilder builder(client_, TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -127,7 +127,7 @@ class DynamicSliceTest : public ClientLibraryTestBase { template void RunR2(const Array2D& input_values, const std::vector slice_starts, - const std::vector slice_sizes, + const std::vector& slice_sizes, const Array2D& expected_values) { ComputationBuilder builder(client_, TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -145,7 +145,7 @@ class DynamicSliceTest : public ClientLibraryTestBase { template void RunR3(const Array3D& input_values, const std::vector slice_starts, - const std::vector slice_sizes, + const std::vector& slice_sizes, const Array3D& expected_values) { ComputationBuilder builder(client_, TestName()); // Initialize and transfer dynamic slice start indices parameter. diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 4e956bc00c8fcbf0cd200bc2ae5b8f4ccfe63694..f741ff38b55933291e6b0c942efc4a37c61a8f4b 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -111,8 +111,9 @@ StatusOr HloTestBase::Execute( backend_->eigen_intra_op_thread_pool_device()); HloExecutionProfile hlo_execution_profile; - ServiceExecutableRunOptions service_run_options(run_options, - backend_->StreamBorrower()); + ServiceExecutableRunOptions service_run_options( + run_options, backend_->StreamBorrower(), + backend_->inter_op_thread_pool()); TF_ASSIGN_OR_RETURN( se::DeviceMemoryBase result, executable->ExecuteOnStream(&service_run_options, arguments, diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index ef81db6fd66502f9debf180b418d9c30917109aa..23453db57bc4a5db0d3a4f7c327e3313333d1ae2 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -314,7 +314,7 @@ class NearComparator { private: // EXPECTs that the two given scalar values are within the error bound. Keeps - // track of how many mismatches have occured to keep the size of the output + // track of how many mismatches have occurred to keep the size of the output // manageable. template bool ExpectValuesNear(NativeT expected, NativeT actual) { diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index aeadc023cc0649cb8e69c3aa981d7f347b3a1a1f..4f98083033310baf6ec95de0d2331d1aff8f3f7d 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include "tensorflow/compiler/xla/array2d.h" @@ -171,6 +172,36 @@ class LiteralTestUtil { tensorflow::gtl::ArraySlice minor_to_major, const Literal& literal); + // Creates a literal with the supplied shape, and uses the provided value + // generator to populate the literal's values. + // Returns the new literal object, or an error Status if failed. + template < + PrimitiveType type, + typename T = typename primitive_util::PrimitiveTypeToNative::type> + static StatusOr> CreateRandomLiteral( + const Shape& shape, + const std::function)>& generator); + + // Creates a literal with the supplied shape, and initializes the literal + // values using a normal distribution with given mean and stddev standard + // deviation, and using the engine as entropy generator. + // Returns the new literal object, or an error Status if failed. + template < + PrimitiveType type, typename E, + typename T = typename primitive_util::PrimitiveTypeToNative::type> + static StatusOr> CreateRandomLiteral( + const Shape& shape, E* engine, T mean, T stddev); + + // Creates a literal with the supplied shape, and initializes the literal + // values using a normal distribution with given mean and stddev standard + // deviation. + // Returns the new literal object, or an error Status if failed. + template < + PrimitiveType type, + typename T = typename primitive_util::PrimitiveTypeToNative::type> + static StatusOr> CreateRandomLiteral( + const Shape& shape, T mean, T stddev); + private: TF_DISALLOW_COPY_AND_ASSIGN(LiteralTestUtil); }; @@ -270,6 +301,40 @@ template ExpectNear(*LiteralUtil::CreateR4FromArray4D(expected), actual, error); } +template +/* static */ StatusOr> +LiteralTestUtil::CreateRandomLiteral( + const Shape& shape, + const std::function)>& generator) { + using NativeT = typename primitive_util::PrimitiveTypeToNative::type; + TF_RET_CHECK(shape.element_type() == type); + std::unique_ptr literal = LiteralUtil::CreateFromShape(shape); + TF_RETURN_IF_ERROR(LiteralUtil::Populate( + literal.get(), [&](tensorflow::gtl::ArraySlice indexes) { + return generator(indexes); + })); + return std::move(literal); +} + +template +/* static */ StatusOr> +LiteralTestUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean, + T stddev) { + using NativeT = typename primitive_util::PrimitiveTypeToNative::type; + std::normal_distribution generator(mean, stddev); + return CreateRandomLiteral( + shape, [&](tensorflow::gtl::ArraySlice /*indexes*/) { + return generator(*engine); + }); +} + +template +/* static */ StatusOr> +LiteralTestUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) { + std::minstd_rand0 engine; + return CreateRandomLiteral(shape, &engine, mean, stddev); +} + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_ diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc index 7ea83a9e956ca8b5bb26ea6aaa844d2b63107328..52816dc72cc4d094054b2aea72f0cc63c7ff478d 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc @@ -42,7 +42,7 @@ xla::Computation Doubler(xla::Client* client) { int main(int argc, char** argv) { tensorflow::port::InitMain(argv[0], &argc, &argv); - auto client = xla::ClientLibrary::LocalClientOrDie(); + auto client = xla::ClientLibrary::GetOrCreateCompileOnlyClient().ValueOrDie(); xla::ComputationBuilder builder(client, "aot_test_helper"); auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape(); @@ -74,7 +74,7 @@ int main(int argc, char** argv) { llvm::Triple triple(xla::llvm_ir::AsStringRef(triple_string)); xla::Computation computation = builder.Build().ConsumeValueOrDie(); - xla::LocalClient::AheadOfTimeComputationInstance instance{ + xla::CompileOnlyClient::AotComputationInstance instance{ &computation, /*argument_layouts=*/{&opaque_shape}, &r0f32}; xla::cpu::CpuAotCompilationOptions options( diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index 7fe4c9020f4c67ecc9888425cf0a2c358ad49e6d..7fcf687655a98d3ee972f8d3b784be655410a313 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -17,12 +17,19 @@ limitations under the License. #include +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/core/common_runtime/eigen_thread_pool.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -91,16 +98,34 @@ int64 TestAllocator::deallocation_count(int device_ordinal) const { return allocator_; } +// Define this in .cc file to avoid having to include eigen or forward declare +// these types in the header. +struct LocalClientTestBase::EigenThreadPoolWrapper { + explicit EigenThreadPoolWrapper() + : pool(new tensorflow::thread::ThreadPool( + tensorflow::Env::Default(), "XLAEigenTest", /*num_threads=*/2)), + wrapper(new tensorflow::EigenThreadPoolWrapper(pool.get())), + device(new Eigen::ThreadPoolDevice(wrapper.get(), + wrapper->NumThreads())) {} + + std::unique_ptr pool; + std::unique_ptr wrapper; + std::unique_ptr device; +}; + LocalClientTestBase::LocalClientTestBase( perftools::gputools::Platform* platform) : local_client_( - ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie()) { + ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie()), + thread_pool_wrapper_(new EigenThreadPoolWrapper()) { stream_executor_ = PlatformUtil::GetStreamExecutors(local_client_->platform()) .ValueOrDie()[local_client_->default_device_ordinal()]; transfer_manager_ = TransferManager::GetForPlatform(local_client_->platform()).ValueOrDie(); } +LocalClientTestBase::~LocalClientTestBase() {} + std::unique_ptr LocalClientTestBase::LiteralToScopedShapedBuffer(const Literal& literal) { return LiteralToScopedShapedBuffer(literal, @@ -190,8 +215,7 @@ ExecutableRunOptions LocalClientTestBase::DefaultExecutableRunOptions() const { ExecutableRunOptions run_options; run_options.set_inter_op_thread_pool( local_client_->backend().inter_op_thread_pool()); - run_options.set_intra_op_thread_pool( - local_client_->backend().eigen_intra_op_thread_pool_device()); + run_options.set_intra_op_thread_pool(thread_pool_wrapper_->device.get()); run_options.set_allocator(GetOrCreateAllocator(local_client_->platform())); return run_options; } diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h index 4e7b05cea60887eec628ce9b4848321e721030e5..e3c3bb46cf26cc742b7abb39a3e457d823d829ec 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.h +++ b/tensorflow/compiler/xla/tests/local_client_test_base.h @@ -74,8 +74,10 @@ class TestAllocator : public StreamExecutorMemoryAllocator { // A base class for tests which exercise the LocalClient interface. class LocalClientTestBase : public ::testing::Test { protected: + struct EigenThreadPoolWrapper; explicit LocalClientTestBase( perftools::gputools::Platform* platform = nullptr); + virtual ~LocalClientTestBase(); static TestAllocator* GetOrCreateAllocator( perftools::gputools::Platform* platform); @@ -142,6 +144,8 @@ class LocalClientTestBase : public ::testing::Test { TransferManager* transfer_manager_; LocalClient* local_client_; + + std::unique_ptr thread_pool_wrapper_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 0cd0f97b0621d771ae039f0be6bd6c67161b49a4..5a6aa467e54f31b57d04b9c1f0cf82cd6295903d 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -55,7 +56,7 @@ void PrngTest::UniformTest(T a, T b, tensorflow::gtl::ArraySlice dims) { SetSeed(42); auto actual = ExecuteAndTransferOrDie(&builder, /*arguments=*/{}); - EXPECT_TRUE(ContainersEqual(dims, actual->shape().dimensions())); + EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions())); LiteralUtil::EachCell(*actual, [=](tensorflow::gtl::ArraySlice, T value) { EXPECT_LE(a, value); @@ -75,7 +76,7 @@ void PrngTest::BernoulliTest(float p, tensorflow::gtl::ArraySlice dims) { auto actual, client_->ExecuteAndTransfer(computation, /*arguments=*/{}, &execution_options)); - EXPECT_TRUE(ContainersEqual(dims, actual->shape().dimensions())); + EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions())); int32 sum = 0; LiteralUtil::EachCell( *actual, [&sum](tensorflow::gtl::ArraySlice, uint32 value) { diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index d00a3175344dffcab08116678a8c46782aa8cf64..feb2b465fca6b1ffda190025568470e8daf297a3 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -61,7 +61,7 @@ namespace { class ReduceTest : public ClientLibraryTestBase { protected: ReduceTest() { - // Implementation note: layed out z >> y >> x by default. + // Implementation note: laid out z >> y >> x by default. // clang-format off literal_2d_ = LiteralUtil::CreateR2({ // x0 x1 x2 diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 56501e43b5c5d965ea4305f2ca88909b253ed273..c3b768579a401706eff4a2a24d840da84080d26d 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -43,7 +43,7 @@ class ReduceWindowTest : public ClientLibraryTestBase { public: ReduceWindowTest() : builder_(client_, TestName()) {} - void ReduceWindowAdd(ComputationDataHandle input, + void ReduceWindowAdd(const ComputationDataHandle& input, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { @@ -52,7 +52,7 @@ class ReduceWindowTest : public ClientLibraryTestBase { window_dimensions, window_strides, padding); } - void ReduceWindowMax(ComputationDataHandle input, + void ReduceWindowMax(const ComputationDataHandle& input, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { @@ -61,7 +61,7 @@ class ReduceWindowTest : public ClientLibraryTestBase { CreateScalarMax(), window_dimensions, window_strides, padding); } - void ReduceWindowMin(ComputationDataHandle input, + void ReduceWindowMin(const ComputationDataHandle& input, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { @@ -182,6 +182,7 @@ TEST_F(ReduceWindowTest, DISABLED_AmongMajor2DimsMediumSizeLargePadding) { ComputeAndCompareR4(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3)); } + // TODO(b/31809540): Implement minor dim reduction to reduce num of reshapes. TEST_F(ReduceWindowTest, ReduceR4AmongXYMinorSmall) { Array4D input_array(2, 2, 4, 16); @@ -368,6 +369,16 @@ TEST_F(ReduceWindowTest, Add2x2In2x2Disjoint) { ComputeAndCompareR2(&builder_, expected, {}, ErrorSpec(0.0001)); } +TEST_F(ReduceWindowTest, Add1x2In2x2Same) { + Array2D input_array({{1.0f, 2.0f}, {3.0f, 4.0f}}); + auto input = builder_.ConstantR2FromArray2D(input_array); + ReduceWindowAdd(input, {1, 2}, {1, 1}, Padding::kSame); + Array2D expected({ + {3.0f, 2.0f}, {7.0f, 4.0f}, + }); + ComputeAndCompareR2(&builder_, expected, {}, ErrorSpec(0.0001)); +} + XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x2) { Array3D input_array(2, 1, 2); input_array(0, 0, 0) = 1000; diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index 839ae42a194381396e387f0e6e8a018d6fbd5cff..c5f20b9ca1db1812f52a4d6f568ff9093016a90b 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -67,6 +67,22 @@ XLA_TEST_F(ReshapeTest, SingleElementArrayToScalar) { ComputeAndCompareR0(&builder, 1.0f, {}, zero_error_spec_); } +XLA_TEST_F(ReshapeTest, ScalarToSingleElementArray) { + ComputationBuilder builder(client_, TestName()); + + std::unique_ptr param0_literal = LiteralUtil::CreateR0(1.0f); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + + auto a = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0"); + a = builder.Neg(a); + auto reshape = + builder.Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1}); + + ComputeAndCompareR1(&builder, {-1.0f}, {param0_data.get()}, + zero_error_spec_); +} + XLA_TEST_F(ReshapeTest, Trivial0x3) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR2FromArray2D(Array2D(0, 3)); @@ -75,6 +91,24 @@ XLA_TEST_F(ReshapeTest, Trivial0x3) { ComputeAndCompareR1(&builder, {}, {}, zero_error_spec_); } +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-05-15 +// with an incorrect result rank. +XLA_TEST_F(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) { + ComputationBuilder builder(client_, TestName()); + + std::unique_ptr param0_literal = + LiteralUtil::CreateR2FromArray2D(Array2D(0, 3)); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + + auto a = builder.Parameter(0, ShapeUtil::MakeShape(F32, {0, 3}), "param0"); + auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); + + ComputeAndCompareR1(&builder, {}, {param0_data.get()}, + zero_error_spec_); +} + XLA_TEST_F(ReshapeTest, Trivial3x0) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR2FromArray2D(Array2D(3, 0)); diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 1d9baf5de102752fe4b47af22ce127ba934a2579..535e5b605b4f68671c9b6a8af4a12732f88e744e 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -153,6 +153,7 @@ cc_binary( "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service", + "//tensorflow/compiler/xla/service:computation_tracker", "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc index 8b96e13489774539b50022808975db56c5ddc6f7..1f0ca31d6d6d57507c8639bec83d66f36cb44ab8 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" @@ -34,7 +35,7 @@ limitations under the License. namespace xla { namespace tools { -void RealMain(tensorflow::gtl::ArraySlice args) { +void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { LocalClient* client = ClientLibrary::LocalClientOrDie(); LocalService* local_service = ClientLibrary::GetXlaService(client->platform()); @@ -50,23 +51,35 @@ void RealMain(tensorflow::gtl::ArraySlice args) { } Computation computation = computation_status.ConsumeValueOrDie(); - std::unique_ptr program_shape = - client->GetComputationShape(computation).ConsumeValueOrDie(); + if (compile) { + std::unique_ptr program_shape = + client->GetComputationShape(computation).ConsumeValueOrDie(); - std::vector layouts; - for (int i = 0; i < program_shape->parameters_size(); ++i) { - layouts.push_back(&program_shape->parameters(i)); - } - StatusOr> executable = - local_service->CompileExecutable( - computation.handle(), layouts, &program_shape->result(), - /*device_ordinal=*/0, /*has_hybrid_result=*/true); + std::vector layouts; + for (int i = 0; i < program_shape->parameters_size(); ++i) { + layouts.push_back(&program_shape->parameters(i)); + } + StatusOr> executable = + local_service->CompileExecutable( + computation.handle(), layouts, &program_shape->result(), + /*device_ordinal=*/0, /*has_hybrid_result=*/true); + + const HloModule& module = executable.ValueOrDie()->module(); - const HloModule& module = executable.ValueOrDie()->module(); + fprintf(stdout, "HLO compiled for %s backend:\n%s\n", + local_service->backend().platform()->Name().c_str(), + module.ToString().c_str()); + } else { + const ComputationTracker& tracker = local_service->computation_tracker(); + UserComputation* user_computation = + tracker.Resolve(computation.handle()).ConsumeValueOrDie(); + VersionedComputationHandle versioned_handle = + user_computation->GetVersionedHandle(); + std::unique_ptr module = + tracker.BuildHloModule(versioned_handle).ConsumeValueOrDie(); - fprintf(stdout, "HLO for %s backend:\n%s\n", - local_service->backend().platform()->Name().c_str(), - module.ToString().c_str()); + fprintf(stdout, "%s\n", module->ToString().c_str()); + } } } @@ -74,10 +87,21 @@ void RealMain(tensorflow::gtl::ArraySlice args) { } // namespace xla int main(int argc, char** argv) { - tensorflow::port::InitMain(argv[0], &argc, &argv); + bool compile = false; + std::vector flag_list = { + {"compile", &compile, + "If true, compile the computation using the default client before " + "dumping the HLO. Otherwise dump the raw (uncompiled) HLO."}, + }; + const xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + QCHECK(parsed_flags_ok) << "\n" << usage; + + tensorflow::port::InitMain(usage.c_str(), &argc, &argv); + QCHECK(argc > 1) << "\nERROR: must specify at least one module\n" << usage; tensorflow::gtl::ArraySlice args(argv, argc); args.pop_front(); // Pop off the binary name, argv[0] - xla::tools::RealMain(args); + xla::tools::RealMain(args, compile); return 0; } diff --git a/tensorflow/compiler/xla/types.h b/tensorflow/compiler/xla/types.h index 8258031a2c5119d085a483a0826f7284897dcee3..8d8e66715a3626825195f875a5942e1b1db67f92 100644 --- a/tensorflow/compiler/xla/types.h +++ b/tensorflow/compiler/xla/types.h @@ -18,6 +18,8 @@ limitations under the License. #include "tensorflow/core/platform/types.h" +#include + namespace xla { using ::tensorflow::string; @@ -32,6 +34,8 @@ using ::tensorflow::uint16; using ::tensorflow::uint32; using ::tensorflow::uint64; +using ::Eigen::half; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_TYPES_H_ diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index a711b5035d842cd26945b2dac1159392813d56ab..d467178cb528a93b2c1030fc72d054cc0edf95b6 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -33,7 +33,7 @@ namespace { // Adds a backtrace to the provided status iff the xla_status_add_backtrace flag // is set. This is useful for quickly tracing status errors observed coming out // of the service. -Status MaybeAddBacktrace(Status prior) { +Status MaybeAddBacktrace(const Status& prior) { DCHECK(!prior.ok()); if (legacy_flags::GetUtilFlags()->xla_status_add_backtrace) { return Status{prior.code(), @@ -153,16 +153,26 @@ string Reindent(tensorflow::StringPiece original, }); } +bool IsPermutation(tensorflow::gtl::ArraySlice permutation, int64 rank) { + if (rank != permutation.size()) { + return false; + } + std::vector output(permutation.size(), -1); + for (auto index : permutation) { + CHECK_GE(index, 0); + CHECK_LT(index, rank); + output[index] = 0; + } + return std::find(output.begin(), output.end(), -1) == output.end(); +} + std::vector InversePermutation( tensorflow::gtl::ArraySlice input_permutation) { + DCHECK(IsPermutation(input_permutation, input_permutation.size())); std::vector output_permutation(input_permutation.size(), -1); for (size_t i = 0; i < input_permutation.size(); ++i) { output_permutation[input_permutation[i]] = i; } - DCHECK_EQ( - 0, std::count(output_permutation.begin(), output_permutation.end(), -1)); - DCHECK(std::is_permutation(input_permutation.begin(), input_permutation.end(), - output_permutation.begin())); return output_permutation; } @@ -196,6 +206,15 @@ PaddingConfig MakeNoPaddingConfig(int64 rank) { return padding_config; } +bool HasInteriorPadding(const PaddingConfig& config) { + for (const auto& dim : config.dimensions()) { + if (dim.interior_padding() != 0) { + return true; + } + } + return false; +} + string HumanReadableNumFlops(double flops, double nanoseconds) { if (nanoseconds == 0) { return "NaN FLOP/s"; diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 32b5fbba0032c04117c2109b5452e098b03e0947..42d5c1d15501fb912551a044414e6fa0c83283b8 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -39,6 +39,13 @@ limitations under the License. namespace xla { +// Ranks greater than 8 are very rare, so use InlinedVector to store +// the bounds and indices. And for the rare cases of ranks greater than 8, +// the InlinedVector will just behave like an std::vector<> and allocate the +// memory to store its values. +static constexpr int kInlineRank = 8; +using DimensionVector = tensorflow::gtl::InlinedVector; + // RAII timer that logs with a given label the wall clock time duration in human // readable form. This differs from base's ElapsedTimer primarily in that it // spits out the human-readable duration form. @@ -139,6 +146,18 @@ bool ContainersEqual(const Container1T& c1, const Container2T& c2, std::equal(std::begin(c1), std::end(c1), std::begin(c2), p)); } +// Performs a copy of count values from src to dest, using different strides for +// source and destination. The source starting index is src_base, while the +// destination one is dest_base. +template +void StridedCopy(tensorflow::gtl::MutableArraySlice dest, int64 dest_base, + int64 dest_stride, tensorflow::gtl::ArraySlice src, + int64 src_base, int64 src_stride, int64 count) { + for (; count > 0; --count, dest_base += dest_stride, src_base += src_stride) { + dest[dest_base] = static_cast(src[src_base]); + } +} + // Adds some context information to the error message in a // Status. This is useful as Statuses are // propagated upwards. @@ -165,6 +184,9 @@ Status Unavailable(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); string Reindent(tensorflow::StringPiece original, tensorflow::StringPiece indentation); +// Checks whether permutation is a permutation of the [0, rank) integer range. +bool IsPermutation(tensorflow::gtl::ArraySlice permutation, int64 rank); + // Applies `permutation` on `input` and returns the permuted array. // For each i, output[permutation[i]] = input[i]. // @@ -175,12 +197,11 @@ template