diff --git a/ISSUE_TEMPLATE.md b/ISSUE_TEMPLATE.md index 5b37028c509a2ba2e331a463834dcda18ba69584..5bf13ee152ad95d9d9b4cabd80f9b7d7de8b1813 100644 --- a/ISSUE_TEMPLATE.md +++ b/ISSUE_TEMPLATE.md @@ -6,7 +6,7 @@ If you open a GitHub issue, here is our policy: 1. It must be a bug or a feature request. 2. The form below must be filled out. -3. It shouldn't be a TensorBoard issue. Those go [here](https://github.com/tensorflow/tensorflow/issues). +3. It shouldn't be a TensorBoard issue. Those go [here](https://github.com/tensorflow/tensorboard/issues). **Here's why we have that policy**: TensorFlow developers respond to issues. We want to focus on work that benefits the whole community, e.g., fixing bugs and adding features. Support only helps individuals. GitHub also notifies thousands of people when issues are filed. We want them to see you communicating an interesting problem, rather than being redirected to Stack Overflow. @@ -17,6 +17,7 @@ If you open a GitHub issue, here is our policy: - **OS Platform and Distribution (e.g., Linux Ubuntu 16.04)**: - **TensorFlow installed from (source or binary)**: - **TensorFlow version (use command below)**: +- **Python version**: - **Bazel version (if compiling from source)**: - **CUDA/cuDNN version**: - **GPU model and memory**: diff --git a/README.md b/README.md index cbc94c1ab2bbc62bc5cb13f7a5999031defeac74..90d50f676811a45479556ad1868dfa1f3d9fe29d 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ or more CPUs or GPUs in a desktop, server, or mobile device without rewriting code. TensorFlow also includes TensorBoard, a data visualization toolkit. TensorFlow was originally developed by researchers and engineers -working on the Google Brain team within Google's Machine Intelligence research +working on the Google Brain team within Google's Machine Intelligence Research organization for the purposes of conducting machine learning and deep neural networks research. The system is general enough to be applicable in a wide variety of other domains, as well. @@ -34,12 +34,12 @@ and discussion.** People who are a little more adventurous can also try our nightly binaries: -* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.2.0-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.2.0-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.2.0-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/)) -* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.2.0-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.2.0-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/)) -* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.2.0-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.2.0-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/)) -* Mac GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.2.0-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.2.0-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/)) -* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.2.0-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.2.0-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/)) -* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.2.0-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.2.0-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/)) +* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.2.1-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.2.1-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.2.1-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/)) +* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.2.1-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.2.1-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.2.1-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/)) +* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.2.1-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.2.1-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/)) +* Mac GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.2.1-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.2.1-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/)) +* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.2.1-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.2.1-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows,PY=36/)) +* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.2.1-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.2.1-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/M=windows-gpu,PY=36/)) * Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](http://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/)) diff --git a/RELEASE.md b/RELEASE.md index 9875838d7e187b45f13892fa4629e6e0f842a234..0cd4eef5d6562ea45e9318856b373ea2dbe89314 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,9 @@ +# Release 1.2.1 + +## Bug Fixes and Other Changes +* Updating markdown version required to >= 2.6.8. +* Support tensors as dropout rates again, by removing the min(max(..)) + # Release 1.2.0 ## Major Features and Improvements diff --git a/configure b/configure index 602124225fe0712135798a779e509a16fe2ccc79..4c6cba216914c3d7f07466241e9c5cffbcdcc1c7 100755 --- a/configure +++ b/configure @@ -25,6 +25,10 @@ function is_windows() { [[ "${PLATFORM}" =~ msys_nt*|mingw*|cygwin*|uwin* ]] } +function is_ppc64le() { + [[ "${uname -m}" == "ppc64le" ]] +} + function sed_in_place() { sed -e $1 $2 > "$2.bak" mv "$2.bak" $2 @@ -294,7 +298,12 @@ fi # TF_NEED_MKL ## Set up architecture-dependent optimization flags. if [ -z "$CC_OPT_FLAGS" ]; then - default_cc_opt_flags="-march=native" + if [ is_ppc64le ]; then + # gcc on ppc64le does not support -march, use mcpu instead + default_cc_opt_flags="-mcpu=native" + else + default_cc_opt_flags="-march=native" + fi read -p "Please specify optimization flags to use during compilation when bazel option "\ "\"--config=opt\" is specified [Default is $default_cc_opt_flags]: " CC_OPT_FLAGS if [ -z "$CC_OPT_FLAGS" ]; then @@ -539,9 +548,9 @@ done # Set default CUDA version if not set if [ -z "$TF_CUDA_VERSION" ]; then TF_CUDA_VERSION="8.0" - export TF_CUDA_VERSION + export TF_CUDA_VERSION fi -write_action_env_to_bazelrc "TF_CUDA_VERSION" "$TF_CUDA_VERSION" +write_action_env_to_bazelrc "TF_CUDA_VERSION" "$TF_CUDA_VERSION" # Set up which gcc nvcc should use as the host compiler # No need to set this on Windows @@ -590,6 +599,9 @@ while true; do # Result returned from "read" will be used unexpanded. That make "~" unusable. # Going through one more level of expansion to handle that. CUDNN_INSTALL_PATH=`"${PYTHON_BIN_PATH}" -c "import os; print(os.path.realpath(os.path.expanduser('${CUDNN_INSTALL_PATH}')))"` + if is_windows; then + CUDNN_INSTALL_PATH="$(cygpath -m "$CUDNN_INSTALL_PATH")" + fi fi if [[ -z "$TF_CUDNN_VERSION" ]]; then @@ -656,16 +668,22 @@ write_action_env_to_bazelrc "TF_CUDNN_VERSION" "$TF_CUDNN_VERSION" # Configure the compute capabilities that TensorFlow builds for. # Since Cuda toolkit is not backward-compatible, this is not guaranteed to work. +function get_native_cuda_compute_capabilities { + device_query_bin="$CUDA_TOOLKIT_PATH/extras/demo_suite/deviceQuery" # Also works on Windows without .exe + "$device_query_bin" | grep 'Capability' | grep -o '[0-9]*\.[0-9]*' | sed ':a;{N;s/\n/,/};ba' + exit 0 # ensure that this function always exit success even if device detection fails, to prevent the whole configure from aborting +} while true; do fromuser="" - default_cuda_compute_capabilities="3.5,5.2" + native_cuda_compute_capabilities=$(get_native_cuda_compute_capabilities) + default_cuda_compute_capabilities=${native_cuda_compute_capabilities:-"3.5,5.2"} if [ -z "$TF_CUDA_COMPUTE_CAPABILITIES" ]; then cat << EOF Please specify a list of comma-separated Cuda compute capabilities you want to build with. You can find the compute capability of your device at: https://developer.nvidia.com/cuda-gpus. Please note that each additional compute capability significantly increases your build time and binary size. EOF - read -p "[Default is: \"3.5,5.2\"]: " TF_CUDA_COMPUTE_CAPABILITIES + read -p "[Default is: \"$default_cuda_compute_capabilities\"]: " TF_CUDA_COMPUTE_CAPABILITIES fromuser=1 fi if [ -z "$TF_CUDA_COMPUTE_CAPABILITIES" ]; then @@ -705,8 +723,13 @@ if is_windows; then write_to_bazelrc "test --config=win-cuda" else # If CUDA is enabled, always use GPU during build and test. - write_to_bazelrc "build --config=cuda" - write_to_bazelrc "test --config=cuda" + if [ "$TF_CUDA_CLANG" == "1" ]; then + write_to_bazelrc "build --config=cuda_clang" + write_to_bazelrc "test --config=cuda_clang" + else + write_to_bazelrc "build --config=cuda" + write_to_bazelrc "test --config=cuda" + fi fi # end of if "$TF_NEED_CUDA" == "1" @@ -836,17 +859,17 @@ while true; do if [ -e "$MPI_HOME/include" ] && [ -e "$MPI_HOME/lib" ]; then break fi - + echo "Invalid path to the MPI Toolkit. ${MPI_HOME}/include or ${MPI_HOME}/lib cannot be found." if [ -z "$fromuser" ]; then exit 1 fi # Retry - MPI_HOME="" + MPI_HOME="" done - - + + if [ "$TF_NEED_MPI" == "1" ]; then write_to_bazelrc 'build --define with_mpi_support=true' @@ -854,11 +877,11 @@ if [ "$TF_NEED_MPI" == "1" ]; then ln -sf "${MPI_HOME}/include/mpi.h" third_party/mpi/mpi.h - #Determine if we use OpenMPI or MVAPICH, these require different header files + #Determine if we use OpenMPI or MVAPICH, these require different header files #to be included here to make bazel dependency checker happy if [ -e "${MPI_HOME}/include/mpi_portable_platform.h" ]; then - #OpenMPI + #OpenMPI ln -sf "${MPI_HOME}/include/mpi_portable_platform.h" third_party/mpi/ sed -i -e "s/MPI_LIB_IS_OPENMPI=False/MPI_LIB_IS_OPENMPI=True/" third_party/mpi/mpi.bzl else @@ -868,7 +891,7 @@ if [ "$TF_NEED_MPI" == "1" ]; then sed -i -e "s/MPI_LIB_IS_OPENMPI=True/MPI_LIB_IS_OPENMPI=False/" third_party/mpi/mpi.bzl fi - + if [ -e "${MPI_HOME}/lib/libmpi.so" ]; then ln -sf "${MPI_HOME}/lib/libmpi.so" third_party/mpi/ else diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 6450b2ad878b57191ae3b12e7e39213ac168eef6..2f46c2916ad4fe417def6fba1ce1e966f99d5f08 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -39,7 +39,7 @@ config_setting( config_setting( name = "android_armeabi", values = { - "cc_target_os": "android", + "crosstool_top": "//external:android/crosstool", "cpu": "armeabi", }, visibility = ["//visibility:public"], @@ -178,7 +178,10 @@ config_setting( package_group( name = "internal", - packages = ["//tensorflow/..."], + packages = [ + "//learning/protonn/llgtm/...", + "//tensorflow/...", + ], ) filegroup( @@ -216,9 +219,12 @@ filegroup( "//tensorflow/compiler/jit/kernels:all_files", "//tensorflow/compiler/jit/legacy_flags:all_files", "//tensorflow/compiler/jit/ops:all_files", + "//tensorflow/compiler/plugin/executor:all_files", "//tensorflow/compiler/tests:all_files", "//tensorflow/compiler/tf2xla:all_files", + "//tensorflow/compiler/tf2xla/cc:all_files", "//tensorflow/compiler/tf2xla/kernels:all_files", + "//tensorflow/compiler/tf2xla/ops:all_files", "//tensorflow/compiler/xla:all_files", "//tensorflow/compiler/xla/client:all_files", "//tensorflow/compiler/xla/client/lib:all_files", @@ -253,7 +259,7 @@ filegroup( "//tensorflow/contrib/data/python/kernel_tests:all_files", "//tensorflow/contrib/data/python/ops:all_files", "//tensorflow/contrib/data/python/util:all_files", - "//tensorflow/contrib/decision_trees:all_files", + "//tensorflow/contrib/decision_trees/proto:all_files", "//tensorflow/contrib/distributions:all_files", "//tensorflow/contrib/factorization:all_files", "//tensorflow/contrib/factorization/kernels:all_files", @@ -284,6 +290,9 @@ filegroup( "//tensorflow/contrib/ndlstm:all_files", "//tensorflow/contrib/nn:all_files", "//tensorflow/contrib/opt:all_files", + "//tensorflow/contrib/predictor:all_files", + "//tensorflow/contrib/remote_fused_graph/pylib:all_files", + "//tensorflow/contrib/resampler:all_files", "//tensorflow/contrib/rnn:all_files", "//tensorflow/contrib/saved_model:all_files", "//tensorflow/contrib/saved_model/cc/saved_model:all_files", @@ -302,10 +311,17 @@ filegroup( "//tensorflow/contrib/stateless:all_files", "//tensorflow/contrib/tensor_forest:all_files", "//tensorflow/contrib/tensor_forest/hybrid:all_files", + "//tensorflow/contrib/tensor_forest/kernels/v4:all_files", + "//tensorflow/contrib/tensor_forest/proto:all_files", "//tensorflow/contrib/tensorboard:all_files", "//tensorflow/contrib/testing:all_files", "//tensorflow/contrib/text:all_files", - "//tensorflow/contrib/tfprof/python/tools/tfprof:all_files", + "//tensorflow/contrib/tfprof:all_files", + "//tensorflow/contrib/timeseries:all_files", + "//tensorflow/contrib/timeseries/examples:all_files", + "//tensorflow/contrib/timeseries/python/timeseries:all_files", + "//tensorflow/contrib/timeseries/python/timeseries/state_space_models:all_files", + "//tensorflow/contrib/tpu:all_files", "//tensorflow/contrib/training:all_files", "//tensorflow/contrib/util:all_files", "//tensorflow/contrib/verbs:all_files", @@ -327,6 +343,9 @@ filegroup( "//tensorflow/core/platform/cloud:all_files", "//tensorflow/core/platform/default/build_config:all_files", "//tensorflow/core/platform/hadoop:all_files", + "//tensorflow/core/profiler:all_files", + "//tensorflow/core/profiler/internal:all_files", + "//tensorflow/core/profiler/internal/advisor:all_files", "//tensorflow/core/util/ctc:all_files", "//tensorflow/core/util/tensor_bundle:all_files", "//tensorflow/examples/android:all_files", @@ -351,72 +370,10 @@ filegroup( "//tensorflow/python/kernel_tests:all_files", "//tensorflow/python/kernel_tests/distributions:all_files", "//tensorflow/python/ops/distributions:all_files", + "//tensorflow/python/profiler:all_files", + "//tensorflow/python/profiler/internal:all_files", "//tensorflow/python/saved_model:all_files", "//tensorflow/python/tools:all_files", - "//tensorflow/tensorboard:all_files", - "//tensorflow/tensorboard/backend:all_files", - "//tensorflow/tensorboard/backend/event_processing:all_files", - "//tensorflow/tensorboard/components:all_files", - "//tensorflow/tensorboard/components/tf_audio_dashboard:all_files", - "//tensorflow/tensorboard/components/tf_audio_dashboard/test:all_files", - "//tensorflow/tensorboard/components/tf_backend:all_files", - "//tensorflow/tensorboard/components/tf_backend/test:all_files", - "//tensorflow/tensorboard/components/tf_color_scale:all_files", - "//tensorflow/tensorboard/components/tf_color_scale/test:all_files", - "//tensorflow/tensorboard/components/tf_dashboard_common:all_files", - "//tensorflow/tensorboard/components/tf_dashboard_common/test:all_files", - "//tensorflow/tensorboard/components/tf_distribution_dashboard:all_files", - "//tensorflow/tensorboard/components/tf_globals:all_files", - "//tensorflow/tensorboard/components/tf_graph:all_files", - "//tensorflow/tensorboard/components/tf_graph/demo:all_files", - "//tensorflow/tensorboard/components/tf_graph_app:all_files", - "//tensorflow/tensorboard/components/tf_graph_app/demo:all_files", - "//tensorflow/tensorboard/components/tf_graph_board:all_files", - "//tensorflow/tensorboard/components/tf_graph_board/demo:all_files", - "//tensorflow/tensorboard/components/tf_graph_common:all_files", - "//tensorflow/tensorboard/components/tf_graph_controls:all_files", - "//tensorflow/tensorboard/components/tf_graph_controls/demo:all_files", - "//tensorflow/tensorboard/components/tf_graph_dashboard:all_files", - "//tensorflow/tensorboard/components/tf_graph_dashboard/demo:all_files", - "//tensorflow/tensorboard/components/tf_graph_debugger_data_card:all_files", - "//tensorflow/tensorboard/components/tf_graph_debugger_data_card/demo:all_files", - "//tensorflow/tensorboard/components/tf_graph_info:all_files", - "//tensorflow/tensorboard/components/tf_graph_info/demo:all_files", - "//tensorflow/tensorboard/components/tf_graph_loader:all_files", - "//tensorflow/tensorboard/components/tf_graph_loader/demo:all_files", - "//tensorflow/tensorboard/components/tf_histogram_dashboard:all_files", - "//tensorflow/tensorboard/components/tf_image_dashboard:all_files", - "//tensorflow/tensorboard/components/tf_imports:all_files", - "//tensorflow/tensorboard/components/tf_option_selector:all_files", - "//tensorflow/tensorboard/components/tf_profile_dashboard:all_files", - "//tensorflow/tensorboard/components/tf_profile_dashboard/demo:all_files", - "//tensorflow/tensorboard/components/tf_runs_selector:all_files", - "//tensorflow/tensorboard/components/tf_scalar_dashboard:all_files", - "//tensorflow/tensorboard/components/tf_scalar_dashboard/demo:all_files", - "//tensorflow/tensorboard/components/tf_storage:all_files", - "//tensorflow/tensorboard/components/tf_storage/test:all_files", - "//tensorflow/tensorboard/components/tf_tensorboard:all_files", - "//tensorflow/tensorboard/components/tf_text_dashboard:all_files", - "//tensorflow/tensorboard/components/tf_trace_viewer:all_files", - "//tensorflow/tensorboard/components/vz_distribution_chart:all_files", - "//tensorflow/tensorboard/components/vz_histogram_timeseries:all_files", - "//tensorflow/tensorboard/components/vz_line_chart:all_files", - "//tensorflow/tensorboard/components/vz_projector:all_files", - "//tensorflow/tensorboard/components/vz_projector/test:all_files", - "//tensorflow/tensorboard/components/vz_sorting:all_files", - "//tensorflow/tensorboard/components/vz_sorting/test:all_files", - "//tensorflow/tensorboard/demo:all_files", - "//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:all_files", - "//tensorflow/tensorboard/plugins:all_files", - "//tensorflow/tensorboard/plugins/audio:all_files", - "//tensorflow/tensorboard/plugins/distributions:all_files", - "//tensorflow/tensorboard/plugins/graphs:all_files", - "//tensorflow/tensorboard/plugins/histograms:all_files", - "//tensorflow/tensorboard/plugins/images:all_files", - "//tensorflow/tensorboard/plugins/projector:all_files", - "//tensorflow/tensorboard/plugins/scalars:all_files", - "//tensorflow/tensorboard/plugins/text:all_files", - "//tensorflow/tensorboard/scripts:all_files", "//tensorflow/tools/api/golden:all_files", "//tensorflow/tools/api/lib:all_files", "//tensorflow/tools/api/tests:all_files", @@ -427,12 +384,10 @@ filegroup( "//tensorflow/tools/docker/notebooks:all_files", "//tensorflow/tools/docs:all_files", "//tensorflow/tools/git:all_files", + "//tensorflow/tools/mlpbtxt:all_files", "//tensorflow/tools/proto_text:all_files", "//tensorflow/tools/quantization:all_files", "//tensorflow/tools/test:all_files", - "//tensorflow/tools/tfprof:all_files", - "//tensorflow/tools/tfprof/internal:all_files", - "//tensorflow/tools/tfprof/internal/advisor:all_files", "//tensorflow/user_ops:all_files", "//third_party/hadoop:all_files", "//third_party/sycl:all_files", diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 3ab4e8efcdb5b05cf8922edd302e7cbf3a3597f1..9267ef77efb9ba0ebc16515644b6febc65fecafc 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -62,6 +62,7 @@ tf_cuda_library( "//tensorflow/cc:scope_internal", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:lib", ], }), diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 77faa475ed47990a4dcee0e1ca0925af0c1643f9..f620e248c13a422df79dbbd91495785d1fe5def7 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -28,12 +28,15 @@ limitations under the License. #endif #include "tensorflow/c/c_api_internal.h" #include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" @@ -163,7 +166,7 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in, if (out->data != nullptr) { return InvalidArgument("Passing non-empty TF_Buffer is invalid."); } - const auto proto_size = in.ByteSize(); + const auto proto_size = in.ByteSizeLong(); void* buf = tensorflow::port::Malloc(proto_size); in.SerializeToArray(buf, proto_size); out->data = buf; @@ -466,15 +469,6 @@ TF_Tensor* TF_Tensor_EncodeStrings(const Tensor& src) { dimvec.size(), base, size, DeleteArray, base); } -class TensorCApi { - public: - static TensorBuffer* Buffer(const Tensor& tensor) { return tensor.buf_; } - static Tensor MakeTensor(TF_DataType type, const TensorShape& shape, - TensorBuffer* buf) { - return Tensor(static_cast(type), shape, buf); - } -}; - // Create an empty tensor of type 'dtype'. 'shape' can be arbitrary, but has to // result in a zero-sized tensor. static TF_Tensor* EmptyTensor(TF_DataType dtype, const TensorShape& shape) { @@ -628,7 +622,7 @@ void TF_PRunSetup(TF_DeprecatedSession* s, // Target nodes const char** c_target_oper_names, int ntargets, const char** handle, TF_Status* status) { - status->status = Status::OK(); + *handle = nullptr; std::vector input_names(ninputs); std::vector output_names(noutputs); @@ -643,16 +637,12 @@ void TF_PRunSetup(TF_DeprecatedSession* s, target_oper_names[i] = c_target_oper_names[i]; } tensorflow::string new_handle; - Status result; - result = s->session->PRunSetup(input_names, output_names, target_oper_names, - &new_handle); - if (result.ok()) { + status->status = s->session->PRunSetup(input_names, output_names, + target_oper_names, &new_handle); + if (status->status.ok()) { char* buf = new char[new_handle.size() + 1]; memcpy(buf, new_handle.c_str(), new_handle.size() + 1); *handle = buf; - } else { - *handle = nullptr; - status->status = result; } } @@ -1600,6 +1590,14 @@ void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def, // TF_Graph functions --------------------------------------------------------- +TF_Graph::TF_Graph() + : graph(tensorflow::OpRegistry::Global()), + refiner(graph.versions().producer(), graph.op_registry()), + num_sessions(0), + delete_requested(false), + parent(nullptr), + parent_inputs(nullptr) {} + TF_Graph* TF_NewGraph() { return new TF_Graph; } void TF_DeleteGraph(TF_Graph* g) { @@ -2326,6 +2324,8 @@ void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs, int ninputs, const TF_Output* outputs, int noutputs, const TF_Operation* const* target_opers, int ntargets, const char** handle, TF_Status* status) { + *handle = nullptr; + if (!ExtendSessionGraphHelper(session, status)) { return; } diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 15139a47acf4b5bcf7a6b6fd77de5834f3f9189c..3aeafb46855a24e1987e2ca4f9f40cb7930794c1 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -1101,8 +1101,7 @@ TF_CAPI_EXPORT extern void TF_SessionRun( // needed. // // On failure, out_status contains a tensorflow::Status with an error -// message. -// NOTE: This is EXPERIMENTAL and subject to change. +// message. *handle is set to nullptr. TF_CAPI_EXPORT extern void TF_SessionPRunSetup( TF_Session*, // Input names @@ -1118,7 +1117,6 @@ TF_CAPI_EXPORT extern void TF_SessionPRunSetup( // Continue to run the graph with additional feeds and fetches. The // execution state is uniquely identified by the handle. -// NOTE: This is EXPERIMENTAL and subject to change. TF_CAPI_EXPORT extern void TF_SessionPRun( TF_Session*, const char* handle, // Input tensors diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index f17ac26ad9665d7ea8cc1ef566cad81bba712b62..7e987a65f7b24eab6508c095e6177414503d43b9 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -56,13 +56,8 @@ struct TF_Library { }; struct TF_Graph { - TF_Graph() - : graph(tensorflow::OpRegistry::Global()), - refiner(graph.versions().producer(), graph.op_registry()), - num_sessions(0), - delete_requested(false), - parent(nullptr), - parent_inputs(nullptr) {} + TF_Graph(); + tensorflow::mutex mu; tensorflow::Graph graph GUARDED_BY(mu); @@ -117,3 +112,16 @@ struct TF_ImportGraphDefOptions { struct TF_DeviceList { std::vector response; }; + +namespace tensorflow { + +class TensorCApi { + public: + static TensorBuffer* Buffer(const Tensor& tensor) { return tensor.buf_; } + static Tensor MakeTensor(TF_DataType type, const TensorShape& shape, + TensorBuffer* buf) { + return Tensor(static_cast(type), shape, buf); + } +}; + +} // end namespace tensorflow diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 04540bd793dab34c2f707e9e995defe7b4e15858..736f56837c4775e3b6cf03f8bbee88e91c2d69fa 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/tensor_id.h" @@ -911,11 +912,13 @@ class CSession { for (TF_Operation* o : outputs) { outputs_.emplace_back(TF_Output{o, 0}); } + output_values_.resize(outputs_.size()); } void SetOutputs(const std::vector& outputs) { ResetOutputValues(); outputs_ = outputs; + output_values_.resize(outputs_.size()); } void SetTargets(std::initializer_list targets) { diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index f89cc6384b3440a1e4d1bfe596b145eef5604964..b461a475c13df2385cff05ffd273f3ee23bd20c0 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -45,6 +45,7 @@ tf_cc_test( "//tensorflow/core:all_kernels", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", @@ -61,7 +62,6 @@ cc_library( ":gradients", ":ops", ":scope", - "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -274,10 +274,6 @@ cc_library( deps = [ ":cc_ops", ":grad_op_registry", - ":ops", - ":scope", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", ], ) @@ -305,10 +301,6 @@ cc_library( ":cc_ops", ":cc_ops_internal", ":grad_op_registry", - ":ops", - ":scope", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", ], ) @@ -441,6 +433,7 @@ cc_library_with_android_deps( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:op_gen_lib", + "//tensorflow/core:op_gen_overrides_proto_cc", "//tensorflow/core:proto_text", "//tensorflow/core:protos_all_cc", ], @@ -527,7 +520,6 @@ cc_library( deps = [ ":coordinator", "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", @@ -560,8 +552,6 @@ cc_library( srcs = ["training/coordinator.cc"], hdrs = ["training/coordinator.h"], deps = [ - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index 71aa986f918de68822d457422f6c7a73d6253819..80dd272f6f9dd5eecf5d7002bdf1c7c98e4c3ba3 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -18,8 +18,12 @@ limitations under the License. #include #include "tensorflow/cc/framework/cc_op_gen.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/core/framework/op_gen_overrides.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb_text.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" diff --git a/tensorflow/cc/framework/gradients.cc b/tensorflow/cc/framework/gradients.cc index 8c00a6f70497df2c70f266a747197e50c98375bb..29ad8a934bddd38f4ddfa2a2d3ed8a598a01da20 100644 --- a/tensorflow/cc/framework/gradients.cc +++ b/tensorflow/cc/framework/gradients.cc @@ -152,12 +152,12 @@ Status SymbolicGradientBuilder::Initialize() { grad_outputs_->resize(inputs_.size()); // Populate `output_nodes_` from node ids in `outputs_`. output_nodes_.reserve(outputs_.size()); - for (int i = 0; i < outputs_.size(); ++i) { + for (size_t i = 0; i < outputs_.size(); ++i) { output_nodes_.insert(outputs_[i].node()->id()); } // Populate `input_nodes_` from Outputs in `inputs_`. input_nodes_.reserve(inputs_.size()); - for (int i = 0; i < inputs_.size(); ++i) { + for (size_t i = 0; i < inputs_.size(); ++i) { input_nodes_.insert({inputs_[i], i}); } @@ -341,7 +341,7 @@ Status SymbolicGradientBuilder::AddGradients() { // gradient function to the src node/output to which it should be // backproped. Maybe grad functions can return a vector of Output pairs to // make this association explicit. - int dx_index = 0; + size_t dx_index = 0; for (const Edge* e : n->in_edges()) { if (e->IsControlEdge()) continue; if (dx_index == dx.size()) { diff --git a/tensorflow/cc/framework/gradients_test.cc b/tensorflow/cc/framework/gradients_test.cc index 6a249825812b4d39b55f7170a35436b6ae88c020..2aad9784808ea1ca7fa30434c22d1ee92ad7857c 100644 --- a/tensorflow/cc/framework/gradients_test.cc +++ b/tensorflow/cc/framework/gradients_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/cc/framework/grad_op_registry.h" #include "tensorflow/cc/framework/testutil.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index 32c0822de69da7989ceaa4028539db928b6fcea3..1948dd4e46b932775fdb5cbbdad7b66338b0fcf4 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -136,7 +136,7 @@ Scope::Impl::Impl(const std::shared_ptr& graph, Scope Scope::NewRootScope() { Graph* graph = new Graph(OpRegistry::Global()); ShapeRefiner* refiner = - new ShapeRefiner(graph->versions().producer(), graph->op_registry()); + new ShapeRefiner(graph->versions(), graph->op_registry()); return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner)); } diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc index 37f07e71a0dff9144f193679bbcfcf581c1538cf..48185db3cbdd065d76e8bf75ef926938bd3a1268 100644 --- a/tensorflow/cc/gradients/array_grad.cc +++ b/tensorflow/cc/gradients/array_grad.cc @@ -100,6 +100,17 @@ Status QuantizeAndDequantizeV2Grad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("QuantizeAndDequantizeV2", QuantizeAndDequantizeV2Grad); +Status QuantizeAndDequantizeV3Grad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + grad_outputs->push_back(Identity(scope, grad_inputs[0])); + grad_outputs->push_back(NoGradient()); + grad_outputs->push_back(NoGradient()); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("QuantizeAndDequantizeV3", QuantizeAndDequantizeV3Grad); + Status SplitGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { @@ -247,6 +258,17 @@ Status ScatterNdGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("ScatterNd", ScatterNdGrad); +Status ScatterNdNonAliasingAddGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto indices = op.input(1); + grad_outputs->push_back(Identity(scope, grad_inputs[0])); + grad_outputs->push_back(NoGradient()); + grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices)); + return scope.status(); +} +REGISTER_GRADIENT_OP("ScatterNdNonAliasingAdd", ScatterNdNonAliasingAddGrad); + Status PadGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { diff --git a/tensorflow/cc/gradients/array_grad_test.cc b/tensorflow/cc/gradients/array_grad_test.cc index 5798b5b509fc14e6c9d95d4fd42aca893254f775..1777e181451b267f52a418888912ed1393bdf8b1 100644 --- a/tensorflow/cc/gradients/array_grad_test.cc +++ b/tensorflow/cc/gradients/array_grad_test.cc @@ -233,6 +233,28 @@ TEST_F(ArrayGradTest, ScatterNdGrad_SliceIndexing) { RunTest(updates, updates_shape, y, y_shape); } +TEST_F(ArrayGradTest, ScatterNdNonAliasingAddGrad_SimpleIndexing) { + TensorShape updates_shape({4}); + TensorShape input_shape({8}); + auto input = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(input_shape)); + auto updates = + Placeholder(scope_, DT_FLOAT, Placeholder::Shape(updates_shape)); + auto indices = Const(scope_, {{4}, {3}, {1}, {7}}); + auto y = ScatterNdNonAliasingAdd(scope_, input, indices, updates); + RunTest({input, updates}, {input_shape, updates_shape}, {y}, {input_shape}); +} + +TEST_F(ArrayGradTest, ScatterNdNonAliasingAddGrad_SliceIndexing) { + TensorShape updates_shape({2, 4, 4}); + TensorShape input_shape({4, 4, 4}); + auto input = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(input_shape)); + auto updates = + Placeholder(scope_, DT_FLOAT, Placeholder::Shape(updates_shape)); + auto indices = Const(scope_, {{0}, {2}}); + auto y = ScatterNdNonAliasingAdd(scope_, input, indices, updates); + RunTest({input, updates}, {input_shape, updates_shape}, {y}, {input_shape}); +} + TEST_F(ArrayGradTest, PadGrad) { TensorShape x_shape({2, 3}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index 71d9a8ed7be5ea75a3b26224df871b955f05c132..0b9b665b1eb4420827b152a88d9023ceab4d932d 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -203,6 +203,46 @@ Status TanhGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Tanh", TanhGrad); +Status AsinhGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + // y = asinh(x) + // dy/dx = 1 / cosh(y) + auto dydx = Reciprocal(scope, Cosh(scope, op.output(0))); + // grad(x) = grad(y) * conj(dy/dx) + grad_outputs->push_back( + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); + return scope.status(); +} +REGISTER_GRADIENT_OP("Asinh", AsinhGrad); + +Status AcoshGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + // y = acosh(x) + // dy/dx = 1 / sinh(y) + auto dydx = Reciprocal(scope, Sinh(scope, op.output(0))); + // grad(x) = grad(y) * conj(dy/dx) + grad_outputs->push_back( + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); + return scope.status(); +} +REGISTER_GRADIENT_OP("Acosh", AcoshGrad); + +Status AtanhGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + // y = atanh(x) + // dy/dx = 1 / (1 - x^2) + auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); + auto dydx = Reciprocal(scope, Sub(scope, one, Square(scope, op.input(0)))); + // grad(x) = grad(y) * conj(dy/dx) + grad_outputs->push_back( + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); + return scope.status(); +} +REGISTER_GRADIENT_OP("Atanh", AtanhGrad); + Status SigmoidGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index 1653b04378f30bd788d549da04d4140ac7d6317e..48b3ddbe90c2313ec0aa50729f277a1c258de52c 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -48,6 +48,9 @@ class CWiseUnaryGradTest : public ::testing::Test { SINH, COSH, TANH, + ASINH, + ACOSH, + ATANH, SIGMOID, SIGN, SIN, @@ -122,6 +125,15 @@ class CWiseUnaryGradTest : public ::testing::Test { case TANH: y = Tanh(scope_, x); break; + case ASINH: + y = Asinh(scope_, x); + break; + case ACOSH: + y = Acosh(scope_, x); + break; + case ATANH: + y = Atanh(scope_, x); + break; case SIGMOID: y = Sigmoid(scope_, x); break; @@ -413,6 +425,76 @@ TEST_F(CWiseUnaryGradTest, Tanh_Complex) { TestCWiseGrad(TANH, x_fn, dy_fn, dx_fn); } +TEST_F(CWiseUnaryGradTest, Asinh) { + auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; + auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; + auto dx_fn = [this](const float x, const float dy) { + auto y = std::asinh(x); + return dy / std::cosh(y); + }; + TestCWiseGrad(ASINH, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Asinh_Complex) { + auto x_fn = [this](const int i) { + return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + auto y = std::asinh(x); + return dy / conjugate(std::cosh(y)); + }; + TestCWiseGrad(ASINH, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Acosh) { + auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7}); }; + auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13, 14}); }; + auto dx_fn = [this](const float x, const float dy) { + auto y = std::acosh(x); + return dy / std::sinh(y); + }; + TestCWiseGrad(ACOSH, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Acosh_Complex) { + auto x_fn = [this](const int i) { + return CRV({{1, 1}, {2, 1}, {1, 4}, {1, 2}, {3, 4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{2, 2}, {3, 3}, {1, 4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + auto y = std::acosh(x); + return dy / conjugate(std::sinh(y)); + }; + TestCWiseGrad(ACOSH, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Atanh) { + auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -0.1, 0.1}); }; + auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; + auto dx_fn = [this](const float x, const float dy) { + return dy * (1. / (1. - x * x)); + }; + TestCWiseGrad(ATANH, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Atanh_Complex) { + auto x_fn = [this](const int i) { + return CRV({{0.1, 0}, {0, 0.1}, {0.2, -0.1}, {0.1, 0.2}, {0.3, 0.4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return dy / conjugate(one_ - x * x); + }; + TestCWiseGrad(ATANH, x_fn, dy_fn, dx_fn); +} + TEST_F(CWiseUnaryGradTest, Sigmoid) { auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index 5e5203d09055d65cb1dcc16e091f6e5028ee7ae1..952b2015edf5dee1246ac2ccb8722ced41c22dfa 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -46,6 +46,19 @@ Status SoftmaxGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Softmax", SoftmaxGrad); +Status LogSoftmaxGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + + auto softmax = Exp(scope, op.output(0)); + auto sum = Sum(scope, grad_inputs[0], {1}, Sum::KeepDims(true)); + auto mul = Mul(scope, sum, softmax); + auto dx = Sub(scope, grad_inputs[0], mul); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("LogSoftmax", LogSoftmaxGrad); + Status ReluGradHelper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index 70c9bd4e08b2b46866a44becc8fe1305fec48ea9..daa87546ec08474eedb640d0a31aed4d172c9a24 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -57,6 +57,19 @@ TEST_F(NNGradTest, SoftmaxGrad) { RunTest(x, shape, y, shape); } +TEST_F(NNGradTest, LogSoftmaxGrad) { + TensorShape shape({5, 3}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto y = LogSoftmax(scope_, x); + // Avoid numerical instability when computing finite differences. + Tensor x_init_value = test::AsTensor( + {-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, + 0.1f, 0.3f, 0.5f, 0.7f, 0.8f, + -0.1f, 0.1f, 0.1f, 0.1f, 1.2f}, + {5, 3}); + RunTest(x, x_init_value, y, shape); +} + TEST_F(NNGradTest, ReluGrad) { TensorShape shape({5, 2}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); diff --git a/tensorflow/cc/saved_model/constants.h b/tensorflow/cc/saved_model/constants.h index 94a3b3cf465a279e3bb44344739499ad670119c3..c940df8a8761d97a859be3af30980ff79ca3577a 100644 --- a/tensorflow/cc/saved_model/constants.h +++ b/tensorflow/cc/saved_model/constants.h @@ -21,6 +21,9 @@ namespace tensorflow { /// SavedModel assets directory. constexpr char kSavedModelAssetsDirectory[] = "assets"; +/// SavedModel assets.extra directory. +constexpr char kSavedModelAssetsExtraDirectory[] = "assets.extra"; + /// SavedModel assets key for graph collection-def. constexpr char kSavedModelAssetsKey[] = "saved_model_assets"; diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index 807f5904afcf36890f4bd02f0d811a3ebe0cceba..f98abc8a817eca7bc129bb03a2ad31b97d957065 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/constants.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/monitoring/counter.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/protobuf_internal.h" #include "tensorflow/core/protobuf/saved_model.pb.h" @@ -76,8 +77,16 @@ Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto, return Status::OK(); } } + string tags_as_string = "{ "; + for (const string& tag : tags) { + tags_as_string = strings::StrCat(tags_as_string, tag, " "); + } + tags_as_string = strings::StrCat(tags_as_string, "}"); return Status(error::Code::NOT_FOUND, - "Could not find meta graph def matching supplied tags."); + "Could not find meta graph def matching supplied tags: " + + tags_as_string + + ". To inspect available tag-sets in the SavedModel, please " + "use the SavedModel CLI: `saved_model_cli`"); } Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def, diff --git a/tensorflow/cc/saved_model/loader_test.cc b/tensorflow/cc/saved_model/loader_test.cc index cef29e7b071e538a60193fd998acc0fb29c2cea3..0ad6b33bba5fcceaca68e2f179cef2232c689a80 100644 --- a/tensorflow/cc/saved_model/loader_test.cc +++ b/tensorflow/cc/saved_model/loader_test.cc @@ -133,9 +133,9 @@ TEST_F(LoaderTest, NoTagMatch) { Status st = LoadSavedModel(session_options, run_options, export_dir, {"missing-tag"}, &bundle); EXPECT_FALSE(st.ok()); - EXPECT_TRUE( - StringPiece(st.error_message()) - .contains("Could not find meta graph def matching supplied tags.")) + EXPECT_TRUE(StringPiece(st.error_message()) + .contains("Could not find meta graph def matching supplied " + "tags: { missing-tag }")) << st.error_message(); } @@ -151,7 +151,7 @@ TEST_F(LoaderTest, NoTagMatchMultiple) { EXPECT_FALSE(st.ok()); EXPECT_TRUE( StringPiece(st.error_message()) - .contains("Could not find meta graph def matching supplied tags.")) + .contains("Could not find meta graph def matching supplied tags: ")) << st.error_message(); } diff --git a/tensorflow/cc/saved_model/tag_constants.h b/tensorflow/cc/saved_model/tag_constants.h index 48ab1158e462af25c27a728e404a041516e82057..2b0b2d5c7fb33768494c1781669c1adcb875a579 100644 --- a/tensorflow/cc/saved_model/tag_constants.h +++ b/tensorflow/cc/saved_model/tag_constants.h @@ -18,10 +18,13 @@ limitations under the License. namespace tensorflow { +/// Tag for the `gpu` graph. +constexpr char kSavedModelTagGpu[] = "gpu"; + /// Tag for the `serving` graph. constexpr char kSavedModelTagServe[] = "serve"; -/// Tag for the `training` graph.` +/// Tag for the `training` graph. constexpr char kSavedModelTagTrain[] = "train"; } // namespace tensorflow diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 1f6fe28188cfbb6a64935e4a3f70cf8e0f6eb9ad..42d6de34c0e01e3f58210f8bbbcb8c860ae7bcd8 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -108,6 +108,7 @@ cc_test( deps = [ ":tfcompile_lib", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", ], @@ -126,14 +127,7 @@ cc_library( deps = [ ":tfcompile_lib", ":tfcompile_proto", - "//tensorflow/compiler/xla/legacy_flags:alias_analysis_flags", - "//tensorflow/compiler/xla/legacy_flags:buffer_assignment_flags", - "//tensorflow/compiler/xla/legacy_flags:compiler_functor_flags", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", - "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", - "//tensorflow/compiler/xla/legacy_flags:hlo_graph_dumper_flags", - "//tensorflow/compiler/xla/legacy_flags:llvm_util_flags", "//tensorflow/compiler/xla/legacy_flags:service_flags", "//tensorflow/compiler/xla/legacy_flags:util_flags", "//tensorflow/compiler/xla/service:compiler", @@ -161,6 +155,17 @@ tf_library( tags = ["manual"], ) +# A test of tf_library that includes a graph with an unknown op, but where +# the compilation works because the unknown op is not needed for the fetches. +tf_library( + name = "test_graph_tfunknownop", + testonly = 1, + config = "test_graph_tfunknownop.config.pbtxt", + cpp_class = "UnknownOpAddComp", + graph = "test_graph_tfunknownop.pbtxt", + tags = ["manual"], +) + # Utility library for benchmark binaries, used by the *_benchmark rules that are # added by the tfcompile bazel macro. cc_library( @@ -204,6 +209,7 @@ test_suite( tests = [ ":benchmark_test", ":test_graph_tfadd_test", + ":test_graph_tfunknownop_test", "//tensorflow/compiler/aot/tests:all_tests", ], ) diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index ca17c5ab690f606bd531638fece8b0a74cdd8c18..59ff14600bc70ab0e635ecf0e20b6395fbd8f5a2 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -378,9 +379,16 @@ Status CompileXla(xla::CompileOnlyClient* client, Status InitGraph(const GraphDef& graph_def, const Config& config, const MainFlags& flags, std::unique_ptr* graph) { TF_RETURN_IF_ERROR(ValidateConfig(config)); + FunctionLibraryDefinition flib_def(OpRegistry::Global(), graph_def.library()); std::unique_ptr g(new Graph(flib_def)); - GraphDef copy_def(graph_def); + + GraphDef copy_def; + + // Prune the GraphDef first so that unknown ops that we aren't compiling get + // filtered out. + TF_RETURN_IF_ERROR(PruneGraphDefInto(config, graph_def, ©_def)); + TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(©_def, *g->op_registry(), 0 /*node_offset*/)); TF_RETURN_IF_ERROR( diff --git a/tensorflow/compiler/aot/test_graph_tfunknownop.config.pbtxt b/tensorflow/compiler/aot/test_graph_tfunknownop.config.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..5625c0ab03893c997245a6449d145b9149b48627 --- /dev/null +++ b/tensorflow/compiler/aot/test_graph_tfunknownop.config.pbtxt @@ -0,0 +1,16 @@ +# Text form of tensorflow.tfcompile.Config proto. +feed { + id { node_name: "x_const" } + shape { + dim { size: 1 } + } +} +feed { + id { node_name: "y_const" } + shape { + dim { size: 1 } + } +} +fetch { + id { node_name: "x_y_sum" } +} diff --git a/tensorflow/tensorboard/components/tf_graph_app/demo/data/graph.pbtxt b/tensorflow/compiler/aot/test_graph_tfunknownop.pbtxt similarity index 53% rename from tensorflow/tensorboard/components/tf_graph_app/demo/data/graph.pbtxt rename to tensorflow/compiler/aot/test_graph_tfunknownop.pbtxt index 8b95b258df4806dcf84e3b4c1c14cd0434df8910..212ffbb5ffc78ea690fe11a3a85f98fe97876e5d 100644 --- a/tensorflow/tensorboard/components/tf_graph_app/demo/data/graph.pbtxt +++ b/tensorflow/compiler/aot/test_graph_tfunknownop.pbtxt @@ -1,90 +1,86 @@ node { - name: "life" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } + name : "x_const" + op : "Const" attr { key: "value" value { tensor { dtype: DT_INT32 tensor_shape { + dim { + size: 1 + } } - int_val: 2 + int_val: 1 } } } -} -node { - name: "universe" - op: "Const" attr { - key: "dtype" + key : "dtype" value { - type: DT_INT32 + type : DT_INT32 } } +} +node { + name : "y_const" + op : "Const" attr { key: "value" value { tensor { dtype: DT_INT32 tensor_shape { + dim { + size: 1 + } } - int_val: 40 + int_val: 2 } } } -} -node { - name: "everything" - op: "Const" attr { key: "dtype" value { type: DT_INT32 } } +} +node { + name : "x_y_sum" + op : "Add" + input : "x_const" + input : "y_const" attr { - key: "value" + key : "T" value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } + type: DT_INT32 } } } node { - name: "Add" - op: "Add" - input: "life" - input: "universe" + name : "z" + op : "SomeUnknownOp" + input : "x_const" attr { - key: "T" + key : "T" value { type: DT_INT32 } } } node { - name: "answer" - op: "Add" - input: "Add" - input: "everything" + name : "x_z_sum" + op : "Add" + input : "x_const" + input : "z" attr { - key: "T" + key : "T" value { type: DT_INT32 } } } versions { - producer: 10 + producer: 15 } diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index 6fed46b4329606baeed21dd9ee4d34849a7c50a0..12825344d58d1a3778b84f9972fa2b571db225d0 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -23,14 +23,7 @@ limitations under the License. #include "tensorflow/compiler/aot/flags.h" #include "tensorflow/compiler/aot/tfcompile.pb.h" #include "tensorflow/compiler/aot/tfcompile_util.h" -#include "tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h" #include "tensorflow/compiler/xla/legacy_flags/service_flags.h" #include "tensorflow/compiler/xla/legacy_flags/util_flags.h" #include "tensorflow/compiler/xla/service/compiler.h" @@ -133,17 +126,11 @@ int main(int argc, char** argv) { flags.target_triple = "x86_64-pc-linux"; flags.out_object = "out.o"; flags.out_header = "out.h"; + flags.entry_point = "entry"; std::vector flag_list; AppendMainFlags(&flag_list, &flags); - xla::legacy_flags::AppendAliasAnalysisFlags(&flag_list); - xla::legacy_flags::AppendBufferAssignmentFlags(&flag_list); - xla::legacy_flags::AppendCompilerFunctorFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); - xla::legacy_flags::AppendCpuRuntimeFlags(&flag_list); - xla::legacy_flags::AppendHloGraphDumperFlags(&flag_list); xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendLlvmUtilFlags(&flag_list); xla::legacy_flags::AppendServiceFlags(&flag_list); xla::legacy_flags::AppendUtilFlags(&flag_list); diff --git a/tensorflow/compiler/aot/tfcompile_util.cc b/tensorflow/compiler/aot/tfcompile_util.cc index fd073a2e2623b4b24ddc58360525886f3fc1b3ac..8774a02128e8d8d6373576e958f70585857a3081 100644 --- a/tensorflow/compiler/aot/tfcompile_util.cc +++ b/tensorflow/compiler/aot/tfcompile_util.cc @@ -15,10 +15,14 @@ limitations under the License. #include "tensorflow/compiler/aot/tfcompile_util.h" +#include #include +#include #include "tensorflow/compiler/aot/tfcompile.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -115,5 +119,51 @@ Status ValidateConfig(const Config& config) { return Status::OK(); } +Status PruneGraphDefInto(const Config& config, const GraphDef& in, + GraphDef* out) { + *out = in; + out->clear_node(); + + // Maps node name to reachability. + std::unordered_map> node_by_name; + for (const NodeDef& node : in.node()) { + node_by_name[node.name()] = std::pair(false, &node); + } + + std::queue name_queue; + for (int i = 0; i < config.fetch_size(); ++i) { + name_queue.push(config.fetch(i).id().node_name()); + } + while (!name_queue.empty()) { + const string name = name_queue.front(); + name_queue.pop(); + + auto find_it = node_by_name.find(name); + if (find_it == node_by_name.end()) { + return errors::InvalidArgument("While pruning graph, node ", name, + " needed but not found in the graph."); + } + auto& map_entry = find_it->second; + if (map_entry.first) { + continue; + } + map_entry.first = true; + + for (const string& in_edge : map_entry.second->input()) { + name_queue.push(ParseTensorName(in_edge).first.ToString()); + } + } + + // Copy over, preserving order of original and only nodes that are reachable + // from the fetches. + out->mutable_node()->Reserve(in.node_size()); + for (const NodeDef& node : in.node()) { + if (node_by_name[node.name()].first) { + *out->add_node() = node; + } + } + return Status::OK(); +} + } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/aot/tfcompile_util.h b/tensorflow/compiler/aot/tfcompile_util.h index 651d75d0d02bdac110159996498778d2c57ddf78..84060c0761e5bea3bd589bb27bbbbae3bb1bf659 100644 --- a/tensorflow/compiler/aot/tfcompile_util.h +++ b/tensorflow/compiler/aot/tfcompile_util.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_ #include "tensorflow/compiler/aot/tfcompile.pb.h" +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -30,6 +31,11 @@ Status ValidateCppIdent(StringPiece ident, StringPiece msg); // ValidateConfig returns OK iff config is valid. Status ValidateConfig(const Config& config); +// Returns in a copy of , pruned to only include fetches from +// . +Status PruneGraphDefInto(const Config& config, const GraphDef& in, + GraphDef* out); + } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/aot/tfcompile_util_test.cc b/tensorflow/compiler/aot/tfcompile_util_test.cc index c321d3ff4c779fbd2e9c67dfc1eb24c734a9103f..5a92851ceb972ca63a8a3845eb4730fe198762dd 100644 --- a/tensorflow/compiler/aot/tfcompile_util_test.cc +++ b/tensorflow/compiler/aot/tfcompile_util_test.cc @@ -15,9 +15,11 @@ limitations under the License. #include "tensorflow/compiler/aot/tfcompile_util.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -180,6 +182,65 @@ TEST(ValidateConfig, ConflictingFetchName) { ExpectErrorContains(ValidateConfig(config), "conflicting fetch name"); } +static Config FetchesConfig(std::vector fetches) { + Config config; + for (const auto& fetch_node_name : fetches) { + auto* fetch = config.add_fetch(); + fetch->set_name(strings::StrCat("fetch_", fetch_node_name)); + fetch->mutable_id()->set_node_name(fetch_node_name); + } + return config; +} + +TEST(PruneGraphDefInto, Basic) { + GraphDef def; + auto* n = def.add_node(); + n->set_name("a"); + n->add_input("b:0"); + n->add_input("^c"); + + GraphDef copy; + ExpectErrorContains(PruneGraphDefInto(FetchesConfig({"missing"}), def, ©), + "node missing needed"); + ExpectErrorContains(PruneGraphDefInto(FetchesConfig({"a"}), def, ©), + "node b needed"); + + n = def.add_node(); + n->set_name("b"); + ExpectErrorContains(PruneGraphDefInto(FetchesConfig({"a"}), def, ©), + "node c needed"); + n->add_input("d:1"); + + n = def.add_node(); + n->set_name("c"); + n->add_input("d:1"); + + n = def.add_node(); + n->set_name("d"); + + // Graph is full, no pruning done. + // Graph right now has diamond from d: + // d --> b --> a + // d --> c --> a + TF_EXPECT_OK(PruneGraphDefInto(FetchesConfig({"a"}), def, ©)); + EXPECT_EQ(def.DebugString(), copy.DebugString()); + GraphDef pruned_a = copy; + + // Add some unrelated fields that use b and c, but are not needed for a. + n = def.add_node(); + n->set_name("e"); + n->add_input("^d"); + n->add_input("b:2"); + copy.Clear(); + TF_EXPECT_OK(PruneGraphDefInto(FetchesConfig({"a"}), def, ©)); + EXPECT_EQ(pruned_a.DebugString(), copy.DebugString()); + + // Fetch "a" and "e" to get the original graph. + copy.Clear(); + TF_EXPECT_OK(PruneGraphDefInto(FetchesConfig({"a", "e"}), def, ©)); + EXPECT_EQ(def.DebugString(), copy.DebugString()); +} + } // namespace } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 5f857191da78ddd68c5689f9c4f467c01300ca7c..8b2d0b7659a5e404311be55c2b3497e2bcc5fa15 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -15,27 +15,16 @@ package_group( ) package( - default_visibility = [":internal"], + default_visibility = [ + ":internal", + "//tensorflow/compiler/plugin/executor:__pkg__", + ], ) load("//tensorflow:tensorflow.bzl", "cc_header_only_library") load("//tensorflow:tensorflow.bzl", "tf_kernel_library") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") -# This target can be used by XLA device plugins to prevent circular -# dependencies, and provides access to all of the required headers -# for building a device library. -cc_header_only_library( - name = "xla_jit_headers_lib", - visibility = ["//visibility:public"], - deps = [ - ":xla_cpu_device", - ":xla_cpu_jit", - ":xla_gpu_device", - ":xla_gpu_jit", - ], -) - # Target that bundles up the XLA CPU and GPU JIT devices. cc_library( name = "jit", @@ -150,6 +139,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:tensorflow_opensource", "//tensorflow/core/kernels:constant_op", @@ -283,3 +273,15 @@ filegroup( ), visibility = ["//tensorflow:__subpackages__"], ) + +# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library. +cc_header_only_library( + name = "xla_jit_headers_lib", + visibility = ["//visibility:public"], + deps = [ + ":xla_cpu_device", + ":xla_cpu_jit", + ":xla_gpu_device", + ":xla_gpu_jit", + ], +) diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 14d8f2ab351bd99dd3fe42a9ac6e31062d552ff0..a1ddad3e9b8191ee4d783136d2b509ec15d993d1 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/optimization_registry.h" -#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/node_def_builder.h" diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index c4116cb8b52adc191e9f695bc9a6e0cf413b4b5c..97f3512a6c43cfe179c8801e0872cbf77ef5fde5 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -2,6 +2,7 @@ licenses(["notice"]) # Apache 2.0 package( default_visibility = [ + "//tensorflow/compiler/plugin/executor:__pkg__", "//tensorflow/compiler/tf2xla:internal", ], ) @@ -38,6 +39,7 @@ cc_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/jit/kernels/parallel_check_op.cc b/tensorflow/compiler/jit/kernels/parallel_check_op.cc index c86e03118b53ddf4865b7995b1d197c3ef07ba29..bd4eefbc0bb960f8ddc1d238057e73a29a098f26 100644 --- a/tensorflow/compiler/jit/kernels/parallel_check_op.cc +++ b/tensorflow/compiler/jit/kernels/parallel_check_op.cc @@ -64,7 +64,7 @@ class ParallelCheckOp : public OpKernel { ok = (diff <= tolerance); } if (ok) continue; - LOG(ERROR) << "Op " << def().name() << " fails equality at output " + LOG(ERROR) << "Op " << name() << " fails equality at output " << input_idx << " type " << DataTypeString(dtype) << " element " << i << ": std_val=" << p0[i] << " test_val=" << p1[i] << " diff=" << (p0[i] - p1[i]); @@ -75,7 +75,7 @@ class ParallelCheckOp : public OpKernel { } void Compute(OpKernelContext* ctx) override { - VLOG(1) << "Compute " << def().name(); + VLOG(1) << "Compute " << name(); const int num_pairs = ctx->num_inputs() / 2; for (int i = 0; i < num_pairs; ++i) { CHECK_EQ(ctx->input_dtype(i), ctx->input_dtype(i + num_pairs)); @@ -113,7 +113,7 @@ class ParallelCheckOp : public OpKernel { LOG(FATAL) << "unimpl: " << ctx->input_dtype(i); } if (failed > 0) { - LOG(ERROR) << "check failed for " << def().name() << " output " << i + LOG(ERROR) << "check failed for " << name() << " output " << i << " num_elts: " << num_elts; legacy_flags::ParallelCheckOpFlags* flags = legacy_flags::GetParallelCheckOpFlags(); @@ -121,7 +121,7 @@ class ParallelCheckOp : public OpKernel { LOG(QFATAL) << "failfast on first parallel-check failure"; } } else { - VLOG(1) << "check passed for " << def().name() << " output " << i + VLOG(1) << "check passed for " << name() << " output " << i << " num_elts: " << num_elts; } diff --git a/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc index 29c5ff724299ec84d31268c4227259ec02d10742..bd051d06ae961ea86fd864c1b051a30572779e64 100644 --- a/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device_context.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/framework/allocator.h" @@ -149,6 +150,8 @@ void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) { xla::ExecutionOptions execution_options; *execution_options.mutable_shape_with_output_layout() = kernel->xla_output_shape; + *execution_options.mutable_debug_options() = + xla::legacy_flags::GetDebugOptionsFromFlags(); Env* env = Env::Default(); auto start_time = env->NowMicros(); VLOG(1) << "Executing XLA Computation..."; @@ -202,8 +205,8 @@ void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) { // Apply variable updates, if any. VLOG(2) << "Applying variable updates"; - for (int i = 0; i < kernel->variable_updates.size(); ++i) { - const XlaCompiler::VariableUpdate& write = kernel->variable_updates[i]; + for (int i = 0; i < kernel->resource_updates.size(); ++i) { + const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i]; OP_REQUIRES(ctx, write.input_index >= 0 && write.input_index < ctx->num_inputs(), errors::Internal("Invalid input index for variable write.")); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index f1fef85f994a5f1f7514a5cb8b8b339706c7d998..7eab7bb28f0e2260b68481ee92ae56a9b55f86b0 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/memory_types.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/algorithm.h" @@ -162,10 +163,12 @@ Status DeviceTypeOfDevice(const string& device, DeviceType* device_type) { return Status::OK(); } -// Does `node` have a DT_RESOURCE typed argument? -bool HasResourceArgument(const Node& node) { +// Tests whether `node` has a DT_RESOURCE typed input or output. +bool HasResourceInputOrOutput(const Node& node) { return std::find(node.input_types().begin(), node.input_types().end(), - DT_RESOURCE) != node.input_types().end(); + DT_RESOURCE) != node.input_types().end() || + std::find(node.output_types().begin(), node.output_types().end(), + DT_RESOURCE) != node.output_types().end(); } Status FindCompilationCandidates( @@ -193,9 +196,10 @@ Status FindCompilationCandidates( << ": " << node->type_string(); continue; } - if (!registration->compile_resource_ops && HasResourceArgument(*node)) { - VLOG(2) << "Compilation rejected node: resource argument " << node->name() - << ": " << node->type_string(); + if (!registration->compile_resource_ops && + HasResourceInputOrOutput(*node)) { + VLOG(2) << "Compilation rejected node: resource input/output " + << node->name() << ": " << node->type_string(); continue; } if (node->type_string() == "While" && diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 9f30e12e0e30fef6b4bcd0ea3c091842b008c29a..4b88da27a188ed4fa6125b3e7a84034efb1a0ec1 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -14,11 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" -#include "tensorflow/compiler/jit/defs.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -455,5 +457,39 @@ TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) { EXPECT_EQ(clusters["B"], clusters["C"]); } +REGISTER_OP("ResourceInput").Input("a: resource").Output("o: float"); +REGISTER_OP("ResourceOutput").Input("a: float").Output("o: resource"); + +namespace { + +class DummyOp : public XlaOpKernel { + using XlaOpKernel::XlaOpKernel; + void Compile(XlaOpKernelContext* ctx) override {} +}; + +REGISTER_XLA_OP(Name("ResourceInput"), DummyOp); +REGISTER_XLA_OP(Name("ResourceOutput"), DummyOp); + +} // namespace + +TEST(XlaCompilationTest, Resources) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + GraphDef graphdef; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = + ops::SourceOp("UncompilableNullary", builder.opts().WithName("A")); + Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); + // We should not form clusters with resource ops by default. + Node* c = ops::UnaryOp("ResourceOutput", b, builder.opts().WithName("C")); + Node* d = ops::UnaryOp("ResourceInput", c, builder.opts().WithName("D")); + ops::UnaryOp("Relu", d, builder.opts().WithName("E")); + TF_EXPECT_OK(builder.ToGraph(graph.get())); + } + MarkForCompilation(&graph); + auto clusters = GetClusters(*graph); + EXPECT_EQ(0, clusters.size()); // Nothing should be compiled. +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD index 8d1fa03cc0d74f3a61b3e2e1d6f2af07c0bcd23f..e5787ca4c8cff436e4404b8488970248b24a5eda 100644 --- a/tensorflow/compiler/jit/ops/BUILD +++ b/tensorflow/compiler/jit/ops/BUILD @@ -1,32 +1,20 @@ licenses(["notice"]) # Apache 2.0 package( - default_visibility = [ - "//tensorflow/compiler/tf2xla:internal", - ], + default_visibility = ["//tensorflow/compiler/tf2xla:internal"], ) cc_library( name = "xla_ops", - srcs = [ - "xla_ops.cc", - ], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - ], + srcs = ["xla_ops.cc"], + deps = ["//tensorflow/core:framework"], alwayslink = 1, ) cc_library( name = "parallel_check_op", srcs = ["parallel_check_op.cc"], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - ], + deps = ["//tensorflow/core:framework"], alwayslink = 1, ) diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 63ca77f9a912acce2078f3da43d64f2e10049380..2325217b973a29285362953cbdf9cc01437ce3f6 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -182,17 +182,18 @@ Status BuildArguments(int num_constant_args, XlaCompiler::Argument& arg = (*args)[input_num]; arg.name = variable_args[variable_id].name; + arg.kind = XlaCompiler::Argument::kVariable; if (variable_args[variable_id].present) { const Tensor& value = variable_args[variable_id].value; - arg.kind = XlaCompiler::Argument::kVariable; arg.type = value.dtype(); arg.shape = value.shape(); + arg.initialized = true; } else { // The values of uninitialized variables are not passed as inputs, since // they are meaningless. However, it is legal to assign to a resource // variable for the first time inside the XLA computation, so we do permit // uninitialized variables. - arg.kind = XlaCompiler::Argument::kUninitializedVariable; + arg.initialized = false; arg.type = DT_INVALID; arg.shape = TensorShape(); } diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 5e336c5287bd9e2067e93cd8db8a5a1b62b62bd2..615e2230f42f63f893ad645e1ab9513d6c30abf5 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -31,9 +31,11 @@ limitations under the License. #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/notification.h" diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index f329e83e14dfce68eff3feb720c1603bd36fa7d6..0ab81ebd5ffec0b3dd6aee509a6d4d2b41d156db 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -137,7 +137,7 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, done(result.status()); return; } - const void* src_ptr = xla::LiteralUtil::InternalData(*result.ValueOrDie()); + const void* src_ptr = result.ValueOrDie()->InternalData(); void* dst_ptr = DMAHelper::base(cpu_tensor); size_t total_bytes = cpu_tensor->TotalBytes(); memcpy(dst_ptr, src_ptr, total_bytes); diff --git a/tensorflow/compiler/plugin/executor/BUILD b/tensorflow/compiler/plugin/executor/BUILD index 9bc706abdf646a32da734906cada727d949eee21..2e5875705f2657c30be61fa1781422b8c7325765 100644 --- a/tensorflow/compiler/plugin/executor/BUILD +++ b/tensorflow/compiler/plugin/executor/BUILD @@ -11,9 +11,11 @@ cc_library( "*.h", ]), deps = [ + "//tensorflow/compiler/jit:xla_device", "//tensorflow/compiler/jit:xla_jit_headers_lib", + "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:xla_headers_lib", - "//tensorflow/compiler/xla/service:hlo_evaluator", + "//tensorflow/compiler/xla/service", "//third_party/eigen3", "@local_config_cuda//cuda:cuda_headers", "@protobuf//:protobuf_headers", diff --git a/tensorflow/compiler/plugin/executor/compiler.cc b/tensorflow/compiler/plugin/executor/compiler.cc index 893ff152f0c77c354be178818eaf9e8fc75feaa4..72fe7ba4519833e17314f8fef803ad0230713780 100644 --- a/tensorflow/compiler/plugin/executor/compiler.cc +++ b/tensorflow/compiler/plugin/executor/compiler.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/plugin/executor/compiler.h" #include "tensorflow/compiler/plugin/executor/executable.h" - #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" @@ -30,27 +29,23 @@ limitations under the License. #include "tensorflow/compiler/xla/service/inliner.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/status_macros.h" - +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/stream_executor/lib/initialize.h" #include "tensorflow/stream_executor/lib/strcat.h" -#include "tensorflow/core/lib/core/errors.h" +namespace xla { +namespace executorplugin { namespace se = ::perftools::gputools; namespace sep = ::perftools::gputools::executorplugin; -namespace port = ::perftools::gputools::port; - -namespace xla { -namespace executorplugin { /* * Run optimization passes on the module. The graph is transformed by * each pass in the optimization pipeline. The service subdirectory * contains useful optimization passes. */ -Status ExecutorCompiler::RunHloOptimization(HloModule* hlo_module, - HloDumper dump_hlo) { - HloPassPipeline pipeline("Executor", dump_hlo); +Status ExecutorCompiler::RunHloOptimization(HloModule* hlo_module) { + HloPassPipeline pipeline("Executor"); pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(false); @@ -67,13 +62,13 @@ Status ExecutorCompiler::RunHloOptimization(HloModule* hlo_module, } StatusOr> ExecutorCompiler::Compile( - std::unique_ptr hlo_module, HloDumper dump_hlo, + std::unique_ptr hlo_module, se::StreamExecutor* stream_exec) { TF_RET_CHECK(stream_exec != nullptr); VLOG(1) << "Generate graph " << hlo_module->name(); - TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get(), dump_hlo)); + TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get())); // Typically you would visit the HLO graph, building up a compiled equivalent // In this case we are using an Hlo evaluator at execution time, so we don't @@ -88,7 +83,7 @@ StatusOr> ExecutorCompiler::Compile( StatusOr>> ExecutorCompiler::Compile( std::vector> hlo_modules, - HloDumper dump_hlos, std::vector stream_execs) { + std::vector stream_execs) { return tensorflow::errors::Unimplemented( "Compilation of multiple HLO modules is not supported on Executor."); @@ -97,7 +92,7 @@ StatusOr>> ExecutorCompiler::Compile( StatusOr>> ExecutorCompiler::CompileAheadOfTime( std::vector> hlo_modules, - HloDumper dump_hlo, const AotCompilationOptions& aot_options) { + const AotCompilationOptions& aot_options) { return tensorflow::errors::InvalidArgument( "AOT compilation not supported on Executor"); @@ -112,12 +107,11 @@ ExecutorCompiler::ShapeSizeBytesFunction() const { return ExecutorExecutable::ShapeSizeBytes; } - -} // namespace executorplugin -} // namespace xla - REGISTER_MODULE_INITIALIZER(executor_compiler, { xla::Compiler::RegisterCompilerFactory(sep::kExecutorPlatformId, []() { return xla::MakeUnique(); }); }); + +} // namespace executorplugin +} // namespace xla diff --git a/tensorflow/compiler/plugin/executor/compiler.h b/tensorflow/compiler/plugin/executor/compiler.h index 8fe591c8abd57933aafa6c82159b49aad45a42d5..d318eefc49f0f1983cf58802d56e71b799944b11 100644 --- a/tensorflow/compiler/plugin/executor/compiler.h +++ b/tensorflow/compiler/plugin/executor/compiler.h @@ -35,25 +35,23 @@ class ExecutorCompiler : public Compiler { StatusOr> Compile( std::unique_ptr hlo_module, - HloDumper dump_hlo, perftools::gputools::StreamExecutor* stream_exec) override; StatusOr>> Compile( std::vector> hlo_module, - HloDumper dump_hlo, std::vector stream_exec) override; StatusOr>> CompileAheadOfTime( std::vector> module, - HloDumper dump_hlo, const AotCompilationOptions& options) override; + const AotCompilationOptions& options) override; HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override; perftools::gputools::Platform::Id PlatformId() const override; private: - Status RunHloOptimization(HloModule* hlo_module, HloDumper dump_hlo); + Status RunHloOptimization(HloModule* hlo_module); TF_DISALLOW_COPY_AND_ASSIGN(ExecutorCompiler); }; diff --git a/tensorflow/compiler/plugin/executor/executable.cc b/tensorflow/compiler/plugin/executor/executable.cc index 92a517ba533cb073dac9b37179825d089e29f3ab..4673a90e0a9251bd2f62df866bd3dadcd1cef756 100644 --- a/tensorflow/compiler/plugin/executor/executable.cc +++ b/tensorflow/compiler/plugin/executor/executable.cc @@ -15,18 +15,16 @@ limitations under the License. #include "tensorflow/compiler/plugin/executor/executable.h" #include "tensorflow/compiler/plugin/executor/executor.h" - -#include "tensorflow/compiler/xla/service/hlo_evaluator.h" - #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/shape_util.h" -namespace se = ::perftools::gputools; -namespace sep = ::perftools::gputools::executorplugin; - namespace xla { namespace executorplugin { +namespace se = ::perftools::gputools; +namespace sep = ::perftools::gputools::executorplugin; + ExecutorExecutable::ExecutorExecutable(std::unique_ptr hlo_module) : Executable(std::move(hlo_module), ShapeSizeBytes) {} @@ -36,7 +34,7 @@ static se::DeviceMemoryBase AllocateSingleOutput(sep::ExecutorExecutor* executor const Literal& literal) { int64 size(xla::ShapeUtil::ByteSizeOf(literal.shape())); void* buf = executor->Allocate(size); - const void* src = LiteralUtil::InternalData(literal); + const void* src = literal.InternalData(); memcpy(buf, src, size); return se::DeviceMemoryBase(buf, size); } @@ -86,19 +84,18 @@ StatusOr ExecutorExecutable::ExecuteOnStream( for (int64 p = 0; p < computation->num_parameters(); p++) { // Create the input literal for the parameter HloInstruction* param = computation->parameter_instruction(p); - arg_literals.emplace_back(LiteralUtil::CreateFromShape(param->shape())); + arg_literals.emplace_back(Literal::CreateFromShape(param->shape())); arg_literals_ptrs.push_back(arg_literals.back().get()); // Copy in the data from the stream_executor buffers - void* buffer = LiteralUtil::MutableInternalData(arg_literals.back().get()); + void* buffer = arg_literals.back()->MutableInternalData(); memcpy(buffer, arguments[p].opaque(), ShapeUtil::ByteSizeOf(param->shape())); } // Execute the graph using the evaluator HloEvaluator evaluator; - std::unique_ptr output; - TF_ASSIGN_OR_RETURN(output, + TF_ASSIGN_OR_RETURN(std::unique_ptr output, evaluator.Evaluate(computation, arg_literals_ptrs)); // Copy the result into the return buffer diff --git a/tensorflow/compiler/plugin/executor/executor.cc b/tensorflow/compiler/plugin/executor/executor.cc index e72c2711f794792fd4d7834b07eee5d983dff0a0..908b996bc95ac8d36f6c5577857b1a3a3826c3d4 100644 --- a/tensorflow/compiler/plugin/executor/executor.cc +++ b/tensorflow/compiler/plugin/executor/executor.cc @@ -14,14 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/plugin/executor/executor.h" -#include "tensorflow/compiler/plugin/executor/platform_id.h" - -#include "tensorflow/compiler/xla/status_macros.h" #include #include -namespace se = ::perftools::gputools; +#include "tensorflow/compiler/plugin/executor/platform_id.h" +#include "tensorflow/compiler/xla/status_macros.h" namespace perftools { namespace gputools { @@ -37,10 +35,7 @@ ExecutorExecutor::ExecutorExecutor(const PluginConfig &plugin_config) ExecutorExecutor::~ExecutorExecutor() {} -void *ExecutorExecutor::Allocate(uint64 size) { - void *buf = new char[size]; - return buf; -} +void *ExecutorExecutor::Allocate(uint64 size) { return new char[size]; } void *ExecutorExecutor::AllocateSubBuffer(DeviceMemoryBase *parent, uint64 offset_bytes, @@ -126,8 +121,7 @@ DeviceDescription *ExecutorExecutor::PopulateDeviceDescription() const { builder.set_device_memory_size(static_cast(4) * 1024 * 1024 * 1024); builder.set_clock_rate_ghz(static_cast(CLOCKS_PER_SEC) / 1e9); - auto built = builder.Build(); - return built.release(); + return builder.Build().release(); } } // namespace executorplugin diff --git a/tensorflow/compiler/plugin/executor/transfer_manager.cc b/tensorflow/compiler/plugin/executor/transfer_manager.cc index b59d20a7791f1ed2df2f35c6186e34e64fe4b248..51c5deeea5d5fd03d0fb99d4f33413c7bf4abe0f 100644 --- a/tensorflow/compiler/plugin/executor/transfer_manager.cc +++ b/tensorflow/compiler/plugin/executor/transfer_manager.cc @@ -70,13 +70,13 @@ Status ExecutorTransferManager::TransferLiteralFromDevice( } *literal->mutable_shape() = device_shape; - LiteralUtil::Reserve(ShapeUtil::ElementsIn(device_shape), literal); + literal->Reserve(ShapeUtil::ElementsIn(device_shape)); TF_RETURN_IF_ERROR(TransferBufferFromDevice( executor, source, ShapeUtil::ByteSizeOf(device_shape), - LiteralUtil::MutableInternalData(literal))); + literal->MutableInternalData())); if (!ShapeUtil::Equal(literal_shape, device_shape)) { literal->Swap( - LiteralUtil::Relayout(*literal, literal_shape.layout()).get()); + literal->Relayout(literal_shape.layout()).get()); } TF_RET_CHECK(ShapeUtil::Equal(literal_shape, literal->shape())); return Status::OK(); @@ -134,7 +134,7 @@ Status ExecutorTransferManager::TransferLiteralToDevice( } return TransferBufferToDevice(executor, GetByteSizeRequirement(shape), - LiteralUtil::InternalData(literal), + literal.InternalData(), destination); } @@ -147,6 +147,11 @@ Status ExecutorTransferManager::TransferLiteralToInfeed( return Status::OK(); } +Status ExecutorTransferManager::TransferBufferToInfeed( + se::StreamExecutor* executor, int64 size, const void* source) { + return Unimplemented("Transfer to Infeed"); +} + Status ExecutorTransferManager::TransferLiteralFromOutfeed( perftools::gputools::StreamExecutor* executor, const Shape& literal_shape, Literal* literal) { diff --git a/tensorflow/compiler/plugin/executor/transfer_manager.h b/tensorflow/compiler/plugin/executor/transfer_manager.h index 22142cd778a0aeccb6c393bdc1593e6213de858a..7a42e5a2d7542eaad7f8f90f011c65a9c526cc11 100644 --- a/tensorflow/compiler/plugin/executor/transfer_manager.h +++ b/tensorflow/compiler/plugin/executor/transfer_manager.h @@ -55,6 +55,9 @@ class ExecutorTransferManager : public TransferManager { Status TransferLiteralToInfeed(se::StreamExecutor* executor, const Literal& literal) override; + Status TransferBufferToInfeed(se::StreamExecutor* executor, + int64 size, const void* source) override; + Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, const Shape& literal_shape, Literal* literal) override; diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 4bbb2767ac033dd9995cad37886d476fc87618da..432b24756d2247ce46b700aaf74ac28088349357 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -40,6 +40,7 @@ py_library( "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform", + "//tensorflow/python:random_seed", "//tensorflow/python:variables", ], ) @@ -174,6 +175,11 @@ tf_xla_py_test( name = "slice_ops_test", size = "small", srcs = ["slice_ops_test.py"], + # TODO(b/62962492): Test fails with assertion error. + tags = [ + "manual", + "notap", + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -323,7 +329,7 @@ tf_xla_py_test( tf_xla_py_test( name = "reverse_ops_test", - size = "small", + size = "medium", srcs = ["reverse_ops_test.py"], deps = [ ":xla_test", @@ -455,6 +461,11 @@ cuda_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", ], + # TODO(b/62961789): Test fails with SIGABRT + tags = [ + "manual", + "notap", + ], ) cc_library( diff --git a/tensorflow/compiler/tests/ftrl_test.py b/tensorflow/compiler/tests/ftrl_test.py index 6b328fb618bf8b9174dce756487494994b8aea04..79182768499074b4b409924374891ec58a9d11e0 100644 --- a/tensorflow/compiler/tests/ftrl_test.py +++ b/tensorflow/compiler/tests/ftrl_test.py @@ -218,7 +218,40 @@ class FtrlOptimizerTest(XLATestCase): self.assertAllClose(np.array([-0.24059935, -0.46829352]), var0.eval()) self.assertAllClose(np.array([-0.02406147, -0.04830509]), var1.eval()) - # When variables are intialized with Zero, FTRL-Proximal has two properties: + def testFtrlWithL1_L2_L2Shrinkage(self): + """Test the new FTRL op with support for l2 shrinkage. + + The addition of this parameter which places a constant pressure on weights + towards the origin causes the gradient descent trajectory to differ. The + weights will tend to have smaller magnitudes with this parameter set. + """ + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.02], dtype=dtype) + opt = ftrl.FtrlOptimizer( + 3.0, + initial_accumulator_value=0.1, + l1_regularization_strength=0.001, + l2_regularization_strength=2.0, + l2_shrinkage_regularization_strength=0.1) + ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([4.0, 3.0], var1.eval()) + + # Run 10 steps FTRL + for _ in range(10): + ftrl_update.run() + + # Validate updated params + self.assertAllClose(np.array([-0.21931979, -0.40642974]), var0.eval()) + self.assertAllClose(np.array([-0.0282721, -0.07188385]), var1.eval()) + + # When variables are initialized with Zero, FTRL-Proximal has two properties: # 1. Without L1&L2 but with fixed learning rate, FTRL-Proximal is identical # with GradientDescent. # 2. Without L1&L2 but with adaptive learning rate, FTRL-Proximal is idential diff --git a/tensorflow/compiler/tests/spacetobatch_op_test.py b/tensorflow/compiler/tests/spacetobatch_op_test.py index 9c3b86c84b2b92089da0dfc0070a4a7b8a03c81a..c013f4b50a4cf95be8028248c52b10b1c3be2bd3 100644 --- a/tensorflow/compiler/tests/spacetobatch_op_test.py +++ b/tensorflow/compiler/tests/spacetobatch_op_test.py @@ -228,34 +228,40 @@ class SpaceToBatchNDTest(XLATestCase): outputs=[[[0, 0], [2, 21]], [[0, 0], [5, 51]], [[1, 11], [3, 31]], [[4, 41], [6, 61]]]) - def testDirect(self): + def testDirect0(self): # Test with zero-size remaining dimension. self._testDirect( input_shape=[3, 1, 2, 0], block_shape=[3], paddings=[[0, 2]]) + def testDirect1(self): # Test with zero-size blocked dimension. self._testDirect( input_shape=[3, 0, 2, 5], block_shape=[3], paddings=[[0, 0]]) + def testDirect2(self): # Test with padding up from zero size. self._testDirect( input_shape=[3, 0, 2, 5], block_shape=[3], paddings=[[1, 2]]) + def testDirect3(self): self._testDirect( input_shape=[3, 3, 4, 5, 2], block_shape=[3, 4, 2], paddings=[[1, 2], [0, 0], [3, 0]]) + def testDirect4(self): self._testDirect( input_shape=[3, 3, 4, 5, 2], block_shape=[3, 4, 2, 2], paddings=[[1, 2], [0, 0], [3, 0], [0, 0]]) + def testDirect5(self): self._testDirect( input_shape=[3, 2, 2, 3, 4, 5, 2, 5], block_shape=[1, 1, 3, 4, 2, 2], paddings=[[0, 0], [0, 0], [1, 2], [0, 0], [3, 0], [0, 0]]) + def testDirect6(self): self._testDirect( input_shape=[3, 2, 2, 3, 4, 5, 2, 5], block_shape=[1, 1, 3, 4, 2, 2, 1], diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index 27a29773053e08c755afce23c3257d96ce27a929..b3067be51dd3ea0a43930ad3265a6330b3eec8e9 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -335,7 +335,7 @@ class TensorArrayTest(xla_test.XLATestCase): r0_bad = gen_data_flow_ops._tensor_array_read_v3( handle=w0.handle, index=0, dtype=dtypes.float64, flow_in=w0.flow) with self.assertRaisesOpError( - "TensorArray dtype is float but Op requested dtype double."): + "TensorArray dtype is float but op has dtype double."): r0_bad.eval() # Test reading from a different index than the one we wrote to @@ -573,13 +573,12 @@ class TensorArrayTest(xla_test.XLATestCase): [2000.0, -2000.0]], grad_vals[0]) - # TODO(phawkins): implement TensorArrayClose - # def testCloseTensorArray(self): - # with self.test_session() as session, self.test_scope(): - # ta = tensor_array_ops.TensorArray( - # dtype=dtypes.float32, tensor_array_name="foo", size=3) - # c1 = ta.close() - # session.run(c1) + def testCloseTensorArray(self): + with self.test_session() as session, self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=3) + c1 = ta.close() + session.run(c1) def testSizeTensorArray(self): with self.test_session(), self.test_scope(): @@ -588,17 +587,16 @@ class TensorArrayTest(xla_test.XLATestCase): s = ta.size() self.assertAllEqual(3, s.eval()) - # TODO(phawkins): implement TensorArrayClose - # def testWriteCloseTensorArray(self): - # with self.test_session(), self.test_scope(): - # ta = tensor_array_ops.TensorArray( - # dtype=dtypes.float32, - # tensor_array_name="foo", - # size=3, - # infer_shape=False) - # w0 = ta.write(0, [[4.0, 5.0]]) - # w1 = w0.write(1, [3.0]) - # w1.close().run() # Expected to run without problems + def testWriteCloseTensorArray(self): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=3, + infer_shape=False) + w0 = ta.write(0, [[4.0, 5.0]]) + w1 = w0.write(1, [3.0]) + w1.close().run() # Expected to run without problems # TODO(phawkins): implement while loops. # def _testWhileLoopWritePackGradients(self, dynamic_size, dtype): diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 93c484ca7a0d04654371724aac905eb055c82b05..997ecd7ebb5b6398167a1ba309f18ef9ec837865 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -42,6 +42,7 @@ cc_library( deps = [ ":common", ":dump_graph", + ":functionalize_control_flow", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -152,7 +153,6 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", ], ) @@ -165,13 +165,10 @@ cc_test( "//tensorflow/cc:cc_ops", "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", - "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", "//tensorflow/core:ops", "//tensorflow/core:test", "//tensorflow/core:test_main", - "//tensorflow/core:testlib", ], ) @@ -203,6 +200,59 @@ cc_library( ], ) +cc_library( + name = "functionalize_control_flow", + srcs = ["functionalize_control_flow.cc"], + hdrs = ["functionalize_control_flow.h"], + deps = [ + "//tensorflow/compiler/jit:graph_to_functiondef", + "//tensorflow/compiler/tf2xla:dump_graph", + "//tensorflow/compiler/tf2xla/ops:functional_ops", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + ], +) + +cc_test( + name = "functionalize_control_flow_test", + srcs = ["functionalize_control_flow_test.cc"], + deps = [ + ":functionalize_control_flow", + ":test_util", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:resource_variable_ops", + "//tensorflow/compiler/tf2xla/cc:functional_ops", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:ops", + "//tensorflow/core:resource_variable_ops_op_lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "test_util", + testonly = 1, + srcs = ["test_util.cc"], + hdrs = ["test_util.h"], + deps = [ + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/tf2xla/cc/BUILD b/tensorflow/compiler/tf2xla/cc/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..599265ba449c88baef1671b1c81d96d1715ce5f2 --- /dev/null +++ b/tensorflow/compiler/tf2xla/cc/BUILD @@ -0,0 +1,44 @@ +package( + default_visibility = ["//tensorflow/compiler/tf2xla:internal"], +) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_cc") + +tf_gen_op_wrapper_cc( + name = "functional_ops_gen", + include_internal_ops = 1, + out_ops_file = "ops/functional_ops", + deps = ["//tensorflow/compiler/tf2xla/ops:functional_ops"], +) + +cc_library( + name = "functional_ops", + srcs = ["ops/functional_ops.cc"], + hdrs = ["ops/functional_ops.h"], + deps = [ + "//tensorflow/cc:const_op", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/compiler/tf2xla/ops:functional_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc new file mode 100644 index 0000000000000000000000000000000000000000..faa88ecfe2efdccbb75e3ae2e81452b33393a307 --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -0,0 +1,569 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/jit/graph_to_functiondef.h" +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/graph/control_flow.h" + +namespace tensorflow { + +namespace { + +const char* const kArgOp = "_Arg"; +const char* const kRetValOp = "_Retval"; + +// Information about a loop argument. +struct Arg { + // Every loop argument has an Enter node. + Node* enter; + + // Is the loop argument a loop-invariant value? Taken from the `is_constant` + // attribute on the Enter node. + bool is_loop_invariant; + + // If 'is_loop_invariant' is true, the following are all nullptr. Non-constant + // arguments must have all of the following nodes: + Node* merge = nullptr; + Node* switch_node = nullptr; + Node* next_iteration = nullptr; + Node* exit = nullptr; +}; + +// Information about a loop frame. +struct Frame { + string name; + + // Pointer to the parent frame. The root frame has a pointer to itself. + Frame* parent = nullptr; + int num_children = 0; + + // Arguments to this loop. + std::vector args; + + // The loop condition of the loop. There should be exactly one loop condition + // in every loop. + Node* loop_cond = nullptr; + + // Set of nodes that belong to the loop frame. + std::unordered_set nodes; +}; + +// Copies a subgraph from `graph` to `output` by performing a reverse DFS +// starting at nodes in vector `stack`. +// `node_map` is a vector indexed by source node ID to dest nodes. +// Does not traverse into nodes in `node_map`, so by adding nodes to `node_map` +// before the traversal clients can cut the graph. Returns an error if the +// traversal leaves 'frame'; the client must add enough nodes to `node_map` to +// cut the graph and prevent the traversal from escaping. +// +// `squash_src_outputs` contains a bool for each source node ID. If true, then +// the source output on that node will be replaced by zero when copied. This is +// used when replacing a Switch node with an _Arg node. The output we are +// taking from the Switch node was not necessarily the first output, but _Arg +// nodes only have one output. By adding the Switch node to `squash_src_outputs` +// we rewrite the src_output of the corresponding edge to be 0. +Status CopySubgraph(const Graph& graph, const Frame& frame, + std::vector stack, + const std::vector& squash_src_outputs, + std::vector* node_map, Graph* output) { + std::vector visited(graph.num_node_ids(), false); + while (!stack.empty()) { + Node* n = stack.back(); + stack.pop_back(); + + VLOG(3) << "Copying node " << n->name(); + + if (visited[n->id()]) continue; + visited[n->id()] = true; + + for (const Edge* e : n->in_edges()) { + Node* src = e->src(); + if (frame.nodes.find(src) == frame.nodes.end()) { + // We traversed out of the loop frame, without encountering a cut node. + return errors::Internal("Graph traversal of loop frame ", frame.name, + " escaped frame at ", src->name(), + " without encountering an argument node."); + } + if ((*node_map)[src->id()] == nullptr) { + (*node_map)[src->id()] = output->CopyNode(src); + stack.push_back(src); + } + Node* src_copy = (*node_map)[e->src()->id()]; + int src_output = squash_src_outputs[e->src()->id()] ? 0 : e->src_output(); + Node* dst_copy = (*node_map)[e->dst()->id()]; + output->AddEdge(src_copy, src_output, dst_copy, e->dst_input()); + } + } + return Status::OK(); +} + +Status BuildArgNode(Graph* graph, DataType type, int index, Node** arg_node) { + NodeDef arg_def; + NodeDefBuilder builder(strings::StrCat("_Arg", index), kArgOp); + builder.Attr("T", type); + builder.Attr("index", index); + TF_RETURN_IF_ERROR(builder.Finalize(&arg_def)); + Status status; + *arg_node = graph->AddNode(arg_def, &status); + return status; +} + +Status BuildRetvalNode(Graph* graph, DataType type, int index, + Node** retval_node) { + NodeDef ret_def; + ret_def.set_op(kRetValOp); + ret_def.set_name(strings::StrCat("_Retval", index)); + AddNodeAttr("T", type, &ret_def); + AddNodeAttr("index", index, &ret_def); + Status status; + *retval_node = graph->AddNode(ret_def, &status); + return status; +} + +// Builds a graph for the loop condition. +Status BuildLoopCondition(const Graph& graph, Frame* frame, + std::unique_ptr* cond_output) { + VLOG(2) << "Building loop condition for " << frame->name; + *cond_output = xla::MakeUnique(graph.op_registry()); + Graph* output = cond_output->get(); + + // Map from nodes in the original graph to the condition graph. + std::vector node_map(graph.num_node_ids(), nullptr); + std::vector squash_src_outputs(graph.num_node_ids(), false); + + // Build one _Arg node for each Enter node. + for (int i = 0; i < frame->args.size(); ++i) { + const Arg& arg = frame->args[i]; + + Node* arg_node; + TF_RETURN_IF_ERROR( + BuildArgNode(output, arg.enter->input_type(0), i, &arg_node)); + if (arg.is_loop_invariant) { + node_map[arg.enter->id()] = arg_node; + } else { + node_map[arg.merge->id()] = arg_node; + } + } + + // Build a Retval node for the loop condition. The LoopCond nodes are always + // boolean because of the type constraints on the LoopCond op. + TF_RETURN_IF_ERROR( + BuildRetvalNode(output, DT_BOOL, 0, &node_map[frame->loop_cond->id()])); + + // Performs a reverse DFS, copying nodes and edges to the output graph. + // The _Arg and _Retval nodes were added unconditionally above, so we are + // guaranteed to get the correct function signature. + TF_RETURN_IF_ERROR(CopySubgraph(graph, *frame, {frame->loop_cond}, + squash_src_outputs, &node_map, output)); + + return Status::OK(); +} + +// Builds a graph for the loop body. +Status BuildLoopBody(const Graph& graph, Frame* frame, + DataTypeVector* arg_types, + std::unique_ptr* body_output) { + VLOG(2) << "Building loop body for " << frame->name; + *body_output = xla::MakeUnique(graph.op_registry()); + Graph* output = body_output->get(); + + // Map from nodes in the original graph to the condition graph. + std::vector node_map(graph.num_node_ids(), nullptr); + std::vector squash_src_outputs(graph.num_node_ids(), false); + + // Build one _Arg node for each Enter node. + std::vector next_iterations; + next_iterations.reserve(frame->args.size()); + arg_types->reserve(frame->args.size()); + for (int i = 0; i < frame->args.size(); ++i) { + const Arg& arg = frame->args[i]; + + DataType dtype = arg.enter->input_type(0); + arg_types->push_back(dtype); + Node* arg_node; + TF_RETURN_IF_ERROR(BuildArgNode(output, dtype, i, &arg_node)); + + if (dtype == DT_RESOURCE) { + // The convention of the XLA bridge is that resource variable arguments + // are only inputs to the loop body and have no corresponding output. + // TODO(b/37741920): change the convention so that DT_RESOURCE variables + // are both inputs and outputs, and then remove this case. + TF_RET_CHECK(arg.is_loop_invariant); + node_map[arg.enter->id()] = arg_node; + } else { + Node* retval_node; + TF_RETURN_IF_ERROR(BuildRetvalNode(output, dtype, i, &retval_node)); + + if (arg.is_loop_invariant) { + // Argument is loop-invariant. Forward it from the Arg to the Retval. + node_map[arg.enter->id()] = arg_node; + output->AddEdge(arg_node, 0, retval_node, 0); + } else { + // Argument is loop-varying. + node_map[arg.switch_node->id()] = arg_node; + // The Switch node has two outputs, but _Arg only has one. This tells + // the CopySubgraph function to rewrite the output number of edges from + // the _Arg node to be 0 rather than copying the output number from the + // Switch node. + squash_src_outputs[arg.switch_node->id()] = true; + node_map[arg.next_iteration->id()] = retval_node; + next_iterations.push_back(arg.next_iteration); + } + } + } + + // Performs a reverse DFS, copying nodes and edges to the output graph. + // The _Arg and _Retval nodes were added unconditionally above, so we are + // guaranteed to get the correct function signature. + TF_RETURN_IF_ERROR(CopySubgraph(graph, *frame, std::move(next_iterations), + squash_src_outputs, &node_map, output)); + + return Status::OK(); +} + +Status FunctionalizeLoop(Graph* graph, Frame* frame, + FunctionLibraryDefinition* library) { + VLOG(2) << "Frame " << frame->name << " before: " + << dump_graph::DumpGraphToFile("functionalize_before", *graph); + + // Split loop-varying Enter nodes with multiple successors. If the same + // Tensor is fed as input to multiple loop arguments, we may end up with a + // shared Enter node. We clone Enter nodes with multiple successors to + // maintain the invariant of a unique Enter node per argument of the final + // loop. + std::vector args; + for (const Arg& arg : frame->args) { + if (arg.is_loop_invariant) { + args.push_back(arg); + } else { + std::vector edges(arg.enter->out_edges().begin(), + arg.enter->out_edges().end()); + for (int i = 0; i < edges.size(); ++i) { + if (edges[i]->IsControlEdge() && edges[i]->dst()->IsSink()) { + continue; + } + TF_RET_CHECK(!edges[i]->IsControlEdge()) << edges[i]->src()->name(); + Arg new_arg; + new_arg.is_loop_invariant = false; + if (i == 0) { + new_arg.enter = arg.enter; + } else { + new_arg.enter = graph->CopyNode(arg.enter); + frame->nodes.insert(new_arg.enter); + for (Edge const* e : arg.enter->in_edges()) { + graph->AddEdge(e->src(), e->src_output(), new_arg.enter, + e->IsControlEdge() ? Graph::kControlSlot : 0); + } + Node* dst = edges[i]->dst(); + int dst_input = edges[i]->dst_input(); + graph->RemoveEdge(edges[i]); + graph->AddEdge(new_arg.enter, 0, dst, dst_input); + } + args.push_back(new_arg); + } + } + } + frame->args = std::move(args); + + // Order the arguments so that: + // a) resource variables are last, and + // b) sort lexicographically by name (for deterministic output). + std::sort(frame->args.begin(), frame->args.end(), + [](const Arg& a, const Arg& b) { + bool a_is_resource = (a.enter->input_type(0) == DT_RESOURCE); + bool b_is_resource = (b.enter->input_type(0) == DT_RESOURCE); + return std::tie(a_is_resource, a.enter->name()) < + std::tie(b_is_resource, b.enter->name()); + }); + + if (frame->loop_cond == nullptr) { + return errors::InvalidArgument("Loop ", frame->name, + " has no LoopCond node"); + } + + // Find the set of Switch nodes that are successors of the LoopCond. + std::unordered_set switches; + for (const Edge* edge : frame->loop_cond->out_edges()) { + if (!edge->IsControlEdge() && IsSwitch(edge->dst()) && + edge->dst_input() == 1) { + switches.insert(edge->dst()); + } + } + + // For each non-constant argument, looks for the following pattern of nodes: + // Enter ----> Merge --------> Switch --> Exit + // ^ ^ + // | | + // NextIteration LoopCond + // ^ ^ + // | | + // ... ... + for (Arg& arg : frame->args) { + if (!arg.is_loop_invariant) { + // Follow the edge from the Enter to Merge. + if (arg.enter->out_edges().size() != 1) { + return errors::Internal("Enter node for loop-varying argument ", + arg.enter->name(), + " does not have exactly one successor"); + } + const Edge* enter_merge = *arg.enter->out_edges().begin(); + arg.merge = enter_merge->dst(); + if (!IsMerge(arg.merge)) { + return errors::InvalidArgument( + "Successor of Enter node for loop-varying argument ", + arg.merge->name(), + " is not a Merge node; got: ", arg.merge->type_string()); + } + + // Find the NextIteration from the merge. There should be two inputs to + // the Merge and the NextIteration should be the other input. + if (arg.merge->input_types().size() != 2) { + return errors::InvalidArgument( + "Unexpected number of inputs to Merge node for loop-varying " + "argument ", + arg.merge->name(), "; expected 2, got ", + arg.merge->input_types().size()); + } + TF_RETURN_IF_ERROR(arg.merge->input_node(1 - enter_merge->dst_input(), + &arg.next_iteration)); + if (!IsNextIteration(arg.next_iteration)) { + return errors::InvalidArgument( + "Expected NextIteration node as input to Merge node; got node ", + arg.next_iteration->name(), " with kind ", + arg.next_iteration->type_string()); + } + + // Find the Switch successor of the Merge. There should be exactly one + // Switch node that is a successor of both the Merge and the LoopCond. + for (const Edge* edge : arg.merge->out_edges()) { + if (edge->dst_input() == 0 && IsSwitch(edge->dst()) && + switches.find(edge->dst()) != switches.end()) { + if (arg.switch_node != nullptr) { + return errors::InvalidArgument("Duplicate Switch successors to ", + arg.merge->name()); + } + arg.switch_node = edge->dst(); + } + } + if (arg.switch_node == nullptr) { + return errors::InvalidArgument("Missing Switch successor to ", + arg.merge->name()); + } + + // Find the Exit successor of the Switch. + for (const Edge* edge : arg.switch_node->out_edges()) { + if (edge->src_output() == 0 && IsExit(edge->dst())) { + if (arg.exit != nullptr) { + return errors::InvalidArgument("Duplicate Exit successors to ", + arg.switch_node->name()); + } + arg.exit = edge->dst(); + } + } + if (arg.exit == nullptr) { + return errors::InvalidArgument("Missing Exit successor to ", + arg.switch_node->name()); + } + } + } + + // Builds the condition and body functions. + std::unique_ptr cond_graph; + TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph)); + DataTypeVector arg_types; + std::unique_ptr body_graph; + TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph)); + + VLOG(2) << "Frame " << frame->name << " condition: " + << dump_graph::DumpGraphToFile("loop_condition", *cond_graph) + << " body: " << dump_graph::DumpGraphToFile("loop_body", *body_graph); + + static std::atomic sequence_num(0LL); + int64 id = ++sequence_num; + NameAttrList cond_name; + cond_name.set_name(strings::StrCat("_functionalize_cond_", id)); + NameAttrList body_name; + body_name.set_name(strings::StrCat("_functionalize_body_", id)); + FunctionDef cond_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*cond_graph, cond_name.name(), &cond_fdef)); + FunctionDef body_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*body_graph, body_name.name(), &body_fdef)); + + TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef)); + TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef)); + + // Builds a While operator. + NodeDef while_def; + NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile"); + builder.Attr("T", arg_types); + builder.Attr("cond", cond_name); + builder.Attr("body", body_name); + std::vector inputs; + for (int i = 0; i < frame->args.size(); ++i) { + const Arg& arg = frame->args[i]; + const Edge* in_edge; + TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge)); + if (in_edge->IsControlEdge()) { + builder.ControlInput(in_edge->src()->name()); + } else { + inputs.push_back(NodeDefBuilder::NodeOut( + in_edge->src()->name(), in_edge->src_output(), arg_types[i])); + } + } + builder.Input(inputs); + TF_RETURN_IF_ERROR(builder.Finalize(&while_def)); + + Status status; + Node* while_node = graph->AddNode(while_def, &status); + if (!status.ok()) { + return status; + } + + // Copies edges to the Enter nodes and from the Exit nodes onto the While. + for (int i = 0; i < frame->args.size(); ++i) { + const Arg& arg = frame->args[i]; + const Edge* in_edge; + TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge)); + if (in_edge->IsControlEdge()) { + graph->AddControlEdge(in_edge->src(), while_node); + } else { + graph->AddEdge(in_edge->src(), in_edge->src_output(), while_node, i); + } + + if (!arg.is_loop_invariant) { + std::vector edges(arg.exit->out_edges().begin(), + arg.exit->out_edges().end()); + for (const Edge* edge : edges) { + Node* dst = edge->dst(); + int dst_input = edge->dst_input(); + graph->RemoveEdge(edge); + + int src_output = + dst_input == Graph::kControlSlot ? Graph::kControlSlot : i; + graph->AddEdge(while_node, src_output, dst, dst_input); + } + } + } + + // Remove the old nodes from the graph, and add the while node to the parent + // frame. + for (Node* node : frame->nodes) { + graph->RemoveNode(node); + } + frame->parent->nodes.insert(while_node); + + VLOG(2) << "Frame " << frame->name << " after: " + << dump_graph::DumpGraphToFile("functionalize_after", *graph); + + return Status::OK(); +} + +} // namespace + +// Transformation that converts Tensorflow's graph control flow constructs into +// functional equivalents. +Status FunctionalizeControlFlow(Graph* graph, + FunctionLibraryDefinition* library) { + VLOG(2) << "FunctionalizeControlFlow: " + << dump_graph::DumpGraphToFile("functionalize_initial", *graph); + // Note: BuildControlFlowInfo() requires that the graph's source node is + // connected to all source nodes in the graph. Many graphs violate this + // invariant. + std::vector cf_info; + TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info)); + + // Builds Frames, indexed by name. + std::unordered_map frames; + for (Node* node : graph->op_nodes()) { + const ControlFlowInfo& cf = cf_info[node->id()]; + + VLOG(2) << "node: " << node->name() << " frame_name: " << cf.frame_name + << " frame: " << (cf.frame ? cf.frame->name() : "---") + << " parent_frame: " + << (cf.parent_frame ? cf.parent_frame->name() : "---"); + TF_RET_CHECK(cf.frame != nullptr && cf.parent_frame != nullptr); + + Frame& frame = frames[cf.frame_name]; + Frame* parent = &frames[cf_info[cf.parent_frame->id()].frame_name]; + if (frame.parent == nullptr) { + frame.parent = parent; + frame.name = cf.frame_name; + ++parent->num_children; + } else if (frame.parent != parent) { + return errors::InvalidArgument("Mismatched parent frames for ", + cf.frame->id(), ": ", parent->name, " vs ", + frame.parent->name); + } + + if (IsEnter(node)) { + Arg arg; + arg.enter = node; + TF_RETURN_IF_ERROR(GetNodeAttr(arg.enter->attrs(), "is_constant", + &arg.is_loop_invariant)); + frame.args.push_back(arg); + } else if (IsLoopCond(node)) { + if (frame.loop_cond) { + return errors::InvalidArgument( + "Loop ", cf.frame_name, + " has more than one LoopCond node: ", node->name(), " and ", + frame.loop_cond->name()); + } + frame.loop_cond = node; + } + frame.nodes.insert(node); + } + + // Adds frames with no children (i.e., the innermost frames) to a worklist. + std::deque worklist; + for (auto& frame : frames) { + if (frame.second.num_children == 0) { + worklist.push_back(&frame.second); + } + } + + // Eliminate loops from innermost to outermost. + while (!worklist.empty()) { + Frame* frame = worklist.front(); + worklist.pop_front(); + if (frame->parent == frame) { + // Skip the root frame. + continue; + } + + TF_RETURN_IF_ERROR(FunctionalizeLoop(graph, frame, library)); + + // If the parent has no remaining children, add it to the worklist. + --frame->parent->num_children; + if (frame->parent->num_children == 0) { + worklist.push_back(frame->parent); + } + } + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h new file mode 100644 index 0000000000000000000000000000000000000000..1535dc80b0ccdba38c57b534ed7473fc8632e33f --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ +#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Transformation that converts tf.while_loop() loops into functional While +// operators, suitable for XLA compilation. +// TODO(b/36470387): add support for conditionals. +Status FunctionalizeControlFlow(Graph* graph, + FunctionLibraryDefinition* library); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..2fb1cc04543561f2e8a296352b0d4922a0eba37c --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -0,0 +1,650 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/resource_variable_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/tf2xla/cc/ops/functional_ops.h" +#include "tensorflow/compiler/tf2xla/test_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/equal_graph_def.h" + +namespace tensorflow { +namespace { + +// Returns the names of the "cond" and "body" functions for the While node +// in a graph. +Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond, + NameAttrList* body) { + for (const NodeDef& node : graph.node()) { + if (node.op() == "XlaWhile") { + const NameAttrList* result; + TF_RETURN_IF_ERROR(GetNodeAttr(node, "cond", &result)); + *cond = *result; + TF_RETURN_IF_ERROR(GetNodeAttr(node, "body", &result)); + *body = *result; + return Status::OK(); + } + } + return errors::NotFound("No XlaWhile node found in graph"); +} + +// Graph: +// x = array_ops.placeholder(dtypes.int32) +// y = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x]) +TEST(FunctionalizeControlFlow, OneLoopVar) { + Graph graph(OpRegistry::Global()); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + + auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32); + + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto enter = + ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop"); + // Add an unused Enter node. These should be ignored. + auto enter2 = + ops::internal::Enter(scope.WithOpName("while/Enter2"), source, "aloop"); + auto merge = ops::Merge(scope.WithOpName("while/Merge"), + std::initializer_list{enter, dummy}); + auto ten = ops::Const( + scope.WithOpName("while/Less/y").WithControlDependencies(merge.output), + 10); + auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten); + auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less); + auto switch_ = + ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond); + auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"), + switch_.output_false); + auto identity = + ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true); + auto one = ops::Const( + scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); + auto add = ops::Add(scope.WithOpName("while/add"), identity, one); + auto next_iteration = + ops::NextIteration(scope.WithOpName("while/NextIteration"), add); + + auto sink = ops::Identity(scope.WithOpName("sink"), exit); + + // Remove the dummy node and add the loop backedge. + scope.graph()->RemoveNode(dummy.node()); + scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1); + + TF_EXPECT_OK(scope.ToGraph(&graph)); + } + + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + + NameAttrList cond_fn, body_fn; + TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto while_op = + ops::XlaWhile(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); + auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } + + // Condition graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto ten = ops::Const( + scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10); + auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } + + // Body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); + auto one = ops::Const( + scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); + auto add = ops::Add(scope.WithOpName("while/add"), identity, one); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } +} + +// Graph: +// x = array_ops.placeholder(dtypes.int32) +// y = array_ops.placeholder(dtypes.int32) +// cond = lambda (i, j): i + 3 < 10 +// body = lambda (i, j): (i < 10, j * 2) +// z = control_flow_ops.while_loop(cond, body, [x, y]) +TEST(FunctionalizeControlFlow, TwoLoopVars) { + Graph graph(OpRegistry::Global()); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + + auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32); + + auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32); + auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32); + auto enter_x = + ops::internal::Enter(scope.WithOpName("while/Enter/x"), x, "aloop"); + auto enter_y = + ops::internal::Enter(scope.WithOpName("while/Enter/y"), y, "aloop"); + auto merge_x = ops::Merge(scope.WithOpName("while/Merge/x"), + std::initializer_list{enter_x, dummy}); + auto merge_y = ops::Merge(scope.WithOpName("while/Merge/y"), + std::initializer_list{enter_y, dummy}); + + // Loop condition + auto three = ops::Const(scope.WithOpName("while/cond/three") + .WithControlDependencies(merge_x.output), + 3); + auto cond_add = + ops::Add(scope.WithOpName("while/cond/Add"), merge_x.output, three); + auto ten = ops::Const(scope.WithOpName("while/cond/ten") + .WithControlDependencies(merge_x.output), + 10); + auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten); + auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less); + + auto switch_x = ops::Switch(scope.WithOpName("while/Switch/x"), + merge_x.output, loop_cond); + auto switch_y = ops::Switch(scope.WithOpName("while/Switch/y"), + merge_y.output, loop_cond); + + auto exit_x = ops::internal::Exit(scope.WithOpName("while/Exit/x"), + switch_x.output_false); + auto exit_y = ops::internal::Exit(scope.WithOpName("while/Exit/y"), + switch_y.output_false); + + auto identity_x = ops::Identity(scope.WithOpName("while/Identity/x"), + switch_x.output_true); + auto identity_y = ops::Identity(scope.WithOpName("while/Identity/y"), + switch_y.output_true); + + auto one = ops::Const( + scope.WithOpName("while/add/one").WithControlDependencies(identity_x), + 1); + auto two = ops::Const( + scope.WithOpName("while/mul/two").WithControlDependencies(identity_x), + 2); + + auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one); + auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two); + auto next_iteration_x = + ops::NextIteration(scope.WithOpName("while/NextIteration/x"), add); + auto next_iteration_y = + ops::NextIteration(scope.WithOpName("while/NextIteration/y"), mul); + + auto sink_x = ops::Identity(scope.WithOpName("sink_x"), exit_x); + auto sink_y = ops::Identity(scope.WithOpName("sink_y"), exit_y); + + // Remove the dummy node and add the loop backedges. + scope.graph()->RemoveNode(dummy.node()); + scope.graph()->AddEdge(next_iteration_x.node(), 0, merge_x.output.node(), + 1); + scope.graph()->AddEdge(next_iteration_y.node(), 0, merge_y.output.node(), + 1); + + TF_EXPECT_OK(scope.ToGraph(&graph)); + } + + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + + NameAttrList cond_fn, body_fn; + TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + + // Outer graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32); + auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32); + auto while_op = + ops::XlaWhile(scope.WithOpName("while/LoopCond"), + std::initializer_list{x, y}, cond_fn, body_fn); + auto sink_x = ops::Identity(scope.WithOpName("sink_x"), while_op[0]); + auto sink_y = ops::Identity(scope.WithOpName("sink_y"), while_op[1]); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } + + // Condition graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto three = ops::Const(scope.WithOpName("while/cond/three") + .WithControlDependencies(arg0.output), + 3); + auto cond_add = + ops::Add(scope.WithOpName("while/cond/Add"), arg0.output, three); + auto ten = ops::Const( + scope.WithOpName("while/cond/ten").WithControlDependencies(arg0.output), + 10); + auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } + + // Body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + + auto identity_x = ops::Identity(scope.WithOpName("while/Identity/x"), arg0); + auto identity_y = ops::Identity(scope.WithOpName("while/Identity/y"), arg1); + + auto one = ops::Const( + scope.WithOpName("while/add/one").WithControlDependencies(identity_x), + 1); + auto two = ops::Const( + scope.WithOpName("while/mul/two").WithControlDependencies(identity_x), + 2); + + auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one); + auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two); + auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); + auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), mul, 1); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types); + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } +} + +// Example with nesting, loop-invariant arguments, and resource variables. +// +// accum = resource_variable_ops.ResourceVariable(1) +// x = array_ops.placeholder(2, dtype=dtypes.int32) +// y = 3 + x +// +// def inner_body(j, k): +// add = state_ops.assign_add(accum, k * j + x) +// with ops.control_dependencies([add]): +// return [j + 1, k] +// +// def body(i): +// m = control_flow_ops.while_loop(lambda j, k: j < 5, inner_body, +// [1, y], name="inner") +// with ops.control_dependencies(m): +// return [i + 1] +// +// z = control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="outer") +TEST(FunctionalizeControlFlow, Complex) { + Graph graph(OpRegistry::Global()); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + + auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32); + + auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); + auto three = ops::Const(scope.WithOpName("three"), 3); + auto y = ops::Add(scope.WithOpName("y"), x, three); + + auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32, + TensorShape({})); + + // Outer loop + auto zero = ops::Const(scope.WithOpName("outer/Const"), 0); + auto enter_i = + ops::internal::Enter(scope.WithOpName("outer/Enter_i"), zero, "outer"); + auto merge_i = ops::Merge(scope.WithOpName("outer/Merge_i"), + std::initializer_list{enter_i, dummy}); + auto ten = ops::Const(scope.WithOpName("outer/Less/y") + .WithControlDependencies(merge_i.output), + 10); + auto less_i = + ops::Less(scope.WithOpName("outer/Less_i"), merge_i.output, ten); + auto outer_loop_cond = + ops::LoopCond(scope.WithOpName("outer/LoopCond"), less_i); + auto switch_i = ops::Switch(scope.WithOpName("outer/Switch"), + merge_i.output, outer_loop_cond); + auto exit_i = ops::internal::Exit(scope.WithOpName("outer/Exit"), + switch_i.output_false); + auto identity_i = + ops::Identity(scope.WithOpName("outer/Identity"), switch_i.output_true); + + auto enter_x_outer = + ops::internal::Enter(scope.WithOpName("outer/Enter_x"), x, "outer", + ops::internal::Enter::Attrs().IsConstant(true)); + auto enter_k_outer = + ops::internal::Enter(scope.WithOpName("outer/Enter_k"), y, "outer", + ops::internal::Enter::Attrs().IsConstant(true)); + auto enter_var_outer = + ops::internal::Enter(scope.WithOpName("outer/Enter_var"), var, "outer", + ops::internal::Enter::Attrs().IsConstant(true)); + + // Inner loop + auto one_j = ops::Const( + scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1); + auto enter_j = ops::internal::Enter(scope.WithOpName("outer/inner/Enter_j"), + one_j, "inner"); + auto enter_k = + ops::internal::Enter(scope.WithOpName("outer/inner/Enter_k") + .WithControlDependencies(identity_i), + enter_k_outer, "inner"); + auto enter_x = ops::internal::Enter( + scope.WithOpName("outer/inner/Enter_x"), enter_x_outer, "inner", + ops::internal::Enter::Attrs().IsConstant(true)); + auto enter_var = ops::internal::Enter( + scope.WithOpName("outer/inner/Enter_var"), enter_var_outer, "inner", + ops::internal::Enter::Attrs().IsConstant(true)); + + auto merge_j = ops::Merge(scope.WithOpName("outer/inner/Merge_j"), + std::initializer_list{enter_j, dummy}); + auto merge_k = ops::Merge(scope.WithOpName("outer/inner/Merge_k"), + std::initializer_list{enter_k, dummy}); + + auto five = ops::Const(scope.WithOpName("outer/inner/Five") + .WithControlDependencies(merge_j.output), + 5); + auto less_j = + ops::Less(scope.WithOpName("outer/inner/Less_j"), merge_j.output, five); + auto loop_cond = ops::LoopCond(scope.WithOpName("outer/LoopCond"), less_j); + + auto switch_j = ops::Switch(scope.WithOpName("outer/inner/Switch_j"), + merge_j.output, loop_cond); + auto switch_k = ops::Switch(scope.WithOpName("outer/inner/Switch_k"), + merge_k.output, loop_cond); + auto exit_j = ops::internal::Exit(scope.WithOpName("outer/inner/Exit_j"), + switch_j.output_false); + auto exit_k = ops::internal::Exit(scope.WithOpName("outer/inner/Exit_k"), + switch_k.output_false); + auto identity_j = ops::Identity(scope.WithOpName("outer/inner/Identity_j"), + switch_j.output_true); + auto identity_k = ops::Identity(scope.WithOpName("outer/inner/Identity_k"), + switch_k.output_true); + + // Variable update + auto mul_jk = + ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k); + auto add_jkx = + ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, enter_x); + auto assign = ops::AssignAddVariableOp( + scope.WithOpName("outer/inner/assign_add"), enter_var, add_jkx); + + auto one = + ops::Const(scope.WithOpName("outer/inner/One") + .WithControlDependencies( + gtl::ArraySlice{assign.operation}), + 1); + auto add_j = + ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one); + + auto next_iteration_j = ops::NextIteration( + scope.WithOpName("outer/inner/NextIteration_j"), add_j); + auto next_iteration_k = ops::NextIteration( + scope.WithOpName("outer/inner/NextIteration_k"), identity_k); + + // Body and backedge for outer loop. + auto one_outer = ops::Const( + scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1); + auto add_i = + ops::Add(scope.WithOpName("outer/add") + .WithControlDependencies(gtl::ArraySlice{ + exit_j.output.op(), exit_k.output.op()}), + identity_i, one_outer); + auto next_iteration_i = + ops::NextIteration(scope.WithOpName("outer/NextIteration"), add_i); + + auto sink = ops::Identity(scope.WithOpName("sink"), exit_i); + + // Remove the dummy node and add the loop backedge. + scope.graph()->RemoveNode(dummy.node()); + scope.graph()->AddEdge(next_iteration_i.node(), 0, merge_i.output.node(), + 1); + scope.graph()->AddEdge(next_iteration_j.node(), 0, merge_j.output.node(), + 1); + scope.graph()->AddEdge(next_iteration_k.node(), 0, merge_k.output.node(), + 1); + + TF_EXPECT_OK(scope.ToGraph(&graph)); + } + + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + + NameAttrList outer_cond_fn, outer_body_fn; + TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &outer_cond_fn, &outer_body_fn)); + + // Outer graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); + auto three = ops::Const(scope.WithOpName("three"), 3); + auto y = ops::Add(scope.WithOpName("y"), x, three); + + auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32, + TensorShape({})); + + auto zero = ops::Const(scope.WithOpName("outer/Const"), 0); + + auto while_op = ops::XlaWhile(scope.WithOpName("outer/LoopCond"), + std::initializer_list{zero, y, x, var}, + outer_cond_fn, outer_body_fn); + auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } + + // Outer condition graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); + + auto ten = ops::Const( + scope.WithOpName("outer/Less/y").WithControlDependencies(arg0.output), + 10); + auto less = ops::Less(scope.WithOpName("outer/Less_i"), arg0, ten); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(outer_cond_fn.name(), library, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } + + // Outer body graph. + NameAttrList inner_cond_fn, inner_body_fn; + { + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(outer_body_fn.name(), library, &result)); + + // Find the inner condition and body names. + TF_EXPECT_OK( + FindWhileCondAndBody(result.gdef, &inner_cond_fn, &inner_body_fn)); + + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); + + auto identity_i = ops::Identity(scope.WithOpName("outer/Identity"), arg0); + auto one_j = ops::Const( + scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1); + auto while_op = + ops::XlaWhile(scope.WithOpName("outer/LoopCond_1"), + std::initializer_list{one_j, arg1, arg2, arg3}, + inner_cond_fn, inner_body_fn); + + auto one_outer = ops::Const( + scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1); + auto add_i = + ops::Add(scope.WithOpName("outer/add") + .WithControlDependencies(gtl::ArraySlice{ + while_op[0].op(), while_op[1].op()}), + identity_i, one_outer); + + auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_i, 0); + auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), arg1, 1); + auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } + + // Inner condition graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); + + auto five = ops::Const( + scope.WithOpName("outer/inner/Five").WithControlDependencies(arg0), 5); + auto less_j = ops::Less(scope.WithOpName("outer/inner/Less_j"), arg0, five); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less_j, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(inner_cond_fn.name(), library, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } + + // Inner body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); + + auto identity_j = + ops::Identity(scope.WithOpName("outer/inner/Identity_j"), arg0); + auto identity_k = + ops::Identity(scope.WithOpName("outer/inner/Identity_k"), arg1); + + auto mul_jk = + ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k); + auto add_jkx = ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, arg2); + auto assign = ops::AssignAddVariableOp( + scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx); + + auto one = + ops::Const(scope.WithOpName("outer/inner/One") + .WithControlDependencies( + gtl::ArraySlice{assign.operation}), + 1); + auto add_j = + ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one); + + auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_j, 0); + auto retval1 = + ops::_Retval(scope.WithOpName("_retval1_RetVal"), identity_k, 1); + auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(inner_body_fn.name(), library, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index a434c7468095a05ee6da31826d44379a735b51f7..96b4fdfec6dff8bd8d4aca19181db0b64e3e2a1c 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -68,6 +68,7 @@ tf_kernel_library( "reduction_ops.h", ], deps = [ + ":while_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:literal_util", @@ -91,6 +92,21 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "while_op", + srcs = ["while_op.cc"], + hdrs = ["while_op.h"], + deps = [ + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/ops:functional_ops", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:tensorflow_opensource", + ], +) + # Kernels that only work on CPU, because they use XLA custom calls. # Only link this when using the CPU backend for XLA. # diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc index 620fc8443785388781caf5121da53c4d908d4cb4..6ad72c6219e01b323d79611e2ad67a6cf4d2f390 100644 --- a/tensorflow/compiler/tf2xla/kernels/arg_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc @@ -51,13 +51,26 @@ class ArgOp : public XlaOpKernel { XlaContext& xc = XlaContext::Get(ctx); const XlaContext::Argument& arg = xc.args()[index_]; - if (arg.is_variable) { + if (arg.is_resource) { + XlaResource::Kind kind; + switch (arg.kind) { + case XlaCompiler::Argument::kVariable: + kind = XlaResource::kVariable; + break; + case XlaCompiler::Argument::kTensorArray: + kind = XlaResource::kTensorArray; + break; + default: + CHECK(false); + } + // TODO(phawkins): this code assumes that variables do not alias. - XlaVariable* var; - OP_REQUIRES_OK(ctx, xc.CreateVariable(index_, arg.name, arg.value.type, - arg.value.handle, &var)); - var->tensor_array_size = arg.tensor_array_size; - ctx->SetVariableOutput(0, var); + XlaResource* resource; + OP_REQUIRES_OK(ctx, + xc.CreateResource(kind, index_, arg.name, arg.value.type, + arg.value.handle, &resource)); + resource->tensor_array_size = arg.tensor_array_size; + ctx->SetResourceOutput(0, resource); } else if (arg.value.is_constant) { ctx->SetConstantOutput(0, arg.value.constant_value); } else { diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index 8642cbf2a924e3c82c80bff8f5122e62ce12082d..21d3e64872e19109852297838043975cea6d7921 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -127,8 +127,8 @@ void BatchToSpace(XlaOpKernelContext* ctx, std::vector end_indices = reshaped_permuted_shape; std::vector strides(input_rank, 1); for (int i = 0; i < block_rank; ++i) { - int64 crop_start = xla::LiteralUtil::Get(crops, {i, 0}); - int64 crop_end = xla::LiteralUtil::Get(crops, {i, 1}); + int64 crop_start = crops.Get({i, 0}); + int64 crop_end = crops.Get({i, 1}); OP_REQUIRES(ctx, crop_start >= 0 && crop_end >= 0, errors::InvalidArgument("Crops must be non-negative")); start_indices[1 + i] = crop_start; diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc index b0fee5e4bca502a7abb4613b58ecdd2ffca2206d..bc2cd31230dfe9ca35540341d225dcb768fa34f6 100644 --- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc @@ -55,7 +55,7 @@ class BCastGradArgsOp : public XlaOpKernel { BCast::Vec vec; for (int64 i = 0; i < in_shape.num_elements(); ++i) { - vec.push_back(xla::LiteralUtil::Get(literal, {i})); + vec.push_back(literal.Get({i})); } shapes.push_back(vec); } diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc index 124e33d7935ce19ced72d1c84521ffda1090bc86..2331520230176fce7646d89140851fe37aee5fda 100644 --- a/tensorflow/compiler/tf2xla/kernels/cast_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc @@ -38,17 +38,6 @@ class CastOp : public XlaOpKernel { if (src_dtype_ == dst_dtype_) { output = input; - } else if (src_dtype_ == DT_BOOL) { - // XLA's ConvertElementType doesn't support casting to/from - // bools. So we need to handle those cases separately. - // Builds the equivalent of (input ? 1 : 0) - xla::ComputationBuilder l(builder->client(), "PredCast"); - xla::ComputationDataHandle x = - l.Parameter(0, xla::ShapeUtil::MakeShape(src_type_, {}), "x"); - l.Select(x, XlaHelpers::One(&l, dst_dtype_), - XlaHelpers::Zero(&l, dst_dtype_)); - xla::Computation computation = l.Build().ConsumeValueOrDie(); - output = builder->Map({input}, computation); } else if (dst_dtype_ == DT_BOOL) { output = builder->Ne(input, XlaHelpers::Zero(builder, src_dtype_)); } else { diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index e2eacb3839d39e6fa41192e8aa0f31d878d96aea..73a4740e29af7fa57e71ef42a342f46b0e24231d 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -52,7 +52,7 @@ class ConcatBaseOp : public XlaOpKernel { xla::Literal literal; OP_REQUIRES_OK(ctx, ctx->ConstantInput(axis_index_, &literal)); // TODO(annarev): add a helper to support int64 input. - const int32 concat_dim = xla::LiteralUtil::Get(literal, {}); + const int32 concat_dim = literal.Get({}); std::vector values; std::vector shapes; @@ -163,7 +163,7 @@ class ConcatOffsetOp : public XlaOpKernel { xla::Literal concat_dim_literal; OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &concat_dim_literal)); - const int64 cdim = xla::LiteralUtil::Get(concat_dim_literal, {}); + const int64 cdim = concat_dim_literal.Get({}); VLOG(1) << "ConcatOffset " << cdim << "," << dims; int32 axis = cdim < 0 ? cdim + dims : cdim; @@ -185,12 +185,10 @@ class ConcatOffsetOp : public XlaOpKernel { for (int64 j = 0; j < dims; ++j) { if (j == axis) { out_vec(j) = offset; - offset += xla::LiteralUtil::Get(inp_literal, {j}); + offset += inp_literal.Get({j}); } else { - const int32 inp0_element = - xla::LiteralUtil::Get(inp0_literal, {j}); - const int32 inp_element = - xla::LiteralUtil::Get(inp_literal, {j}); + const int32 inp0_element = inp0_literal.Get({j}); + const int32 inp_element = inp_literal.Get({j}); OP_REQUIRES( ctx, (inp0_element == inp_element), errors::InvalidArgument("input[", i, ",", j, "] mismatch: ", diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index ad676e7a2bb3d3f28ecb98164323cbf1e32f61a9..9833323d851e00e7ca76d0b39cd2b216748a17fa 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/tensor.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index 107c673f4a7d62f8b760b137aeda2864e156b7f7..dde7898015e73190c96fa6effddfd3fc892264ea 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -63,11 +63,14 @@ class DynamicStitchOp : public XlaOpKernel { std::vector indices(indices_input.size()); const TensorShape& data0_shape = data_shapes[0]; - const TensorShape indices0_shape = - XLAShapeToTensorShape(indices_input[0].shape()); + TensorShape indices0_shape; + OP_REQUIRES_OK( + ctx, XLAShapeToTensorShape(indices_input[0].shape(), &indices0_shape)); for (int input_num = 0; input_num < indices_input.size(); input_num++) { - const TensorShape indices_shape = - XLAShapeToTensorShape(indices_input[input_num].shape()); + TensorShape indices_shape; + OP_REQUIRES_OK(ctx, + XLAShapeToTensorShape(indices_input[input_num].shape(), + &indices_shape)); const TensorShape& data_shape = data_shapes[input_num]; OP_REQUIRES(ctx, TensorShapeUtils::StartsWith(data_shape, indices_shape), errors::InvalidArgument( @@ -103,8 +106,7 @@ class DynamicStitchOp : public XlaOpKernel { int max_index = -1; for (int input_num = 0; input_num < indices.size(); input_num++) { for (int i = 0; i < indices[input_num].shape().dimensions(0); ++i) { - max_index = std::max( - max_index, xla::LiteralUtil::Get(indices[input_num], {i})); + max_index = std::max(max_index, indices[input_num].Get({i})); } } int number_of_indices = max_index + 1; @@ -118,7 +120,7 @@ class DynamicStitchOp : public XlaOpKernel { int index_used_count = 0; for (int input_num = 0; input_num < indices.size(); input_num++) { for (int i = 0; i < indices[input_num].shape().dimensions(0); ++i) { - int index = xla::LiteralUtil::Get(indices[input_num], {i}); + int index = indices[input_num].Get({i}); src_input_vector[index] = input_num; src_slice_vector[index] = i; if (!src_index_used[index]) { diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc index 1e1d2a1b4b3fa281adc96b76ade5ce7b07b2b41c..9e090fe01cbfd4dab81b0de21e3a44e42c2ef18e 100644 --- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc @@ -52,7 +52,7 @@ class FillOp : public XlaOpKernel { std::vector broadcast; broadcast.reserve(dims_literal.shape().dimensions(0)); for (int i = 0; i < dims_literal.shape().dimensions(0); ++i) { - broadcast.push_back(xla::LiteralUtil::Get(dims_literal, {i})); + broadcast.push_back(dims_literal.Get({i})); } // Look up the value input, reshaping to a scalar if it was a // 'legacy' scalar (secretly a vector). diff --git a/tensorflow/compiler/tf2xla/kernels/function_ops.cc b/tensorflow/compiler/tf2xla/kernels/function_ops.cc index 8dacb6627bde516c92cb07b747207adbe85ada5b..af1085d5b35077b7ebd144bfb2473485e3b3de6b 100644 --- a/tensorflow/compiler/tf2xla/kernels/function_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/function_ops.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/node_def.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 49eadaf9d1f0ff1dbfa2321f20f9f833a0d4eb9a..3c1cdef5f80e7cd9151590403846677e20f999d1 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -66,10 +66,10 @@ class GatherOp : public XlaOpKernel { std::vector args; args.push_back(tc.GetOrCreateRuntimeContextParameter()); args.push_back(b.ConstantLiteral( - *xla::LiteralUtil::CreateR0(indices_shape.num_elements()))); + *xla::Literal::CreateR0(indices_shape.num_elements()))); args.push_back(b.ConstantLiteral( - *xla::LiteralUtil::CreateR0(params_shape.dim_size(0)))); - args.push_back(b.ConstantLiteral(*xla::LiteralUtil::CreateR0( + *xla::Literal::CreateR0(params_shape.dim_size(0)))); + args.push_back(b.ConstantLiteral(*xla::Literal::CreateR0( params_shape.num_elements() / params_shape.dim_size(0)))); args.push_back(ctx->Input(0)); args.push_back(ctx->Input(1)); diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc index df002dddd043c6795481436586a31c74b20d33d1..6be66cf66ec19cad33858f36a3239048efce9de3 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc @@ -69,7 +69,7 @@ class ArgMaxOp : public XlaOpKernel { // XLA op would have the same requirement. xla::Literal literal; OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &literal)); - const int32 dim = xla::LiteralUtil::Get(literal, {}); + const int32 dim = literal.Get({}); OP_REQUIRES(ctx, dim >= 0, errors::InvalidArgument("dim must be >= 0")); OP_REQUIRES( ctx, dim < input_shape.dims(), @@ -97,14 +97,13 @@ class ArgMaxOp : public XlaOpKernel { std::vector args; args.push_back(ctx->Input(0)); args.push_back(b.ConstantLiteral( - *xla::LiteralUtil::CreateR1(input_shape.dim_sizes()))); + *xla::Literal::CreateR1(input_shape.dim_sizes()))); if (input_shape.dims() > 1) { // Don't bother passing the output shape and dim for the 1d case, since // the shape is always a scalar and the dim is always 0. args.push_back(b.ConstantLiteral( - *xla::LiteralUtil::CreateR1(output_shape.dim_sizes()))); - args.push_back( - b.ConstantLiteral(*xla::LiteralUtil::CreateR0(dim))); + *xla::Literal::CreateR1(output_shape.dim_sizes()))); + args.push_back(b.ConstantLiteral(*xla::Literal::CreateR0(dim))); } xla::Shape xla_shape = diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc index 22476f4a0c51930cabf146313347e5e3bd2eaebe..cc13ab020344f525246e259449cf0ef221d3c09c 100644 --- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc @@ -60,8 +60,8 @@ class PadOp : public XlaOpKernel { xla::PaddingConfig config; for (int i = 0; i < fixed_dims; ++i) { auto* dim = config.add_dimensions(); - int before = xla::LiteralUtil::Get(pad_literal, {i, 0}); - int after = xla::LiteralUtil::Get(pad_literal, {i, 1}); + int before = pad_literal.Get({i, 0}); + int after = pad_literal.Get({i, 1}); OP_REQUIRES(ctx, before >= 0 && after >= 0, errors::InvalidArgument("Paddings must be non-negative: ", before, " ", after)); diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index 518a9372c4fa3f195ff7c77e8ef0de1ba0a8807b..dae2eb9d2a92ef8d4eabb8d6f9a79758c42d446d 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -63,7 +63,7 @@ class MinOp : public XlaReductionOp { xla::ComputationBuilder* builder) override { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type)); - return builder->ConstantLiteral(xla::LiteralUtil::MaxValue(type)); + return builder->ConstantLiteral(xla::Literal::MaxValue(type)); } void BuildReducer(xla::ComputationBuilder* builder, @@ -83,7 +83,7 @@ class MaxOp : public XlaReductionOp { xla::ComputationBuilder* builder) override { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type)); - return builder->ConstantLiteral(xla::LiteralUtil::MinValue(type)); + return builder->ConstantLiteral(xla::Literal::MinValue(type)); } void BuildReducer(xla::ComputationBuilder* builder, diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 8798c80ad5354c76a9b4061ad8913b76ae0629b0..4b5d09eb9fd4110cdc4221099ff55767e9132540 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -66,13 +66,13 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { 1, {axes_tensor_shape.num_elements()}, &axes_literal)); VLOG(1) << "data shape: " << data_shape.DebugString(); - VLOG(1) << "axes : " << xla::LiteralUtil::ToString(axes_literal); + VLOG(1) << "axes : " << axes_literal.ToString(); gtl::InlinedVector bitmap(data_shape.dims(), false); std::vector xla_axes; int64 num_elements_reduced = 1LL; for (int64 i = 0; i < axes_tensor_shape.num_elements(); ++i) { - int32 index = xla::LiteralUtil::Get(axes_literal, {i}); + int32 index = axes_literal.Get({i}); OP_REQUIRES(ctx, !(index < -data_shape.dims() || index >= data_shape.dims()), errors::InvalidArgument("Invalid reduction dimension (", index, diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index df542350b443b765a1ab35be9632cf61a38be49c..5952e752724d1e6953dd4dbb6a8099b847c64d08 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -50,7 +50,7 @@ class ReshapeOp : public XlaOpKernel { int64 product = 1; int unknown_index = -1; for (int d = 0; d < num_dims; ++d) { - const int32 size = xla::LiteralUtil::Get(literal, {d}); + const int32 size = literal.Get({d}); if (size == -1) { OP_REQUIRES( ctx, unknown_index == -1, diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index 5b6fa64fa825894b5d7bf938c5892d30f4fc11b0..c2b0e1bb4c1a141d0ab3f5b3ff5397d9da620bd8 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -32,7 +32,7 @@ template Status GetValue(int index, XlaOpKernelContext* ctx, T* value) { xla::Literal literal; TF_RETURN_IF_ERROR(ctx->ConstantInput(index, &literal)); - *value = xla::LiteralUtil::Get(literal, {}); + *value = literal.Get({}); return Status::OK(); } @@ -41,10 +41,10 @@ Status GetIntValue(int index, XlaOpKernelContext* ctx, int64* value) { TF_RETURN_IF_ERROR(ctx->ConstantInput(index, &literal)); switch (literal.shape().element_type()) { case xla::S32: - *value = xla::LiteralUtil::Get(literal, {}); + *value = literal.Get({}); break; case xla::S64: - *value = xla::LiteralUtil::Get(literal, {}); + *value = literal.Get({}); break; default: return errors::InvalidArgument("Invalid argument type for argument", @@ -58,9 +58,9 @@ template Status CreateRangeTensor(const xla::Literal& start_literal, const xla::Literal& limit_literal, const xla::Literal& delta_literal, Tensor* output) { - T start = xla::LiteralUtil::Get(start_literal, {}); - T limit = xla::LiteralUtil::Get(limit_literal, {}); - T delta = xla::LiteralUtil::Get(delta_literal, {}); + T start = start_literal.Get({}); + T limit = limit_literal.Get({}); + T delta = delta_literal.Get({}); if (delta == 0) { return errors::InvalidArgument("Requires delta != 0: ", delta); diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc index f15b354cb26d390352d866a8e827970f7c8b0c7f..83a87f19a718ce86a105e3c33ab9eaf0faff3a76 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -56,8 +56,8 @@ void SpaceToBatch(XlaOpKernelContext* ctx, padding_config.add_dimensions(); // Don't pad the batch dimension. for (int i = 0; i < block_rank; ++i) { auto* dim = padding_config.add_dimensions(); - int64 pad_start = xla::LiteralUtil::Get(paddings, {i, 0}); - int64 pad_end = xla::LiteralUtil::Get(paddings, {i, 1}); + int64 pad_start = paddings.Get({i, 0}); + int64 pad_end = paddings.Get({i, 1}); OP_REQUIRES(ctx, pad_start >= 0 && pad_end >= 0, errors::InvalidArgument("Paddings must be non-negative")); dim->set_edge_padding_low(pad_start); diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 42bde90042218b3a36f50e32d4f458d31c82d5da..44ee81461e5b31f15594c0dfb86f7219f9875768 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -39,7 +39,7 @@ class SplitOp : public XlaOpKernel { int32 split_dim; if (index_shape.dims() == 0) { - split_dim = xla::LiteralUtil::Get(literal_index, {}); + split_dim = literal_index.Get({}); } else { OP_REQUIRES( ctx, index_shape.dims() == 1, @@ -49,7 +49,7 @@ class SplitOp : public XlaOpKernel { ctx, index_shape.dim_size(0) == 1, errors::InvalidArgument("split_index input to Split Op must be a " "scalar or a vector with 1 element")); - split_dim = xla::LiteralUtil::Get(literal_index, {0}); + split_dim = literal_index.Get({0}); } const int32 num_split = num_outputs(); const TensorShape input_shape = ctx->InputShape(1); @@ -115,7 +115,7 @@ class SplitVOp : public XlaOpKernel { OP_REQUIRES(ctx, index_shape.dims() == 0, errors::InvalidArgument("split_dim input to Split Op must be a " "scalar")); - split_dim = xla::LiteralUtil::Get(literal_index, {}); + split_dim = literal_index.Get({}); xla::ComputationDataHandle input = ctx->Input(0); const TensorShape input_shape = ctx->InputShape(0); @@ -152,7 +152,7 @@ class SplitVOp : public XlaOpKernel { for (int i = 0; i < num_split; ++i) { int slice_size; - slice_size = xla::LiteralUtil::Get(split_size_literal, {i}); + slice_size = split_size_literal.Get({i}); if (slice_size == -1) { OP_REQUIRES( ctx, neg_one_dim == -1, diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 9eb689983105eff05555bbe454f97149eb8f14a2..6af4bd0496e0da926726e3f74376281f539e925a 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -63,17 +63,13 @@ class StridedSliceOp : public XlaOpKernel { &strides_tensor)); TensorShape dummy_processing_shape; - ShapeReadWriteFromTensorShape wrapped_final_shape(&final_shape); - ShapeReadWriteFromTensorShape wrapped_dummy_processing_shape( - &dummy_processing_shape); bool dummy = false; - OP_REQUIRES_OK( - ctx, ValidateStridedSliceOp( - &begin_tensor, &end_tensor, strides_tensor, - ShapeReadWriteFromTensorShape(&input_shape), begin_mask_, - end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_, - &wrapped_dummy_processing_shape, &wrapped_final_shape, &dummy, - &dummy, &dummy, &begin, &end, &strides)); + OP_REQUIRES_OK(ctx, + ValidateStridedSliceOp( + &begin_tensor, &end_tensor, strides_tensor, input_shape, + begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_, + shrink_axis_mask_, &dummy_processing_shape, &final_shape, + &dummy, &dummy, &dummy, &begin, &end, &strides)); gtl::InlinedVector dimensions_to_reverse; gtl::InlinedVector slice_begin, slice_end, slice_strides; @@ -146,14 +142,11 @@ class StridedSliceGradOp : public XlaOpKernel { &strides_tensor)); bool dummy = false; - ShapeReadWriteFromTensorShape wrapped_final_shape(&final_shape); - ShapeReadWriteFromTensorShape wrapped_processing_shape(&processing_shape); OP_REQUIRES_OK( ctx, ValidateStridedSliceOp( - &begin_tensor, &end_tensor, strides_tensor, - ShapeReadWriteFromTensorShape(&input_shape), begin_mask_, - end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_, - &wrapped_processing_shape, &wrapped_final_shape, &dummy, + &begin_tensor, &end_tensor, strides_tensor, input_shape, + begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_, + shrink_axis_mask_, &processing_shape, &final_shape, &dummy, &dummy, &dummy, &begin, &end, &strides)); // Check to make sure dy is consistent with the original slice @@ -257,17 +250,13 @@ class StridedSliceAssignOp : public XlaOpKernel { const TensorShape rhs_shape = ctx->InputShape(4); TensorShape dummy_processing_shape; - ShapeReadWriteFromTensorShape wrapped_final_shape(&final_shape); - ShapeReadWriteFromTensorShape wrapped_dummy_processing_shape( - &dummy_processing_shape); bool dummy = false; - OP_REQUIRES_OK( - ctx, ValidateStridedSliceOp( - &begin_tensor, &end_tensor, strides_tensor, - ShapeReadWriteFromTensorShape(&lhs_shape), begin_mask_, - end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_, - &wrapped_dummy_processing_shape, &wrapped_final_shape, &dummy, - &dummy, &dummy, &begin, &end, &strides)); + OP_REQUIRES_OK(ctx, + ValidateStridedSliceOp( + &begin_tensor, &end_tensor, strides_tensor, lhs_shape, + begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_, + shrink_axis_mask_, &dummy_processing_shape, &final_shape, + &dummy, &dummy, &dummy, &begin, &end, &strides)); if (final_shape.num_elements() == 0 && rhs_shape.num_elements() == 0) { // DynamicUpdateSlice does not allow 0-element updates. We should probably diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index deee7dd44dbf80f83ded3f09819365f7b6c1c7bd..d720496f74e1d5d0b92e7dd658f9156e2baf2404 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -41,36 +41,42 @@ namespace { // Since the element shape is not always provided to the TensorArrayV3 operator, // we must support lazily initialization of the TensorArray at the time of the // first write. -// If a TensorArray `var` has not been initialized, constructs storage for the -// TensorArray with elements of `elem_shape`. For both initialized and +// If a TensorArray `resource` has not been initialized, constructs storage for +// the TensorArray with elements of `elem_shape`. For both initialized and // uninitialized TensorArrays, checks that the tensor has a type compatible with // 'dtype' and shape compatible with 'elem_shape'. Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, - XlaVariable* var, DataType dtype, + XlaResource* resource, DataType dtype, const TensorShape& elem_shape) { - if (var->type != dtype) { + if (resource->kind != XlaResource::kTensorArray) { + return errors::InvalidArgument("Unexpected non-TensorArray resource"); + } + + if (resource->type != dtype) { return errors::InvalidArgument( - "TensorArray dtype is ", DataTypeString(var->type), + "TensorArray dtype is ", DataTypeString(resource->type), " but op has dtype ", DataTypeString(dtype), "."); } - TF_RET_CHECK(var->tensor_array_size >= 0) - << var->name << " size " << var->tensor_array_size; + TF_RET_CHECK(resource->tensor_array_size >= 0) + << resource->name << " size " << resource->tensor_array_size; TensorShape ta_shape; - ta_shape.AddDim(var->tensor_array_size); + ta_shape.AddDim(resource->tensor_array_size); ta_shape.AppendShape(elem_shape); - if (var->value.handle() == 0) { + if (resource->value.handle() == 0) { // TensorArray has not been initialized. - xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, var->type); - var->value = builder->Broadcast(zero, ta_shape.dim_sizes()); + xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, resource->type); + resource->value = builder->Broadcast(zero, ta_shape.dim_sizes()); } else { // Checks the elem_shape matches the TensorArray shape. - auto shape_or_status = builder->GetShape(var->value); + auto shape_or_status = builder->GetShape(resource->value); if (!shape_or_status.ok()) { return shape_or_status.status(); } - TensorShape shape = XLAShapeToTensorShape(*shape_or_status.ValueOrDie()); + TensorShape shape; + TF_RETURN_IF_ERROR( + XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape)); if (ta_shape != shape) { return errors::InvalidArgument( "Mismatched TensorArray sizes: ", ta_shape.DebugString(), " vs ", @@ -80,6 +86,45 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, return Status::OK(); } +// Checks that the TensorArray 'resource' has been initialized, and has type +// 'dtype'. Sets 'shape' to the shape +Status CheckTensorArrayIsInitialized(const string& op_name, + const XlaResource* resource, + DataType dtype) { + if (resource->kind != XlaResource::kTensorArray) { + return errors::InvalidArgument( + "Unexpected non-TensorArray resource passed " + "to ", + op_name); + } + if (resource->value.handle() == 0) { + return errors::InvalidArgument("Uninitialized TensorArray passed to ", + op_name); + } + if (resource->type != dtype) { + return errors::InvalidArgument( + "TensorArray dtype is ", DataTypeString(resource->type), + " but op has dtype ", DataTypeString(dtype), "."); + } + + return Status::OK(); +} + +Status GetTensorArrayShape(const XlaResource* resource, + xla::ComputationBuilder* builder, + TensorShape* shape) { + auto shape_or_status = builder->GetShape(resource->value); + if (!shape_or_status.ok()) { + return shape_or_status.status(); + } + TF_RETURN_IF_ERROR( + XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), shape)); + if (shape->dims() < 1) { + return errors::InvalidArgument("TensorArray rank must be >= 1"); + } + return Status::OK(); +} + // Pads 'x' with 'count' zero indices. 'x' must have 1 element. xla::ComputationDataHandle PadIndexWithZeros( xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, @@ -125,7 +170,6 @@ class TensorArrayOp : public XlaOpKernel { errors::InvalidArgument("TensorArray size must be >= 0")); xla::ComputationBuilder* b = ctx->builder(); - b->set_die_immediately_on_error(true); // Initializes the TensorArray value if we know the element shape. // Otherwise, defer initialization to the first write. @@ -141,12 +185,13 @@ class TensorArrayOp : public XlaOpKernel { } XlaContext& xc = XlaContext::Get(ctx); - XlaVariable* var; + XlaResource* var; string name = strings::StrCat("TensorArray: ", tensor_array_name_); - OP_REQUIRES_OK(ctx, - xc.CreateVariable(-1, std::move(name), dtype_, value, &var)); + OP_REQUIRES_OK( + ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name), + dtype_, value, &var)); var->tensor_array_size = size; - ctx->SetVariableOutput(0, var); + ctx->SetResourceOutput(0, var); ctx->SetConstantOutput(1, Tensor(DT_FLOAT)); } @@ -173,11 +218,12 @@ class TensorArrayWriteOp : public XlaOpKernel { // Initializes the TensorArray, if the element shape was not known at // construction time. - XlaVariable* var; - OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); - OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape)); + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); + OP_REQUIRES_OK(ctx, + MaybeInitializeTensorArray(b, resource, dtype_, elem_shape)); - xla::ComputationDataHandle ta = var->value; + xla::ComputationDataHandle ta = resource->value; xla::ComputationDataHandle index = ctx->Input(1); xla::ComputationDataHandle value = ctx->Input(2); @@ -191,7 +237,7 @@ class TensorArrayWriteOp : public XlaOpKernel { xla::ComputationDataHandle written = DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, written)); + resource->value = written; ctx->SetConstantOutput(0, Tensor(DT_FLOAT)); } @@ -210,20 +256,17 @@ class TensorArrayReadOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - DataType ta_type; - TensorShape ta_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape)); - OP_REQUIRES(ctx, ta_type == dtype_, - errors::InvalidArgument( - "TensorArray dtype is ", DataTypeString(ta_type), - " but Op requested dtype ", DataTypeString(dtype_), ".")); - OP_REQUIRES(ctx, ta_shape.dims() >= 1, - errors::InvalidArgument("TensorArray rank must be >= 1")); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle ta; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta)); + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); + + OP_REQUIRES_OK(ctx, + CheckTensorArrayIsInitialized(name(), resource, dtype_)); + TensorShape ta_shape; + OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); + + xla::ComputationDataHandle ta = resource->value; xla::ComputationDataHandle index = ctx->Input(1); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. @@ -255,13 +298,15 @@ class TensorArrayGatherOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - DataType ta_type; + xla::ComputationBuilder* b = ctx->builder(); + + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); + + OP_REQUIRES_OK(ctx, + CheckTensorArrayIsInitialized(name(), resource, dtype_)); TensorShape ta_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape)); - OP_REQUIRES(ctx, ta_type == dtype_, - errors::InvalidArgument("TensorArray type mismatch")); - OP_REQUIRES(ctx, ta_shape.dims() >= 1, - errors::InvalidArgument("TensorArray rank must be >= 1")); + OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); const TensorShape indices_shape = ctx->InputShape(1); OP_REQUIRES(ctx, indices_shape.dims() >= 1, @@ -269,10 +314,7 @@ class TensorArrayGatherOp : public XlaOpKernel { const int num_indices = indices_shape.dim_size(0); auto indices = ctx->Input(1); - xla::ComputationBuilder* b = ctx->builder(); - - xla::ComputationDataHandle ta; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta)); + xla::ComputationDataHandle ta = resource->value; // For each index in `indices`, add the corresponding slice to `slices`. std::vector slices(num_indices); @@ -320,11 +362,12 @@ class TensorArrayScatterOp : public XlaOpKernel { const TensorShape value_shape = ctx->InputShape(2); - XlaVariable* var; - OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); TensorShape elem_shape = value_shape; elem_shape.RemoveDim(0); - OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape)); + OP_REQUIRES_OK(ctx, + MaybeInitializeTensorArray(b, resource, dtype_, elem_shape)); const TensorShape indices_shape = ctx->InputShape(1); OP_REQUIRES(ctx, indices_shape.dims() >= 1, @@ -332,7 +375,7 @@ class TensorArrayScatterOp : public XlaOpKernel { const int num_indices = indices_shape.dim_size(0); const xla::ComputationDataHandle indices = ctx->Input(1); - xla::ComputationDataHandle ta = var->value; + xla::ComputationDataHandle ta = resource->value; const xla::ComputationDataHandle value = ctx->Input(2); auto slice_dims = value_shape.dim_sizes(); @@ -357,7 +400,7 @@ class TensorArrayScatterOp : public XlaOpKernel { ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); } - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, ta)); + resource->value = ta; ctx->SetConstantOutput(0, Tensor(DT_FLOAT)); } @@ -376,18 +419,17 @@ class TensorArrayConcatOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - DataType ta_type; - TensorShape ta_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape)); - OP_REQUIRES(ctx, ta_type == dtype_, - errors::InvalidArgument("TensorArray type mismatch")); - OP_REQUIRES(ctx, ta_shape.dims() >= 1, - errors::InvalidArgument("TensorArray rank must be >= 1")); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle ta; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta)); + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); + + OP_REQUIRES_OK(ctx, + CheckTensorArrayIsInitialized(name(), resource, dtype_)); + TensorShape ta_shape; + OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); + + xla::ComputationDataHandle ta = resource->value; auto ta_dims = ta_shape.dim_sizes(); std::vector shape(ta_dims.begin() + 1, ta_dims.end()); @@ -438,19 +480,20 @@ class TensorArraySplitOp : public XlaOpKernel { elem_shape.set_dim(0, length); xla::ComputationBuilder* b = ctx->builder(); - XlaVariable* var; - OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); - OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape)); - xla::ComputationDataHandle ta = var->value; + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); + OP_REQUIRES_OK(ctx, + MaybeInitializeTensorArray(b, resource, dtype_, elem_shape)); + xla::ComputationDataHandle ta = resource->value; TensorShape ta_shape; - ta_shape.AddDim(var->tensor_array_size); + ta_shape.AddDim(resource->tensor_array_size); ta_shape.AppendShape(elem_shape); - OP_REQUIRES(ctx, lengths.size() == var->tensor_array_size, + OP_REQUIRES(ctx, lengths.size() == resource->tensor_array_size, errors::InvalidArgument( "TensorArray's size is not equal to the size of lengths (", - lengths.size(), " vs. ", var->tensor_array_size, ")")); + lengths.size(), " vs. ", resource->tensor_array_size, ")")); const xla::ComputationDataHandle value = ctx->Input(1); @@ -459,8 +502,7 @@ class TensorArraySplitOp : public XlaOpKernel { value_shape.DebugString(), " vs. ", ta_shape.DebugString())); - ta = b->Add(ta, b->Reshape(value, ta_shape.dim_sizes())); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, ta)); + resource->value = b->Add(ta, b->Reshape(value, ta_shape.dim_sizes())); ctx->SetConstantOutput(0, Tensor(DT_FLOAT)); } @@ -478,8 +520,8 @@ class TensorArraySizeOp : public XlaOpKernel { explicit TensorArraySizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - XlaVariable* var; - OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); + XlaResource* var; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &var)); Tensor size_tensor(DT_INT32, {}); size_tensor.scalar()() = static_cast(var->tensor_array_size); ctx->SetConstantOutput(0, size_tensor); @@ -500,31 +542,31 @@ class TensorArrayGradOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::ComputationBuilder* b = ctx->builder(); - XlaVariable* var; - OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); - DataType ta_type; + OP_REQUIRES_OK( + ctx, CheckTensorArrayIsInitialized(name(), resource, resource->type)); TensorShape ta_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape)); - OP_REQUIRES(ctx, ta_shape.dims() >= 1, - errors::InvalidArgument("TensorArray rank must be >= 1")); + OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); // Finds or looks up the corresponding gradient TensorArray, which stores // gradients computed during backpropagation. - XlaVariable*& gradient = var->tensor_array_gradient[source_]; + XlaResource*& gradient = resource->tensor_array_gradient[source_]; if (!gradient) { - xla::ComputationDataHandle zero = XlaHelpers::Zero(b, ta_type); + xla::ComputationDataHandle zero = XlaHelpers::Zero(b, resource->type); xla::ComputationDataHandle value = b->Broadcast(zero, ta_shape.dim_sizes()); XlaContext& xc = XlaContext::Get(ctx); - string name = strings::StrCat("TensorArrayGrad: ", var->name); - OP_REQUIRES_OK(ctx, xc.CreateVariable(-1, std::move(name), var->type, - value, &gradient)); - gradient->tensor_array_size = var->tensor_array_size; + string name = strings::StrCat("TensorArrayGrad: ", resource->name); + OP_REQUIRES_OK( + ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name), + resource->type, value, &gradient)); + gradient->tensor_array_size = resource->tensor_array_size; } - ctx->SetVariableOutput(0, gradient); + ctx->SetResourceOutput(0, gradient); ctx->SetConstantOutput(1, Tensor(DT_FLOAT)); } @@ -536,5 +578,19 @@ class TensorArrayGradOp : public XlaOpKernel { REGISTER_XLA_OP(Name("TensorArrayGradV3"), TensorArrayGradOp); +class TensorArrayCloseOp : public XlaOpKernel { + public: + explicit TensorArrayCloseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + // Do nothing; XLA handles resource management. + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayCloseOp); +}; + +REGISTER_XLA_OP(Name("TensorArrayCloseV3"), TensorArrayCloseOp); + } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc index 4cc2eb8f877a873593f0460346e3379e851e8e08..9ee6bd892504e683a191484fb09259619759f36d 100644 --- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -68,7 +68,7 @@ class TileOp : public XlaOpKernel { bool all_multiples_are_one = true; bool one_dimension_is_broadcasted_without_multiple = true; for (int i = 0; i < input_dims; ++i) { - int multiple = xla::LiteralUtil::Get(literal, {i}); + int multiple = literal.Get({i}); OP_REQUIRES(ctx, multiple, errors::InvalidArgument("Expected multiples[", i, "] >= 0, but got ", multiple)); diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index e9ac1ee91b8e86a7154f42b8c51dcbb5c8a32a83..42104b951ac06fa51da15902ac599174f9745b2b 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -364,112 +364,160 @@ class ResourceApplyRMSProp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("ResourceApplyRMSProp"), ResourceApplyRMSProp); -class ResourceApplyFtrl : public XlaOpKernel { - public: - explicit ResourceApplyFtrl(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); +void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, + bool has_l2_shrinkage) { + xla::ComputationBuilder* b = ctx->builder(); + + DataType var_type, accum_type, linear_type; + TensorShape var_shape, accum_shape, linear_shape; + OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape)); + OP_REQUIRES_OK(ctx, + ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape)); + OP_REQUIRES_OK(ctx, + ctx->GetVariableTypeAndShape(2, &linear_type, &linear_shape)); + + OP_REQUIRES( + ctx, dtype == var_type && dtype == accum_type && dtype == linear_type, + errors::InvalidArgument( + "Types of variable arguments to ResourceApplyFtrlV2 must match: ", + DataTypeString(dtype), " vs. ", DataTypeString(var_type), " and ", + DataTypeString(accum_type), " and ", DataTypeString(linear_type))); + + OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), + errors::InvalidArgument( + "var and accum do not have the same shape", + var_shape.DebugString(), " ", accum_shape.DebugString())); + + OP_REQUIRES(ctx, var_shape.IsSameSize(linear_shape), + errors::InvalidArgument( + "var and linear do not have the same shape", + var_shape.DebugString(), " ", linear_shape.DebugString())); + + TensorShape grad_shape = ctx->InputShape(3); + TensorShape lr_shape = ctx->InputShape(4); + TensorShape l1_shape = ctx->InputShape(5); + TensorShape l2_shape = ctx->InputShape(6); + TensorShape l2_shrinkage_shape; + TensorShape lr_power_shape; + if (has_l2_shrinkage) { + l2_shrinkage_shape = ctx->InputShape(7); + lr_power_shape = ctx->InputShape(8); + } else { + lr_power_shape = ctx->InputShape(7); } - void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument("var and grad do not have the same shape", + var_shape.DebugString(), " ", + grad_shape.DebugString())); - DataType var_type, accum_type, linear_type; - TensorShape var_shape, accum_shape, linear_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape)); - OP_REQUIRES_OK(ctx, - ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape)); - OP_REQUIRES_OK( - ctx, ctx->GetVariableTypeAndShape(2, &linear_type, &linear_shape)); + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar: ", lr_shape.DebugString())); - OP_REQUIRES( - ctx, - dtype_ == var_type && dtype_ == accum_type && dtype_ == linear_type, - errors::InvalidArgument( - "Types of variable arguments to ResourceApplyFtrl must match: ", - DataTypeString(dtype_), " vs. ", DataTypeString(var_type), " and ", - DataTypeString(accum_type), " and ", DataTypeString(linear_type))); + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(l1_shape), + errors::InvalidArgument("l1 is not a scalar: ", l1_shape.DebugString())); - OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), - errors::InvalidArgument( - "var and accum do not have the same shape", - var_shape.DebugString(), " ", accum_shape.DebugString())); - - OP_REQUIRES(ctx, var_shape.IsSameSize(linear_shape), - errors::InvalidArgument( - "var and linear do not have the same shape", - var_shape.DebugString(), " ", linear_shape.DebugString())); + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(l2_shape), + errors::InvalidArgument("l2 is not a scalar: ", l2_shape.DebugString())); - TensorShape grad_shape = ctx->InputShape(3); - TensorShape lr_shape = ctx->InputShape(4); - TensorShape l1_shape = ctx->InputShape(5); - TensorShape l2_shape = ctx->InputShape(6); - TensorShape lr_power_shape = ctx->InputShape(7); + if (has_l2_shrinkage) { + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2_shrinkage_shape), + errors::InvalidArgument("l2_shrinkage is not a scalar: ", + l2_shrinkage_shape.DebugString())); + } - OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), - errors::InvalidArgument( - "var and grad do not have the same shape", - var_shape.DebugString(), " ", grad_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_power_shape), + errors::InvalidArgument("lr_power is not a scalar: ", + lr_power_shape.DebugString())); + + xla::ComputationDataHandle var, accum, linear; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &linear)); + xla::ComputationDataHandle grad = ctx->Input(3); + xla::ComputationDataHandle lr = ctx->Input(4); + xla::ComputationDataHandle l1 = ctx->Input(5); + xla::ComputationDataHandle l2 = ctx->Input(6); + xla::ComputationDataHandle l2_shrinkage; + xla::ComputationDataHandle lr_power; + if (has_l2_shrinkage) { + l2_shrinkage = ctx->Input(7); + lr_power = ctx->Input(8); + } else { + lr_power = ctx->Input(7); + } - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), - errors::InvalidArgument("lr is not a scalar: ", - lr_shape.DebugString())); + // grad_to_use = grad + 2 * l2_shrinkage * var + // new_accum = accum + grad_to_use * grad_to_use + // linear += grad_to_use - + // (new_accum^(-lr_power) - accum^(-lr_power)) / lr * var + // quadratic = (new_accum^(-lr_power) / lr) + 2 * l2 + // var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 + // accum = new_accum + + xla::ComputationDataHandle zero_broadcast = b->Broadcast( + XlaHelpers::FloatLiteral(b, dtype, 0.0), var_shape.dim_sizes()); + xla::ComputationDataHandle two = XlaHelpers::FloatLiteral(b, dtype, 2.0); + xla::ComputationDataHandle grad_to_use; + if (has_l2_shrinkage) { + grad_to_use = b->Add(grad, b->Mul(two, b->Mul(l2_shrinkage, var))); + } else { + grad_to_use = grad; + } - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1_shape), - errors::InvalidArgument("l1 is not a scalar: ", - l1_shape.DebugString())); + xla::ComputationDataHandle new_accum = + b->Add(accum, b->Pow(grad_to_use, two)); + xla::ComputationDataHandle new_accum_lr_pow = + b->Pow(new_accum, b->Neg(lr_power)); + xla::ComputationDataHandle accum_lr_pow = b->Pow(accum, b->Neg(lr_power)); + linear = b->Add( + linear, + b->Sub(grad_to_use, + b->Mul(b->Div(b->Sub(new_accum_lr_pow, accum_lr_pow), lr), var))); + xla::ComputationDataHandle quadratic = + b->Add(b->Div(new_accum_lr_pow, lr), b->Mul(two, l2)); + xla::ComputationDataHandle pre_shrink = + b->Div(b->Sub(b->Mul(l1, b->Sign(linear)), linear), quadratic); + var = b->Select(b->Gt(b->Abs(linear), l1), pre_shrink, zero_broadcast); + accum = new_accum; + + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype, var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype, accum)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype, linear)); +} - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2_shape), - errors::InvalidArgument("l2 is not a scalar: ", - l2_shape.DebugString())); +class ResourceApplyFtrl : public XlaOpKernel { + public: + explicit ResourceApplyFtrl(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_power_shape), - errors::InvalidArgument("lr_power is not a scalar: ", - lr_power_shape.DebugString())); + void Compile(XlaOpKernelContext* ctx) override { + CompileFtrl(ctx, dtype_, /*has_l2_shrinkage=*/false); + } - xla::ComputationDataHandle var, accum, linear; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &linear)); - xla::ComputationDataHandle grad = ctx->Input(3); - xla::ComputationDataHandle lr = ctx->Input(4); - xla::ComputationDataHandle l1 = ctx->Input(5); - xla::ComputationDataHandle l2 = ctx->Input(6); - xla::ComputationDataHandle lr_power = ctx->Input(7); - - // new_accum = accum + grad * grad - // linear += grad - (new_accum^(-lr_power) - accum^(-lr_power)) / lr * var - // quadratic = (new_accum^(-lr_power) / lr) + 2 * l2 - // var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 - // accum = new_accum - - xla::ComputationDataHandle zero_broadcast = b->Broadcast( - XlaHelpers::FloatLiteral(b, dtype_, 0.0), var_shape.dim_sizes()); - xla::ComputationDataHandle two = XlaHelpers::FloatLiteral(b, dtype_, 2.0); + private: + DataType dtype_; +}; +REGISTER_XLA_OP(Name("ResourceApplyFtrl"), ResourceApplyFtrl); - xla::ComputationDataHandle new_accum = b->Add(accum, b->Pow(grad, two)); - xla::ComputationDataHandle new_accum_lr_pow = - b->Pow(new_accum, b->Neg(lr_power)); - xla::ComputationDataHandle accum_lr_pow = b->Pow(accum, b->Neg(lr_power)); - linear = b->Add( - linear, - b->Sub(grad, b->Mul(b->Div(b->Sub(new_accum_lr_pow, accum_lr_pow), lr), - var))); - xla::ComputationDataHandle quadratic = - b->Add(b->Div(new_accum_lr_pow, lr), b->Mul(two, l2)); - xla::ComputationDataHandle pre_shrink = - b->Div(b->Sub(b->Mul(l1, b->Sign(linear)), linear), quadratic); - var = b->Select(b->Gt(b->Abs(linear), l1), pre_shrink, zero_broadcast); - accum = new_accum; +class ResourceApplyFtrlV2 : public XlaOpKernel { + public: + explicit ResourceApplyFtrlV2(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum)); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, linear)); + void Compile(XlaOpKernelContext* ctx) override { + CompileFtrl(ctx, dtype_, /*has_l2_shrinkage=*/true); } private: DataType dtype_; }; -REGISTER_XLA_OP(Name("ResourceApplyFtrl"), ResourceApplyFtrl); +REGISTER_XLA_OP(Name("ResourceApplyFtrlV2"), ResourceApplyFtrlV2); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index abe4949f5dbc8034fa46828e3ff872cae7591d90..ab96d86ed2aa3299757d357aede86404ca8c8a21 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -44,6 +44,7 @@ namespace { // Return x if x>0, otherwise -x. XLAJIT_MAKE_UNARY(Abs, b->Abs(x)); XLAJIT_MAKE_UNARY(Ceil, b->Ceil(x)); +XLAJIT_MAKE_UNARY(Cos, b->Cos(x)); XLAJIT_MAKE_UNARY(Exp, b->Exp(x)); XLAJIT_MAKE_UNARY(Floor, b->Floor(x)); // Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0. @@ -77,12 +78,19 @@ static xla::ComputationDataHandle Round(xla::ComputationBuilder* b, b->LogicalAnd(b->Eq(fraction, half), is_odd)), b->Add(round_val, one), round_val); } -XLAJIT_MAKE_UNARY(Round, Round(b, input_type(0), x)); +// Expresses sigmoid as a rescaled tanh: sigmoid(x) == (tanh(x/2) + 1) / 2. +static xla::ComputationDataHandle Sigmoid(xla::ComputationBuilder* b, + DataType dtype, + const xla::ComputationDataHandle& x) { + auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5); + return b->Add(half, b->Mul(half, b->Tanh(b->Mul(half, x)))); +} + +XLAJIT_MAKE_UNARY(Round, Round(b, input_type(0), x)); XLAJIT_MAKE_UNARY(Rsqrt, b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), -0.5))); -XLAJIT_MAKE_UNARY(Sigmoid, - b->Map({x}, *ctx->GetOrCreateSigmoid(input_type(0)))); +XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(b, input_type(0), x)); XLAJIT_MAKE_UNARY(Softplus, b->Log(b->Add(b->Exp(x), XlaHelpers::One(b, input_type(0))))); XLAJIT_MAKE_UNARY(Sqrt, diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..37c611d14792d85d88bb0fce1c76b5af1cfa1965 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -0,0 +1,269 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/kernels/while_op.h" + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +namespace { + +// Builds XlaCompiler argument descriptions `args` from `ctx`. +Status MakeXlaCompilerArgumentsFromInputs( + XlaOpKernelContext* ctx, std::vector* args, + bool* has_uninitialized_vars) { + VLOG(2) << "Num inputs " << ctx->num_inputs(); + args->resize(ctx->num_inputs()); + *has_uninitialized_vars = false; + for (int i = 0; i < ctx->num_inputs(); ++i) { + VLOG(2) << " Input " << i + << " type: " << DataTypeString(ctx->input_type(i)) + << " shape: " << ctx->InputShape(i).DebugString(); + XlaCompiler::Argument& arg = (*args)[i]; + DataType type = ctx->input_type(i); + // When reading a resource input, use the type and shape of the resource's + // current value. + if (type == DT_RESOURCE) { + XlaResource* resource; + TF_RETURN_IF_ERROR(ctx->GetResourceInput(i, &resource)); + + arg.initialized = resource->value.handle() > 0; + switch (resource->kind) { + case XlaResource::kVariable: + arg.kind = XlaCompiler::Argument::kVariable; + break; + case XlaResource::kTensorArray: + arg.kind = XlaCompiler::Argument::kTensorArray; + break; + case XlaResource::kInvalid: + CHECK(false); + } + arg.type = resource->type; + if (arg.initialized) { + auto shape = ctx->builder()->GetShape(resource->value); + TF_RETURN_IF_ERROR(shape.status()); + TF_RETURN_IF_ERROR( + XLAShapeToTensorShape(*shape.ValueOrDie(), &arg.shape)); + } else { + *has_uninitialized_vars = true; + } + arg.tensor_array_size = resource->tensor_array_size; + arg.name = resource->name; + // TODO(phawkins): propagate TensorArray gradients into loops. + VLOG(2) << " resource " << resource->name + << " type: " << DataTypeString(arg.type) + << " shape: " << arg.shape.DebugString() + << " initialized: " << arg.initialized; + + } else { + arg.kind = XlaCompiler::Argument::kParameter; + arg.type = ctx->input_type(i); + arg.shape = ctx->InputShape(i); + } + } + return Status::OK(); +} + +} // anonymous namespace + +XlaWhileOp::XlaWhileOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + const NameAttrList* name_attr; + OP_REQUIRES_OK(ctx, ctx->GetAttr("cond", &name_attr)); + cond_name_attr_ = *name_attr; + OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &name_attr)); + body_name_attr_ = *name_attr; +} + +void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { + VLOG(1) << "WhileOp::Compile"; + + std::vector arguments; + bool has_uninitialized_vars; + OP_REQUIRES_OK(ctx, MakeXlaCompilerArgumentsFromInputs( + ctx, &arguments, &has_uninitialized_vars)); + + const bool use_tuple_arg = (arguments.size() != 1); + + xla::ComputationBuilder* builder = ctx->builder(); + XlaCompiler* compiler = ctx->compiler(); + + VLOG(1) << "Compiling body"; + + // All resource that are inputs to the loop's body must also be + // present as loop body outputs; the signature of the loop's input and + // output must match. We ensure this by asking the compiler to include the + // current values of all resources, even if they haven't been updated by the + // computation. We must also ask the compiler to keep compile-time constant + // outputs as part of the generated computation, for the same reason. + // TODO(phawkins): consider adding loop-invariant inputs to XLA's While() + // operator. + XlaCompiler::CompileOptions body_options; + body_options.use_tuple_arg = use_tuple_arg; + body_options.return_updated_values_for_all_resources = true; + body_options.resolve_compile_time_constants = false; + XlaCompiler::CompilationResult body; + OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_, + arguments, &body)); + + // We must use a static shape for parameters to an XLA compilation. However, + // we may not know the shape of a TensorArray if it is first written inside + // the loop. Ideally we would require the user to provide a static shape, + // but this is not always easy. + // So if uninitialized resource are used by the loop body, we compile the + // body function twice: + // 1) once with uninitialized resource inputs. We discard the computation + // but we assume resource shapes reach a fixpoint after one iteration. + // So we can use the output shapes of the resource as the "true" shapes. + // 2) again with the "correct" input shapes determined by (1). + if (has_uninitialized_vars) { + // Initializes any uninitialized resource with zero values of the + // shape determined by the first compilation. + for (int i = 0; i < body.resource_updates.size(); ++i) { + const XlaCompiler::ResourceUpdate& update = body.resource_updates[i]; + XlaCompiler::Argument& arg = arguments[update.input_index]; + if (!arg.initialized) { + arg.initialized = true; + arg.shape = update.shape; + + XlaResource* resource; + OP_REQUIRES_OK(ctx, + ctx->GetResourceInput(update.input_index, &resource)); + + xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, arg.type); + resource->value = builder->Broadcast(zero, update.shape.dim_sizes()); + } + } + // Recompile the body with the "correct" shapes. + body = {}; + OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_, + arguments, &body)); + } + + VLOG(1) << "Compiling condition"; + + XlaCompiler::CompileOptions cond_options; + cond_options.use_tuple_arg = use_tuple_arg; + cond_options.resolve_compile_time_constants = false; + XlaCompiler::CompilationResult cond; + OP_REQUIRES_OK(ctx, compiler->CompileFunction(cond_options, cond_name_attr_, + arguments, &cond)); + + xla::Shape body_input_shape, cond_input_shape; + if (use_tuple_arg) { + body_input_shape = xla::ShapeUtil::MakeTupleShape(body.xla_input_shapes); + cond_input_shape = xla::ShapeUtil::MakeTupleShape(cond.xla_input_shapes); + } else { + CHECK(!body.xla_input_shapes.empty()); + body_input_shape = body.xla_input_shapes[0]; + CHECK(!body.xla_input_shapes.empty()); + cond_input_shape = cond.xla_input_shapes[0]; + } + + VLOG(2) << "Body shape: " << xla::ShapeUtil::HumanString(body_input_shape) + << " -> " << xla::ShapeUtil::HumanString(body.xla_output_shape); + VLOG(2) << "Cond shape: " << xla::ShapeUtil::HumanString(cond_input_shape) + << " -> " << xla::ShapeUtil::HumanString(cond.xla_output_shape); + + OP_REQUIRES(ctx, + xla::ShapeUtil::Compatible(body_input_shape, cond_input_shape), + errors::InvalidArgument( + "Input shapes of loop body and condition do not match: ", + xla::ShapeUtil::HumanString(body_input_shape), " vs. ", + xla::ShapeUtil::HumanString(cond_input_shape))); + OP_REQUIRES( + ctx, xla::ShapeUtil::Compatible(body_input_shape, body.xla_output_shape), + errors::InvalidArgument( + "Input and output shapes of loop body do not match: ", + xla::ShapeUtil::HumanString(body_input_shape), " vs. ", + xla::ShapeUtil::HumanString(body.xla_output_shape))); + + xla::ComputationDataHandle data; + + int num_inputs = body.input_mapping.size(); + + std::vector inputs(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + int input_num = body.input_mapping[i]; + if (ctx->input_type(input_num) == DT_RESOURCE) { + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource)); + inputs[i] = resource->value; + } else { + inputs[i] = ctx->Input(i); + } + } + + xla::ComputationDataHandle init; + if (use_tuple_arg) { + init = builder->Tuple(inputs); + } else { + init = inputs[0]; + } + + VLOG(1) << "Building while loop"; + + xla::ComputationDataHandle while_result = + builder->While(*cond.computation, *body.computation, init); + + auto get_loop_output = [&](int i) { + if (use_tuple_arg) { + return builder->GetTupleElement(while_result, i); + } else { + return while_result; + } + }; + + // Sets non-variable outputs. + for (int i = 0; i < ctx->num_outputs(); ++i) { + if (ctx->input_type(i) != DT_RESOURCE) { + ctx->SetOutput(body.input_mapping[i], get_loop_output(i)); + } + } + + // Updates the values of any resource variables modified by the loop. + for (int i = 0; i < body.resource_updates.size(); ++i) { + const XlaCompiler::ResourceUpdate& update = body.resource_updates[i]; + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource)); + if (update.modified) { + int pos = body.outputs.size() + i; + resource->value = get_loop_output(pos); + } + VLOG(2) << "Loop-carried variable: pos: " << update.input_index + << " name: " << resource->name << " modified: " << update.modified + << " type: " << DataTypeString(update.type) + << " shape: " << update.shape.DebugString(); + // Copies the identity of the resource variable from input to output + // unchanged, even if the variable was not modified. + ctx->op_kernel_context()->set_output( + update.input_index, + ctx->op_kernel_context()->input(update.input_index)); + } + + VLOG(1) << "Done building while loop"; +} + +REGISTER_XLA_OP(Name("XlaWhile").AllowResourceTypes(), XlaWhileOp); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.h b/tensorflow/compiler/tf2xla/kernels/while_op.h new file mode 100644 index 0000000000000000000000000000000000000000..67edebabf9f643a919d0f06c228e2d224a49a2af --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/while_op.h @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_WHILE_OP_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_WHILE_OP_H_ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/attr_value.pb.h" + +namespace tensorflow { + +// This TensorFlow op provides a functional iteration primitive. +// +// The inputs and outputs of the loop body must agree on the number, types, and +// shapes of the Tensors carried around the loop body. +// +// Computations in while loops may read from and write to resource variables. +// Resource variables may be passed as arguments to a function's body and +// condition functions. The XlaCompiler converts resource variable arguments +// into parameters to the XLA computation and moves them to the end of the +// parameter list, and by using the `return_updated_values_for_all_variables` +// we ensure that all variables that appear in the input also appear at the +// end of the body's output. This ensures the loop body's input and output +// signatures match. +// +// It is the user's responsibility to ensure that each non-variable _Arg matches +// the corresponding _Retval. +// +// For example, suppose we have a loop body with arguments: +// DT_INT32, DT_RESOURCE (pointing to a DT_BOOL var), DT_FLOAT +// and return values +// DT_INT32, DT_FLOAT +// It is an error for the body to return DT_RESOURCE values. +// +// The body will be lowered into an XLA computation that takes and returns a +// tuple with XLA type (I32, F32, PRED). Note the resource variable appears at +// the end of both the loop body's input and output argument lists. +class XlaWhileOp : public XlaOpKernel { + public: + explicit XlaWhileOp(OpKernelConstruction* ctx); + + void Compile(XlaOpKernelContext* ctx) override; + + private: + NameAttrList cond_name_attr_; + NameAttrList body_name_attr_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaWhileOp); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_WHILE_OP_H_ diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 1f2bc01cf4a48b37de585c55b781c239ee4b8f2a..576cd9bf9abb43e29d9eb8f706e0f42ac2d038e9 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -27,13 +27,13 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) { TF_RETURN_IF_ERROR(TensorShapeToXLAShape( host_tensor.dtype(), host_tensor.shape(), literal->mutable_shape())); - xla::LiteralUtil::Reserve(host_tensor.NumElements(), literal); + literal->Reserve(host_tensor.NumElements()); // memcpy over the payload ... // TODO(phawkins): handle string types. size_t total_bytes = host_tensor.TotalBytes(); if (total_bytes > 0) { - void* dst_ptr = xla::LiteralUtil::MutableInternalData(literal); + void* dst_ptr = literal->MutableInternalData(); const void* src_ptr = DMAHelper::base(&host_tensor); memcpy(dst_ptr, src_ptr, total_bytes); } @@ -51,11 +51,12 @@ Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type, " to tensor of type ", DataTypeString(target_type)); } - TensorShape shape = XLAShapeToTensorShape(literal.shape()); + TensorShape shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(literal.shape(), &shape)); *host_tensor = Tensor(target_type, shape); size_t total_bytes = host_tensor->TotalBytes(); if (total_bytes > 0) { - const void* src_ptr = xla::LiteralUtil::InternalData(literal); + const void* src_ptr = literal.InternalData(); void* dst_ptr = DMAHelper::base(host_tensor); memcpy(dst_ptr, src_ptr, total_bytes); } diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc index 56993bc58534d1225f9177719804a69f561b3a06..f3d6787daaa1165b28ce63dfd501533fa0963edd 100644 --- a/tensorflow/compiler/tf2xla/literal_util_test.cc +++ b/tensorflow/compiler/tf2xla/literal_util_test.cc @@ -27,7 +27,7 @@ TEST(LiteralUtil, LiteralToHostTensor) { { std::vector int64_values = {1, 2, 3}; std::unique_ptr int64_values_literal = - xla::LiteralUtil::CreateR1(gtl::ArraySlice(int64_values)); + xla::Literal::CreateR1(gtl::ArraySlice(int64_values)); Tensor host_tensor; EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32", LiteralToHostTensor(*int64_values_literal, DT_INT32, &host_tensor) @@ -48,7 +48,7 @@ TEST(LiteralUtil, LiteralToHostTensor) { Tensor host_tensor; std::vector int32_values = {10, 11}; std::unique_ptr int32_values_literal = - xla::LiteralUtil::CreateR1(gtl::ArraySlice(int32_values)); + xla::Literal::CreateR1(gtl::ArraySlice(int32_values)); EXPECT_TRUE( LiteralToHostTensor(*int32_values_literal, DT_INT32, &host_tensor) .ok()); diff --git a/tensorflow/compiler/tf2xla/ops/BUILD b/tensorflow/compiler/tf2xla/ops/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..a2bd06861d5f383e3497a386b42a2e5a4035f1ea --- /dev/null +++ b/tensorflow/compiler/tf2xla/ops/BUILD @@ -0,0 +1,38 @@ +package( + default_visibility = ["//tensorflow/compiler/tf2xla:internal"], +) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") + +cc_library( + name = "functional_ops", + srcs = ["functional_ops.cc"], + deps = [ + "//tensorflow/core:framework", + ], + alwayslink = 1, +) + +tf_gen_op_wrapper_py( + name = "gen_functional_ops", + out = "gen_functional_ops.py", + deps = [ + ":functional_ops", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/tf2xla/ops/functional_ops.cc b/tensorflow/compiler/tf2xla/ops/functional_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..38bcaa32278c4acf212881b10d66bb67b807a21c --- /dev/null +++ b/tensorflow/compiler/tf2xla/ops/functional_ops.cc @@ -0,0 +1,45 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +// TODO(b/37549631) setting the While Op to always be stateful is too +// conservative. +REGISTER_OP("XlaWhile") + .Input("input: T") + .Output("output: T") + .Attr("T: list(type) >= 0") + .Attr("cond: func") + .Attr("body: func") + .SetIsStateful() + .Doc(R"doc( +output = input; While (Cond(output)) { output = Body(output) } + +input: A list of input tensors whose types are T. +output: A list of output tensors whose types are T. +cond: A function takes 'input' and returns a tensor. If the tensor is + a scalar of non-boolean, the scalar is converted to a boolean + according to the following rule: if the scalar is a numerical + value, non-zero means True and zero means False; if the scalar is + a string, non-empty means True and empty means False. If the + tensor is not a scalar, non-emptiness means True and False + otherwise. +body: A function that takes a list of tensors and returns another + list of tensors. Both lists have the same types as specified by T. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc index f5ecb51a5b77e36e606ed1c48b8e2dbe76de0074..9d1992205b02665b99b1bd15b7b65a1fb8c35a51 100644 --- a/tensorflow/compiler/tf2xla/shape_util.cc +++ b/tensorflow/compiler/tf2xla/shape_util.cc @@ -24,12 +24,18 @@ limitations under the License. namespace tensorflow { // Convert an XLA Shape into the equivalent TensorFlow shape. -TensorShape XLAShapeToTensorShape(const xla::Shape& shape) { - TensorShape tensor_shape; +Status XLAShapeToTensorShape(const xla::Shape& shape, + TensorShape* tensor_shape) { + if (xla::ShapeUtil::IsTuple(shape)) { + return errors::InvalidArgument("XLA shape ", + xla::ShapeUtil::HumanString(shape), + " cannot be converted to a TensorShape"); + } + *tensor_shape = TensorShape(); for (int i = 0; i < xla::ShapeUtil::Rank(shape); ++i) { - tensor_shape.AddDim(shape.dimensions(i)); + tensor_shape->AddDim(shape.dimensions(i)); } - return tensor_shape; + return Status::OK(); } // Convert a TensorShape into the equivalent XLA Shape proto. diff --git a/tensorflow/compiler/tf2xla/shape_util.h b/tensorflow/compiler/tf2xla/shape_util.h index 516dd636a970f78fda363a0b13961b8244dc2cd9..58240b9c965a194b9380ac7cd477ce7344e5ebe3 100644 --- a/tensorflow/compiler/tf2xla/shape_util.h +++ b/tensorflow/compiler/tf2xla/shape_util.h @@ -24,8 +24,10 @@ limitations under the License. namespace tensorflow { -// Convert an XLA Shape into the equivalent TensorFlow shape. -TensorShape XLAShapeToTensorShape(const xla::Shape& shape); +// Convert an XLA Shape into the equivalent TensorFlow shape. May fail since +// not all XLA shapes can be represented as TensorShapes. +Status XLAShapeToTensorShape(const xla::Shape& shape, + TensorShape* tensor_shape); // Convert a TensorShape into the equivalent XLA Shape proto. Unlike Tensorflow, // XLA shapes include the type. Not all `dtype` values can be represented by diff --git a/tensorflow/compiler/tf2xla/test_util.cc b/tensorflow/compiler/tf2xla/test_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..3c6c9a91b6d2fb47f6dee1c347e9b852f1eea3ec --- /dev/null +++ b/tensorflow/compiler/tf2xla/test_util.cc @@ -0,0 +1,43 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/test_util.h" + +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/node_def.pb.h" + +namespace tensorflow { + +Status InstantiateFunctionForTest(const string& name, + const FunctionLibraryDefinition& library, + InstantiationResultForTest* result) { + const FunctionDef* fdef = library.Find(name); + TF_RET_CHECK(fdef != nullptr); + + auto get_func_sig = [&library](const string& op, const OpDef** sig) { + return library.LookUpOpDef(op, sig); + }; + InstantiationResult inst; + TF_RETURN_IF_ERROR( + InstantiateFunction(*fdef, AttrSlice(), get_func_sig, &inst)); + result->arg_types = inst.arg_types; + result->ret_types = inst.ret_types; + for (NodeDef& n : inst.nodes) { + *result->gdef.add_node() = std::move(n); + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/test_util.h b/tensorflow/compiler/tf2xla/test_util.h new file mode 100644 index 0000000000000000000000000000000000000000..e6e4ae92ed23f3fca0f59b131dc73152e0947b72 --- /dev/null +++ b/tensorflow/compiler/tf2xla/test_util.h @@ -0,0 +1,47 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Helper functions for tests. + +#ifndef TENSORFLOW_COMPILER_TF2XLA_TEST_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_TEST_UTIL_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Same as InstantiationResult, but has a GraphDef instead of just nodes. +struct InstantiationResultForTest { + DataTypeVector arg_types; + DataTypeVector ret_types; + GraphDef gdef; +}; + +// Instantiates a function, producing a GraphDef to compare against the +// expected graph. +Status InstantiateFunctionForTest(const string& name, + const FunctionLibraryDefinition& library, + InstantiationResultForTest* result); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_TEST_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.h b/tensorflow/compiler/tf2xla/xla_compilation_device.h index 75630bee3961243b2389274f0f98200ee3a0a7eb..e4f43f1950d3c8ced57e80994845dbfd498df4ed 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.h +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.h @@ -64,26 +64,35 @@ class XlaCompilationDevice : public LocalDevice { std::unique_ptr allocator_; }; -struct XlaVariable { - // If this variable is visible externally, what was its argument number? +// Represents a resource, such as a Variable or TensorArray. +struct XlaResource { + enum Kind { + kInvalid, + kVariable, + kTensorArray, + }; + + Kind kind = kInvalid; + + // If this resource is visible externally, what was its argument number? int arg_num = -1; - // A descriptive name for the variable, used in error messages. + // A descriptive name for the resource, used in error messages. string name; - // Current type and value of the variable. Uninitialized variables are + // Current type and value of the resource. Uninitialized resources are // represented by a default (zero) handle and type DT_INVALID. - // While the type of a variable is notionally fixed during execution, when - // a variable is first initialized we do not yet know its type, so we keep + // While the type of a resource is notionally fixed during execution, when + // a resource is first initialized we do not yet know its type, so we keep // track of its type dynamically. DataType type = DT_INVALID; xla::ComputationDataHandle value; - // Value of the variable at computation entry. Used to detect which + // Value of the resource at computation entry. Used to detect which // variables have new values that need to be written back. xla::ComputationDataHandle initial_value; - // We treat TensorArrays as a Variable with some extra metadata. + // TensorArray-specific fields // 'tensor_array_size' stores the expected size of the TensorArray. We need // to store this since sometimes TensorArrays must be initialized lazily since @@ -91,10 +100,10 @@ struct XlaVariable { int64 tensor_array_size = -1; // 'tensor_array_gradient' is a map from TensorArrayGradV3 'source' attributes - // to an XlaVariable containing the gradient TensorArrays. We store a pointer + // to an XlaResource containing the gradient TensorArrays. We store a pointer // here since there should only be one gradient TensorArray per 'source' // string, irrespective of the number of calls to TensorArrayGrad. - std::unordered_map tensor_array_gradient; + std::unordered_map tensor_array_gradient; }; // A XlaExpression wraps an XLA computation. Each Tensor on an @@ -115,8 +124,8 @@ class XlaExpression { bool has_constant_value() const { return has_constant_value_; } const Tensor& constant_value() const { return constant_value_; } - void set_variable(XlaVariable* variable) { variable_ = variable; } - XlaVariable* variable() const { return variable_; } + void set_resource(XlaResource* resource) { resource_ = resource; } + XlaResource* resource() const { return resource_; } private: // The XLA handle of the expression's computation. @@ -128,7 +137,7 @@ class XlaExpression { bool has_constant_value_ = false; Tensor constant_value_; - XlaVariable* variable_ = nullptr; // Not owned. + XlaResource* resource_ = nullptr; // Not owned. TF_DISALLOW_COPY_AND_ASSIGN(XlaExpression); }; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 580ce3d802e71ef99903321fff2bc7374d0a9470..c1af731986ccccb2eff5f31aa37646a32099ff3b 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" @@ -85,9 +86,10 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) (*options_.populate_resource_manager)(device_->resource_manager()); } + flib_def_.reset(new FunctionLibraryDefinition(*options.flib_def)); flib_runtime_.reset(NewFunctionLibraryRuntime( &device_mgr_, Env::Default(), device_, options.graph_def_version, - options.flib_def, OptimizerOptions(), + flib_def_.get(), OptimizerOptions(), nullptr /* custom_kernel_creator */)); } @@ -129,9 +131,11 @@ Status XlaCompiler::CompileFunction( std::unique_ptr graph(new Graph(options_.flib_def)); CopyGraph(*fbody->graph, graph.get()); - if (VLOG_IS_ON(1)) { - dump_graph::DumpGraphToFile( - strings::StrCat("xla_compile_function_input_", function_id), *graph); + if (VLOG_IS_ON(2)) { + VLOG(2) << "XlaCompiler::CompileFunction: " + << dump_graph::DumpGraphToFile( + strings::StrCat("xla_compile_function_", function_id), + *graph); } // Optimize the graph before running the compiler. @@ -143,12 +147,6 @@ Status XlaCompiler::CompileFunction( optimizer.Optimize(flib_runtime_.get(), flib_runtime_->env(), /*device=*/nullptr, &graph); - if (VLOG_IS_ON(1)) { - dump_graph::DumpGraphToFile( - strings::StrCat("xla_compile_function_optimized_", function_id), - *graph); - } - VLOG(1) << "===================================================="; TF_RETURN_IF_ERROR( CompileGraph(options, function_id, std::move(graph), args, result)); @@ -249,35 +247,36 @@ Status BuildArguments(const std::vector& args, std::vector* input_shapes) { context_args->resize(args.size()); - // Argument numbers of arguments and variables that are to be passed to the + // Argument numbers of arguments and resources that are to be passed to the // XLA computation as runtime parameters. - std::vector parameters, variables; + std::vector parameters, resources; parameters.reserve(args.size()); - variables.reserve(args.size()); + resources.reserve(args.size()); for (std::vector::size_type i = 0; i < args.size(); ++i) { XlaContext::Argument& context_arg = (*context_args)[i]; + context_arg.kind = args[i].kind; context_arg.name = args[i].name; context_arg.value.constant_value = args[i].constant_value; context_arg.value.type = args[i].type; switch (args[i].kind) { case XlaCompiler::Argument::kVariable: - variables.push_back(i); - context_arg.is_variable = true; - context_arg.value.is_constant = false; + case XlaCompiler::Argument::kTensorArray: + context_arg.is_resource = true; + if (args[i].initialized) { + resources.push_back(i); + context_arg.value.is_constant = false; + } else { + context_arg.value.is_constant = true; + } context_arg.tensor_array_size = args[i].tensor_array_size; break; case XlaCompiler::Argument::kParameter: parameters.push_back(i); context_arg.value.is_constant = false; break; - case XlaCompiler::Argument::kUninitializedVariable: - context_arg.is_variable = true; - context_arg.value.is_constant = true; - context_arg.tensor_array_size = args[i].tensor_array_size; - break; case XlaCompiler::Argument::kConstant: context_arg.value.is_constant = true; break; @@ -288,7 +287,7 @@ Status BuildArguments(const std::vector& args, // Append parameters containing variable values after the other runtime // parameters. - parameters.insert(parameters.end(), variables.begin(), variables.end()); + parameters.insert(parameters.end(), resources.begin(), resources.end()); if (parameters.empty()) { return Status::OK(); } @@ -329,22 +328,22 @@ Status BuildArguments(const std::vector& args, // variable states, generated by the symbolic evaluation. // If `has_side_effects` is true, the computation has side effects and should be // built even if it has no outputs. -// If `return_updated_values_for_all_variables` is true, all variables will be -// included in `variable_updates`, regardless of whether their value changed. +// If `return_updated_values_for_all_resources` is true, all resources will be +// included in `resource_updates`, regardless of whether their value changed. // Sets `*num_nonconst_outputs` to the number of outputs of the `computation`. -// Sets `*variable_updates` to a description of variables whose values are +// Sets `*resource_updates` to a description of resources whose values are // written by the computation; the variable writes are the last -// `variable_updates.size()` return values from the computation. Each entry in -// `variable_updates` is a (input_index, type) pair, where `input_index` is the +// `resource_updates.size()` return values from the computation. Each entry in +// `resource_updates` is a (input_index, type) pair, where `input_index` is the // index of a resource variable argument to the computation, and `type` is the // type of the final output. Status BuildComputation( const std::vector& retvals, - const std::vector>& variables, - bool has_side_effects, bool return_updated_values_for_all_variables, + const std::vector>& resources, + bool has_side_effects, bool return_updated_values_for_all_resources, xla::ComputationBuilder* builder, xla::Computation* computation, int* num_nonconst_outputs, - std::vector* variable_updates) { + std::vector* resource_updates) { std::vector elems; elems.reserve(retvals.size()); for (const XlaContext::HandleOrConstant& retval : retvals) { @@ -354,24 +353,24 @@ Status BuildComputation( } *num_nonconst_outputs = elems.size(); - // Add return values for variables whose values have changed. - std::vector arg_vars; - arg_vars.reserve(variables.size()); - for (const auto& var : variables) { + // Add return values for resources whose values have changed. + std::vector arg_vars; + arg_vars.reserve(resources.size()); + for (const auto& var : resources) { if (var->arg_num >= 0) { arg_vars.push_back(var.get()); } } std::sort(arg_vars.begin(), arg_vars.end(), - [](const XlaVariable* a, const XlaVariable* b) { + [](const XlaResource* a, const XlaResource* b) { return a->arg_num < b->arg_num; }); - for (const XlaVariable* var : arg_vars) { + for (const XlaResource* var : arg_vars) { bool modified = var->value.handle() != var->initial_value.handle(); - if (return_updated_values_for_all_variables || modified) { - variable_updates->emplace_back(); - XlaCompiler::VariableUpdate& update = variable_updates->back(); + if (return_updated_values_for_all_resources || modified) { + resource_updates->emplace_back(); + XlaCompiler::ResourceUpdate& update = resource_updates->back(); update.input_index = var->arg_num; update.type = var->type; update.modified = modified; @@ -410,13 +409,23 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, CompilationResult* result) { VLOG(1) << "Executing graph symbolically to populate ComputationBuilder."; + if (VLOG_IS_ON(2)) { + VLOG(2) << "XlaCompiler::CompileGraph: " + << dump_graph::DumpGraphToFile( + strings::StrCat("xla_compile_graph_", name), *graph); + } + // Report the error here if initialization failed. TF_RETURN_IF_ERROR(initialization_status_); + // Converts Tensorflow's graph control-flow constructs into functional + // control-flow that can be compiled into XLA code. + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(graph.get(), flib_def_.get())); + xla::ComputationBuilder builder(client(), name); XlaContext* context = new XlaContext(this, &builder, options_.allow_cpu_custom_calls, - options_.resolve_compile_time_constants); + options.resolve_compile_time_constants); core::ScopedUnref context_unref(context); result->tuple_arg = options.use_tuple_arg; @@ -433,10 +442,10 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, int num_nonconst_outputs; result->computation = std::make_shared(); TF_RETURN_IF_ERROR(BuildComputation( - context->retvals(), context->variables(), context->has_side_effects(), - options.return_updated_values_for_all_variables, &builder, + context->retvals(), context->resources(), context->has_side_effects(), + options.return_updated_values_for_all_resources, &builder, result->computation.get(), &num_nonconst_outputs, - &result->variable_updates)); + &result->resource_updates)); result->requires_runtime_context = context->has_context_parameter(); @@ -501,26 +510,29 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, OutputDescription& output = result->outputs[i]; output.is_constant = false; if (num_computation_outputs > 1) { - output.shape = - XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape( - result->xla_output_shape, computation_output)); + TF_RETURN_IF_ERROR(XLAShapeToTensorShape( + xla::ShapeUtil::GetTupleElementShape(result->xla_output_shape, + computation_output), + &output.shape)); } else { - output.shape = XLAShapeToTensorShape(result->xla_output_shape); + TF_RETURN_IF_ERROR( + XLAShapeToTensorShape(result->xla_output_shape, &output.shape)); } ++computation_output; } } - for (std::vector::size_type i = 0; - i < result->variable_updates.size(); ++i) { + for (std::vector::size_type i = 0; + i < result->resource_updates.size(); ++i) { if (num_computation_outputs > 1) { - result->variable_updates[i].shape = - XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape( - result->xla_output_shape, computation_output)); + TF_RETURN_IF_ERROR(XLAShapeToTensorShape( + xla::ShapeUtil::GetTupleElementShape(result->xla_output_shape, + computation_output), + &result->resource_updates[i].shape)); } else { CHECK_EQ(0, computation_output); - result->variable_updates[i].shape = - XLAShapeToTensorShape(result->xla_output_shape); + TF_RETURN_IF_ERROR(XLAShapeToTensorShape( + result->xla_output_shape, &result->resource_updates[i].shape)); } ++computation_output; } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 131430553252e2b62315c6388a53058bdf20eb7f..197e45617c4a11f0b1e60ec169e1acba7a2f2651 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -85,14 +85,14 @@ class XlaCompiler { // Argument is a compile-time constant. No associated runtime parameter. kConstant, - // Argument is a variable that has not been initialized yet. No associated - // runtime parameter. - kUninitializedVariable, - - // Argument is a variable that already has a value set. Expects a runtime - // parameter containing the current value. + // Argument is a variable resource. Has an associated runtime parameter + // iff `initialized` is true. kVariable, + // Argument is a TensorArray resource. Has an associated runtime parameter + // iff `initialized` is true. + kTensorArray, + // Argument is a run-time parameter. kParameter, }; @@ -114,8 +114,11 @@ class XlaCompiler { // The name of this argument, used for debugging. string name; - // For a kVariable or kUninitializedVariable corresponding to a TensorArray, - // what is the tensor array's declared size? + // For a kVariable or kTensorArray, has this resource been initialized? + bool initialized = false; + + // For a kTensorArray, what is the array's declared size? (Used for lazy + // initialization.) int64 tensor_array_size = -1; bool operator==(const Argument& other) const; @@ -133,7 +136,7 @@ class XlaCompiler { }; // Describes a variable write side effect of the computation. - struct VariableUpdate { + struct ResourceUpdate { // Index of the input that contains the variable resource to write to. int input_index; @@ -142,14 +145,14 @@ class XlaCompiler { TensorShape shape; // Was the value of the variable modified by the computation? - // (Always true, unless `return_updated_values_for_all_variables` is true.) + // (Always true, unless `return_updated_values_for_all_resources` is true.) bool modified; }; struct CompilationResult { // Vector that maps from the parameters of the XLA computation to their // original argument positions. To handle compile-time constant inputs and - // variables, the parameters to the XLA computation may be a subset of the + // resources, the parameters to the XLA computation may be a subset of the // original arguments, and are not necessarily in the same order.) std::vector input_mapping; @@ -172,10 +175,10 @@ class XlaCompiler { // containing both constant and non-constant results. std::vector outputs; - // Variables whose values were updated by the computation, ordered - // by return value position. Variable updates follow the non-constant + // Resources whose values were updated by the computation, ordered + // by return value position. Resource updates follow the non-constant // results in the outputs of XLA computation. - std::vector variable_updates; + std::vector resource_updates; // The XLA computation built from the tensorflow subgraph. May be null // if the output consists solely of compile-time constants. @@ -206,12 +209,6 @@ class XlaCompiler { // stored in device memory. bool local_executable_has_hybrid_result = false; - // If 'resolve_compile_time_constants' is true, then outputs of a - // computation that are known to be compile-time constants will be returned - // as Tensors at compile-time, rather than as run-time outputs of the - // computation. - bool resolve_compile_time_constants = true; - // If not nullptr, populate_resource_manager is called with the // compilation device's resource manager when the compilation // device is created, and can be used to create metadata objects @@ -229,12 +226,18 @@ class XlaCompiler { // arguments; if false, each argument gets its own parameter. bool use_tuple_arg = false; - // If 'return_updated_values_for_all_variables' is true, then updated - // values of all resource variables arguments will be included in the - // 'variable_updates' of the computation, even if the variable was not + // If 'return_updated_values_for_all_resources' is true, then updated + // values of all resource resources arguments will be included in the + // 'resource_updates' of the computation, even if the resource was not // modified by the computation. Used when compiling loop bodies to ensure // the input and output signatures match. - bool return_updated_values_for_all_variables = false; + bool return_updated_values_for_all_resources = false; + + // If 'resolve_compile_time_constants' is true, then outputs of a + // computation that are known to be compile-time constants will be returned + // as Tensors at compile-time, rather than as run-time outputs of the + // computation. + bool resolve_compile_time_constants = true; }; // Compiles a Tensorflow function `fn_name_attrs` into an XLA computation. @@ -294,6 +297,7 @@ class XlaCompiler { XlaCompilationDevice* device_; // Owned by device_mgr_ DeviceMgr device_mgr_; + std::unique_ptr flib_def_; std::unique_ptr flib_runtime_; struct SignatureHash { diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 58d74057d101cdef89fca24ec6c0858291d825fa..97cd951000bc2eb23e966b29cfc74e93b7877b3e 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -163,9 +163,9 @@ TEST_F(XlaCompilerTest, Simple) { // Tests that the generated computation works. std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal::CreateR1({7, 42}); std::unique_ptr param1_literal = - xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal::CreateR1({-3, 101}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = @@ -179,7 +179,7 @@ TEST_F(XlaCompilerTest, Simple) { client_->Transfer(*actual).ConsumeValueOrDie(); std::unique_ptr expected_literal = - xla::LiteralUtil::CreateR1({4, 143}); + xla::Literal::CreateR1({4, 143}); xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); } @@ -203,19 +203,19 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { args[0].type = DT_INT32; args[0].shape = TensorShape({2}); + XlaCompiler::Options options = DefaultOptions(); + XlaCompiler compiler(options); { // Compiles the graph, with resolve_compile_time_constants enabled. - XlaCompiler::Options options = DefaultOptions(); - options.resolve_compile_time_constants = true; - XlaCompiler compiler(options); std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); CopyGraph(*graph, graph_copy.get()); + XlaCompiler::CompileOptions compile_options; + compile_options.resolve_compile_time_constants = true; XlaCompiler::CompilationResult result; - TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), - "constants", std::move(graph_copy), args, - &result)); + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants", + std::move(graph_copy), args, &result)); ASSERT_EQ(2, result.outputs.size()); EXPECT_TRUE(result.outputs[0].is_constant); @@ -225,7 +225,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { // Tests that the generated computation works. std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal::CreateR1({7, 42}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -236,23 +236,20 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { client_->Transfer(*actual).ConsumeValueOrDie(); std::unique_ptr expected_literal = - xla::LiteralUtil::CreateR1({-7, -42}); + xla::Literal::CreateR1({-7, -42}); xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); } { // Compiles the graph, with resolve_compile_time_constants disabled. - XlaCompiler::Options options = DefaultOptions(); - options.resolve_compile_time_constants = false; - XlaCompiler compiler(options); - std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); CopyGraph(*graph, graph_copy.get()); + XlaCompiler::CompileOptions compile_options; + compile_options.resolve_compile_time_constants = false; XlaCompiler::CompilationResult result; - TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), - "constants", std::move(graph_copy), args, - &result)); + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants", + std::move(graph_copy), args, &result)); ASSERT_EQ(2, result.outputs.size()); EXPECT_FALSE(result.outputs[0].is_constant); @@ -260,7 +257,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { // Tests that the generated computation works. std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal::CreateR1({7, 42}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -270,12 +267,11 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { std::unique_ptr actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr expected0 = - xla::LiteralUtil::CreateR0(7); + std::unique_ptr expected0 = xla::Literal::CreateR0(7); std::unique_ptr expected1 = - xla::LiteralUtil::CreateR1({-7, -42}); + xla::Literal::CreateR1({-7, -42}); std::unique_ptr expected = - xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); + xla::Literal::MakeTuple({expected0.get(), expected1.get()}); xla::LiteralTestUtil::ExpectEqual(*expected, *actual_literal); } } diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 4440b530696db6125e0af0606be49e2d834dbd9f..d4d493b456f668ecfbdd0164c573b9ae2aa810e9 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -129,16 +129,18 @@ void XlaContext::AddSideEffects() { xla::ComputationBuilder* XlaContext::builder() { return builder_; } -Status XlaContext::CreateVariable(int arg_num, string name, DataType type, +Status XlaContext::CreateResource(XlaResource::Kind kind, int arg_num, + string name, DataType type, const xla::ComputationDataHandle& handle, - XlaVariable** variable) { - variables_.emplace_back(new XlaVariable); - *variable = variables_.back().get(); - XlaVariable& var = **variable; - var.arg_num = arg_num; - var.name = std::move(name); - var.type = type; - var.initial_value = var.value = handle; + XlaResource** resource) { + resources_.emplace_back(new XlaResource); + *resource = resources_.back().get(); + XlaResource& r = **resource; + r.kind = kind; + r.arg_num = arg_num; + r.name = std::move(name); + r.type = type; + r.initial_value = r.value = handle; return Status::OK(); } @@ -170,27 +172,6 @@ const xla::Computation* XlaContext::GetOrCreateAdd(const DataType type) { }); } -const xla::Computation* XlaContext::GetOrCreateSigmoid(const DataType type) { - return LookupOrCreate(type, &sigmoid_func_, [this, type] { - const string type_string = DataTypeString(type); - VLOG(1) << "Building Sigmoid() for " << type_string; - xla::ComputationBuilder b(builder()->client(), - "sigmoid<" + type_string + ">"); - xla::PrimitiveType xla_type; - TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); - auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); - // Clamp the inputs to the range [-18, 18] since anything outside - // this range is 0.0f or 1.0f in single-precision. We must clamp the range - // of x to avoid incorrect outputs due to fast-math optimizations for large - // negative x. - x = b.Clamp(XlaHelpers::IntegerLiteral(&b, type, -18), x, - XlaHelpers::IntegerLiteral(&b, type, 18)); - auto one = XlaHelpers::One(&b, type); - b.Div(one, b.Add(b.Exp(b.Neg(x)), one)); - return b.Build().ConsumeValueOrDie(); - }); -} - const xla::Computation* XlaContext::LookupOrCreate( DataType type, ComputationMap* out, const std::function& create) { diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 3978baaf637b4948510eafe37de94a383a87ddc3..544921b9e38fb52e70b9f67ba10f7c79dc53c657 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -52,11 +52,13 @@ class XlaContext : public ResourceBase { }; struct Argument { - // Descriptive name for the variable, for use in error messages. + XlaCompiler::Argument::Kind kind; + + // Descriptive name for the resource, for use in error messages. string name; - // Is this a variable? - bool is_variable = false; + // Is this a resource? + bool is_resource = false; HandleOrConstant value; @@ -106,15 +108,15 @@ class XlaContext : public ResourceBase { bool has_side_effects() const { return has_side_effects_; } - // Creates a variable with variable `variable_id` and initial type `type` and + // Creates a resource with resource `kind` and initial type `type` and // value `handle`. `name` is a descriptive name for use in error messages. - // Fails if the variable already exists. - Status CreateVariable(int arg_num, string name, DataType type, - const xla::ComputationDataHandle& handle, - XlaVariable** variable); + // Fails if the resource already exists. + Status CreateResource(XlaResource::Kind kind, int arg_num, string name, + DataType type, const xla::ComputationDataHandle& handle, + XlaResource** resource); - const std::vector>& variables() { - return variables_; + const std::vector>& resources() { + return resources_; } // Get an XLA lambda to compute Max. This is cached in the @@ -127,11 +129,6 @@ class XlaContext : public ResourceBase { // separate specialization of the computation for each DataType. const xla::Computation* GetOrCreateAdd(const DataType type); - // Get an XLA lambda to compute Sigmoid. This is cached in the - // XlaContext since it may be used by multiple Ops. There is a - // separate specialization of the computation for each DataType. - const xla::Computation* GetOrCreateSigmoid(const DataType type); - // The name of the XlaContext resource during symbolic graph execution. static const char kXlaContextResourceName[]; @@ -166,8 +163,8 @@ class XlaContext : public ResourceBase { // Does the computation have side effects, i.e., Send() calls? bool has_side_effects_ = false; - // Holds ownership of variables. The variables are not ordered. - std::vector> variables_; + // Holds ownership of resources. The resources are not ordered. + std::vector> resources_; // Cache of prebuilt computations indexed by their type. using ComputationMap = std::map; diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index f060f8f2f178b2bc56caf7a3df9df32c8a407473..2366c02dd2b0f22d3cbee929f31bdb0185bfabbc 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -30,28 +30,28 @@ xla::ComputationDataHandle XlaHelpers::MinValue(xla::ComputationBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return b->ConstantLiteral(xla::LiteralUtil::MinValue(type)); + return b->ConstantLiteral(xla::Literal::MinValue(type)); } xla::ComputationDataHandle XlaHelpers::MaxValue(xla::ComputationBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return b->ConstantLiteral(xla::LiteralUtil::MaxValue(type)); + return b->ConstantLiteral(xla::Literal::MaxValue(type)); } xla::ComputationDataHandle XlaHelpers::Zero(xla::ComputationBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return b->ConstantLiteral(xla::LiteralUtil::Zero(type)); + return b->ConstantLiteral(xla::Literal::Zero(type)); } xla::ComputationDataHandle XlaHelpers::One(xla::ComputationBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return b->ConstantLiteral(xla::LiteralUtil::One(type)); + return b->ConstantLiteral(xla::Literal::One(type)); } xla::ComputationDataHandle XlaHelpers::IntegerLiteral( @@ -61,28 +61,28 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral( TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); switch (type) { case xla::U8: - literal = *xla::LiteralUtil::CreateR0(value); + literal = *xla::Literal::CreateR0(value); break; case xla::U32: - literal = *xla::LiteralUtil::CreateR0(value); + literal = *xla::Literal::CreateR0(value); break; case xla::U64: - literal = *xla::LiteralUtil::CreateR0(value); + literal = *xla::Literal::CreateR0(value); break; case xla::S8: - literal = *xla::LiteralUtil::CreateR0(value); + literal = *xla::Literal::CreateR0(value); break; case xla::S32: - literal = *xla::LiteralUtil::CreateR0(value); + literal = *xla::Literal::CreateR0(value); break; case xla::S64: - literal = *xla::LiteralUtil::CreateR0(value); + literal = *xla::Literal::CreateR0(value); break; case xla::F32: - literal = *xla::LiteralUtil::CreateR0(value); + literal = *xla::Literal::CreateR0(value); break; case xla::F64: - literal = *xla::LiteralUtil::CreateR0(value); + literal = *xla::Literal::CreateR0(value); break; case xla::PRED: LOG(FATAL) << "pred element type is not integral"; @@ -91,7 +91,7 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral( LOG(FATAL) << "u16/s16 literals not yet implemented"; case xla::F16: literal = - *xla::LiteralUtil::CreateR0(static_cast(value)); + *xla::Literal::CreateR0(static_cast(value)); break; case xla::TUPLE: LOG(FATAL) << "tuple element type is not integral"; diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 3272b1efa153c0ecab720583277175b81fe59509..c5a68e05d9e1dfa3ed1c648e95d3690fadef8b51 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -39,7 +39,7 @@ static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) { const XlaExpression* expression = reinterpret_cast(tensor.tensor_data().data()); CHECK(expression->handle().handle() != 0 || - expression->variable() != nullptr); + expression->resource() != nullptr); VLOG(1) << "Fetched T" << expression->handle().handle(); return expression; } @@ -144,9 +144,9 @@ static Status LiteralToInt64Scalar(const xla::Literal& literal, int64* out) { return errors::InvalidArgument("value is not a scalar"); } if (literal.shape().element_type() == xla::S32) { - *out = xla::LiteralUtil::Get(literal, {}); + *out = literal.Get({}); } else if (literal.shape().element_type() == xla::S64) { - *out = xla::LiteralUtil::Get(literal, {}); + *out = literal.Get({}); } else { return errors::InvalidArgument("value must be either int32 or int64"); } @@ -168,11 +168,11 @@ static Status LiteralToInt64Vector(const xla::Literal& literal, int64 size = xla::ShapeUtil::ElementsIn(literal.shape()); if (literal.shape().element_type() == xla::S32) { for (int64 i = 0; i < size; ++i) { - out->push_back(xla::LiteralUtil::Get(literal, {i})); + out->push_back(literal.Get({i})); } } else if (literal.shape().element_type() == xla::S64) { for (int64 i = 0; i < size; ++i) { - out->push_back(xla::LiteralUtil::Get(literal, {i})); + out->push_back(literal.Get({i})); } } else { return errors::InvalidArgument("value must be either int32 or int64"); @@ -252,8 +252,9 @@ Status XlaOpKernelContext::ReadVariableInput( int index, xla::ComputationDataHandle* value) { const Tensor& tensor = context_->input(index); const XlaExpression* expression = CastExpressionFromTensor(tensor); - XlaVariable* variable = expression->variable(); + XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); + TF_RET_CHECK(variable->kind == XlaResource::kVariable); if (variable->value.handle() == 0) { return errors::InvalidArgument("Read of uninitialized variable ", variable->name); @@ -262,22 +263,13 @@ Status XlaOpKernelContext::ReadVariableInput( return Status::OK(); } -string XlaOpKernelContext::VariableDebugString(int index) { - const Tensor& tensor = context_->input(index); - const XlaExpression* expression = CastExpressionFromTensor(tensor); - XlaVariable* variable = expression->variable(); - if (!variable) { - return ""; - } - return variable->name; -} - Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, TensorShape* shape) const { const Tensor& tensor = context_->input(index); const XlaExpression* expression = CastExpressionFromTensor(tensor); - XlaVariable* variable = expression->variable(); + XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); + TF_RET_CHECK(variable->kind == XlaResource::kVariable); if (variable->value.handle() == 0) { return errors::InvalidArgument("Read of uninitialized variable ", variable->name); @@ -287,7 +279,8 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, if (!shape_or_status.ok()) { return shape_or_status.status(); } - *shape = XLAShapeToTensorShape(*shape_or_status.ValueOrDie()); + TF_RETURN_IF_ERROR( + XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), shape)); return Status::OK(); } @@ -304,10 +297,11 @@ void XlaOpKernelContext::SetOutput(int index, // The step's default allocator is the dummy XlaCompilationAllocator which // simply allocates a metadata buffer to hold the expression to which it // corresponds. - OP_REQUIRES_OK( - context_, - context_->allocate_output( - index, XLAShapeToTensorShape(*shape.ValueOrDie()), &output)); + TensorShape tensor_shape; + OP_REQUIRES_OK(context_, + XLAShapeToTensorShape(*shape.ValueOrDie(), &tensor_shape)); + OP_REQUIRES_OK(context_, + context_->allocate_output(index, tensor_shape, &output)); // The expression is stored in the tensor's data buffer. Fill in the // fields now. @@ -337,33 +331,34 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { expression->set_constant_value(constant); } -void XlaOpKernelContext::SetVariableOutput(int index, XlaVariable* variable) { +void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) { Tensor* output = nullptr; - // The shape of the output tensor is the shape of the variable resource - // (i.e., a scalar), not the shape of the variable's value. + // The shape of the output tensor is the shape of the resource itself + // (i.e., a scalar), not the shape of the resource's value. OP_REQUIRES_OK(context_, context_->allocate_output(index, TensorShape(), &output)); XlaExpression* expression = CastExpressionFromUninitializedTensor(output); - expression->set_variable(variable); + expression->set_resource(resource); } -Status XlaOpKernelContext::GetVariableInput(int index, XlaVariable** variable) { +Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) { const XlaExpression* expression = CastExpressionFromTensor(context_->input(index)); - TF_RET_CHECK(expression->variable() != nullptr); - *variable = expression->variable(); + TF_RET_CHECK(expression->resource() != nullptr); + *resource = expression->resource(); return Status::OK(); } Status XlaOpKernelContext::AssignVariable( - int index, DataType type, const xla::ComputationDataHandle& handle) { + int input_index, DataType type, const xla::ComputationDataHandle& handle) { TF_RET_CHECK(handle.handle() != 0); SetOpHasSideEffects(); const XlaExpression* expression = - CastExpressionFromTensor(context_->input(index)); - XlaVariable* variable = expression->variable(); + CastExpressionFromTensor(context_->input(input_index)); + XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); + TF_RET_CHECK(variable->kind == XlaResource::kVariable); if (!((variable->type == DT_INVALID && type != DT_INVALID) || (variable->type == type))) { return errors::InvalidArgument( @@ -398,11 +393,6 @@ const xla::Computation* XlaOpKernelContext::GetOrCreateAdd( return XlaContext::Get(context_).GetOrCreateAdd(type); } -const xla::Computation* XlaOpKernelContext::GetOrCreateSigmoid( - const DataType type) { - return XlaContext::Get(context_).GetOrCreateSigmoid(type); -} - XlaOpKernel::XlaOpKernel(OpKernelConstruction* context) : OpKernel(context) {} void XlaOpKernel::Compute(OpKernelContext* context) { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index a25774c3a6a4a7212d157766a23e73063c2deab8..30b794c8c198cae6bf3b11794b35049b729063e1 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -148,6 +148,12 @@ class XlaOpKernelContext { // Variables + // Sets '*resource' to the resource associated with input `index`. + Status GetResourceInput(int index, XlaResource** resource); + + // Sets output 'index' to be a reference to resource 'resource'. + void SetResourceOutput(int index, XlaResource* resource); + // Sets `*type` and `*shape` to the current type and shape of a variable's // value. Status GetVariableTypeAndShape(int index, DataType* type, @@ -158,20 +164,10 @@ class XlaOpKernelContext { Status ReadVariableInput(int index, xla::ComputationDataHandle* value); // Assigns the value `handle` to the variable referenced by input - // `variable_index`. Marks the operator as having side effects. - Status AssignVariable(int variable_index, DataType type, + // `input_index`. Marks the operator as having side effects. + Status AssignVariable(int input_index, DataType type, const xla::ComputationDataHandle& handle); - // Sets '*variable' to the variable associated with input `index`. - Status GetVariableInput(int index, XlaVariable** variable); - - // Sets output 'index' to be a reference to variable 'variable'. Used - // to propagate resource variables through the compilation. - void SetVariableOutput(int index, XlaVariable* variable); - - // Returns a human-readable debug string describing 'variable_index'. - string VariableDebugString(int variable_index); - // Helper routines for the OP_REQUIRES macros void CtxFailure(Status s); void CtxFailureWithWarning(Status s); @@ -205,11 +201,6 @@ class XlaOpKernelContext { // separate specialization of the computation for each DataType. const xla::Computation* GetOrCreateAdd(const DataType type); - // Get an XLA lambda to compute Sigmoid. This is cached in the - // XlaContext since it may be used by multiple Ops. There is a - // separate specialization of the computation for each DataType. - const xla::Computation* GetOrCreateSigmoid(const DataType type); - private: OpKernelContext* const context_; }; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 1bb0d8528994b957ccebeabce8bc48227122e366..d059c7a23ef2955cdd1280d1ceff7fc39b625631 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -24,6 +24,8 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -34,11 +36,18 @@ const char* const DEVICE_GPU_XLA_JIT = "XLA_GPU_JIT"; const char* const DEVICE_XLA_CPU = "XLA_CPU"; const char* const DEVICE_XLA_GPU = "XLA_GPU"; -// Is platform 'id' supported by XLA? -static bool IsPlatformSupported(perftools::gputools::Platform::Id id) { - auto platform = perftools::gputools::MultiPlatformManager::PlatformWithId(id); - if (!platform.ok()) return false; - return xla::ClientLibrary::GetOrCreateLocalClient(platform.ValueOrDie()).ok(); +static Status LaunchOpHasKernelForDevice(const DeviceType& device_type) { + const OpDef* op_def; + TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef("_XlaLaunch", &op_def)); + NodeDef node_def; + node_def.set_name("_XlaLaunch-op"); + node_def.set_op("_XlaLaunch"); + string kernel_class_name; + TF_RETURN_IF_ERROR(FindKernelDef(device_type, node_def, /*KernelDef*/ nullptr, + &kernel_class_name)); + VLOG(1) << "LaunchOpHasKernelForDevice" + << " kernel_class_name: " << kernel_class_name; + return Status::OK(); } XlaOpRegistry::XlaOpRegistry() = default; @@ -75,7 +84,7 @@ XlaOpRegistry::~XlaOpRegistry() = default; // GetCompilationDevice is called. static void* registration_init = [®istry]() { mutex_lock lock(registry.mutex_); - if (IsPlatformSupported(perftools::gputools::host::kHostPlatformId)) { + if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_CPU)).ok()) { DeviceRegistration& registration = registry.compilation_devices_[DEVICE_CPU]; registration.compilation_device_name = DEVICE_CPU_XLA_JIT; @@ -83,7 +92,7 @@ XlaOpRegistry::~XlaOpRegistry() = default; registration.enable_jit_by_default = false; registration.compile_resource_ops = false; } - if (IsPlatformSupported(perftools::gputools::cuda::kCudaPlatformId)) { + if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_GPU)).ok()) { DeviceRegistration& registration = registry.compilation_devices_[DEVICE_GPU]; registration.compilation_device_name = DEVICE_GPU_XLA_JIT; diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 2491cc3f7a2011827f4e093f287b525155153b71..c508071f8c1c436fee74bbac9ec76446aa76fd22 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -46,21 +46,18 @@ xla_proto_library( ], ) -# This is a headers target that extra XLA devices can use to prevent -# circular dependencies. Devices that are compiled as separate shared -# objects can also use it to prevent linking of library code. -cc_header_only_library( - name = "xla_headers_lib", - visibility = ["//visibility:public"], +cc_library( + name = "execution_options_util", + srcs = [ + "execution_options_util.cc", + ], + hdrs = [ + "execution_options_util.h", + ], + visibility = [":friends"], deps = [ - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/legacy_flags:layout_util_flags", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_evaluator", - "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:stream_executor_headers_lib", + ":xla_proto", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", ], ) @@ -602,3 +599,18 @@ filegroup( ), visibility = ["//tensorflow:__subpackages__"], ) + +# This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code. +cc_header_only_library( + name = "xla_headers_lib", + visibility = ["//visibility:public"], + deps = [ + ":xla_data_proto", + ":xla_proto", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/legacy_flags:layout_util_flags", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:stream_executor_headers_lib", + ], +) diff --git a/tensorflow/compiler/xla/array4d.h b/tensorflow/compiler/xla/array4d.h index d93f968f4d7a8c30129f4e14c4db06c25187cb45..4c7fce1aaf1faf4bd08bca38bc8eb2b47303b575 100644 --- a/tensorflow/compiler/xla/array4d.h +++ b/tensorflow/compiler/xla/array4d.h @@ -207,6 +207,18 @@ class Array4D { } } + // Invokes a callback with the (indices, value) for each cell in the 4D array. + void Each( + std::function, T)> f) const { + // We const_cast to be able to use the common non-const implementation, + // but prevent modification of the data by passing it by-value to the + // caller. + const_cast(this)->Each( + [&f](tensorflow::gtl::ArraySlice indices, T* value) { + f(indices, *value); + }); + } + // Fills all of the {p,z} with the array provided, which specifies {y,x}. void FillWithYX(const Array2D& value) { CHECK_EQ(value.height(), height()); diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 63c6d9ddaca5e9e336e29cd3b23cfd921d4ce9e7..a998b91c89d79ac5e354d2a3edf5fb78695d73cb 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -62,6 +62,7 @@ cc_library( deps = [ ":computation", ":global_data", + "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:service_interface", "//tensorflow/compiler/xla:status_macros", @@ -70,6 +71,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], @@ -114,7 +116,6 @@ cc_library( "//tensorflow/compiler/xla/service:compile_only_service", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", - "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@llvm//:support", ], diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 454d0fbd9650c4d77a62b4c25a5407e36bd191f8..1799bbd3480daacc204b42f168a7f8e9149db58b 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -197,7 +199,10 @@ StatusOr> Client::Execute( ExecutionProfile* execution_profile) { ExecuteRequest request; *request.mutable_computation() = computation.handle(); - if (execution_options != nullptr) { + + if (execution_options == nullptr) { + *request.mutable_execution_options() = CreateDefaultExecutionOptions(); + } else { *request.mutable_execution_options() = *execution_options; } for (GlobalData* argument : arguments) { @@ -298,7 +303,9 @@ StatusOr Client::ExecuteAsync( for (GlobalData* argument : arguments) { *request.add_arguments() = argument->handle(); } - if (execution_options != nullptr) { + if (execution_options == nullptr) { + *request.mutable_execution_options() = CreateDefaultExecutionOptions(); + } else { *request.mutable_execution_options() = *execution_options; } @@ -376,9 +383,10 @@ StatusOr>> Client::DeconstructTuple( } StatusOr Client::GetComputationStats( - const Computation& computation) const { + const Computation& computation, const DebugOptions& debug_options) const { ComputationStatsRequest request; *request.mutable_computation() = computation.handle(); + *request.mutable_debug_options() = debug_options; ComputationStatsResponse response; VLOG(1) << "making computation stats request"; @@ -427,7 +435,10 @@ StatusOr Client::GetShape(const GlobalData& data) { StatusOr Client::ExecutionStatsAsString( const Computation& computation, const ExecutionProfile& profile) { - TF_ASSIGN_OR_RETURN(auto computation_stats, GetComputationStats(computation)); + TF_ASSIGN_OR_RETURN( + auto computation_stats, + GetComputationStats(computation, + legacy_flags::GetDebugOptionsFromFlags())); int64 total_flops = computation_stats.flop_count() + computation_stats.transcendental_count(); if (profile.compute_time_ns() > 0) { diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index 797835160fa2850f108e85ff3147abffd9f86ad8..69d3642911fa8fe87ceb347d929e95ffd972615b 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -150,7 +150,7 @@ class Client { // Retrieves the statistics of the given computation. StatusOr GetComputationStats( - const Computation& computation) const; + const Computation& computation, const DebugOptions& debug_options) const; // Returns the Shape of the given array specified by 'data'. The shape // includes the Layout of the array as it is stored on the service. diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index 37bf697683b0f5f61a1b915628920b0752116a32..dcc313707b93248842f0c1500afbd449e2048549 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -971,6 +971,11 @@ ComputationDataHandle ComputationBuilder::Sign( return UnaryOp(UNOP_SIGN, operand); } +ComputationDataHandle ComputationBuilder::Cos( + const ComputationDataHandle& operand) { + return UnaryOp(UNOP_COS, operand); +} + ComputationDataHandle ComputationBuilder::Tanh( const ComputationDataHandle& operand) { return UnaryOp(UNOP_TANH, operand); @@ -1411,6 +1416,52 @@ ComputationDataHandle ComputationBuilder::ReduceWindowWithGeneralPadding( return ParseOpResponse(s, &response); } +ComputationDataHandle ComputationBuilder::BatchNormTraining( + const ComputationDataHandle& operand, const ComputationDataHandle& scale, + const ComputationDataHandle& offset, float epsilon, int64 feature_index) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + BatchNormTrainingRequest request; + *request.mutable_operand() = operand; + *request.mutable_scale() = scale; + *request.mutable_offset() = offset; + request.set_epsilon(epsilon); + request.set_feature_index(feature_index); + + OpRequest op_request; + *op_request.mutable_batch_norm_training_request() = request; + *op_request.mutable_computation() = computation_.handle(); + AddOpMetadata(&op_request); + + OpResponse response; + + VLOG(2) << "making BatchNormTraining request"; + + Status s = client_->stub()->Op(&op_request, &response); + return ParseOpResponse(s, &response); +} + +ComputationDataHandle ComputationBuilder::BatchNormInference( + const ComputationDataHandle& operand, const ComputationDataHandle& scale, + const ComputationDataHandle& offset, const ComputationDataHandle& mean, + const ComputationDataHandle& variance, float epsilon, int64 feature_index) { + // TODO(b/62843645): Implement BatchNormInference. + NoteError(Unimplemented("BatchNormInference is not implemented yet.")); + return ComputationDataHandle(); +} + +ComputationDataHandle ComputationBuilder::BatchNormGrad( + const ComputationDataHandle& operand, const ComputationDataHandle& scale, + const ComputationDataHandle& batch_mean, + const ComputationDataHandle& batch_var, + const ComputationDataHandle& grad_output, float epsilon, + int64 feature_index) { + // TODO(b/62843645): Implement BatchNormGrad. + NoteError(Unimplemented("BatchNormGrad is not implemented yet.")); + return ComputationDataHandle(); +} + ComputationDataHandle ComputationBuilder::CrossReplicaSum( const ComputationDataHandle& operand) { if (!first_error_.ok() || !PrepareComputation().ok()) { @@ -1487,6 +1538,28 @@ ComputationDataHandle ComputationBuilder::SelectAndScatterWithGeneralPadding( return ParseOpResponse(s, &response); } +ComputationDataHandle ComputationBuilder::ReducePrecision( + const ComputationDataHandle& operand, const int exponent_bits, + const int mantissa_bits) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + ReducePrecisionRequest request; + *request.mutable_operand() = operand; + request.set_exponent_bits(exponent_bits); + request.set_mantissa_bits(mantissa_bits); + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_reduce_precision_request() = request; + AddOpMetadata(&op_request); + OpResponse response; + + VLOG(2) << "making reduce-precision request"; + Status s = client_->stub()->Op(&op_request, &response); + return ParseOpResponse(s, &response); +} + void ComputationBuilder::Send(const ComputationDataHandle& operand, const ChannelHandle& handle) { if (!first_error_.ok() || !PrepareComputation().ok()) { diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index 5cc73c28d03a097a4fd5b8d3a549ffdc43c6fcd3..b411346459e87df3e8f3eed679d6e16b6e5ee894 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -510,6 +510,9 @@ class ComputationBuilder { // Enqueues a sign instruction onto the computation. ComputationDataHandle Sign(const ComputationDataHandle& operand); + // Enqueues a cosine instruction onto the computation. + ComputationDataHandle Cos(const ComputationDataHandle& operand); + // Enqueues a tanh instruction onto the computation. ComputationDataHandle Tanh(const ComputationDataHandle& operand); @@ -597,6 +600,11 @@ class ComputationBuilder { const Computation& body, const ComputationDataHandle& init); + // Enqueues a ReducePrecision node onto the computation. + ComputationDataHandle ReducePrecision(const ComputationDataHandle& operand, + const int exponent_bits, + const int mantissa_bits); + // Enqueues a Send node onto the computation, to send the given operand to // a Recv instruction that shares the same channel handle. void Send(const ComputationDataHandle& operand, const ChannelHandle& handle); @@ -820,87 +828,80 @@ class ComputationBuilder { template ComputationDataHandle ComputationBuilder::ConstantR0(NativeT value) { - return ConstantOp( - [value](Literal* literal) { LiteralUtil::PopulateR0(value, literal); }); + return ConstantOp([value](Literal* literal) { literal->PopulateR0(value); }); } template ComputationDataHandle ComputationBuilder::ConstantR1( tensorflow::gtl::ArraySlice values) { - return ConstantOp([&values](Literal* literal) { - LiteralUtil::PopulateR1(values, literal); - }); + return ConstantOp( + [&values](Literal* literal) { literal->PopulateR1(values); }); } template ComputationDataHandle ComputationBuilder::ConstantR1(int64 length, NativeT value) { return ConstantOp([length, value](Literal* literal) { - LiteralUtil::PopulateWithValue(value, {length}, literal); + literal->PopulateWithValue(value, {length}); }); } inline ComputationDataHandle ComputationBuilder::ConstantR1( const tensorflow::core::Bitmap& values) { - return ConstantOp([&values](Literal* literal) { - LiteralUtil::PopulateR1(values, literal); - }); + return ConstantOp( + [&values](Literal* literal) { literal->PopulateR1(values); }); } template ComputationDataHandle ComputationBuilder::ConstantR2( std::initializer_list> values) { - return ConstantOp([&values](Literal* literal) { - LiteralUtil::PopulateR2(values, literal); - }); + return ConstantOp( + [&values](Literal* literal) { literal->PopulateR2(values); }); } template ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout( const Array2D& values, const Layout& layout) { return ConstantOp([&values, &layout](Literal* literal) { - LiteralUtil::PopulateR2FromArray2DWithLayout(values, layout, literal); + literal->PopulateR2FromArray2DWithLayout(values, layout); }); } template ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D( const Array2D& values) { - return ConstantOp([&values](Literal* literal) { - LiteralUtil::PopulateR2FromArray2D(values, literal); - }); + return ConstantOp( + [&values](Literal* literal) { literal->PopulateR2FromArray2D(values); }); } template ComputationDataHandle ComputationBuilder::ConstantR3FromArray3DWithLayout( const Array3D& values, const Layout& layout) { return ConstantOp([&values, &layout](Literal* literal) { - LiteralUtil::PopulateR3FromArray3DWithLayout(values, layout, literal); + literal->PopulateR3FromArray3DWithLayout(values, layout); }); } template ComputationDataHandle ComputationBuilder::ConstantR3FromArray3D( const Array3D& values) { - return ConstantOp([&values](Literal* literal) { - LiteralUtil::PopulateR3FromArray3D(values, literal); - }); + return ConstantOp( + [&values](Literal* literal) { literal->PopulateR3FromArray3D(values); }); } template ComputationDataHandle ComputationBuilder::ConstantR4FromArray4DWithLayout( const Array4D& values, const Layout& layout) { return ConstantOp([&values, &layout](Literal* literal) { - LiteralUtil::PopulateR4FromArray4DWithLayout(values, layout, literal); + literal->PopulateR4FromArray4DWithLayout(values, layout); }); } template ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D( const Array4D& values) { - return ConstantOp([&values](Literal* literal) { - LiteralUtil::PopulateR4FromArray4D(values, literal); - }); + return ConstantOp( + [&values](Literal* literal) { literal->PopulateR4FromArray4D(values); }); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 86b16be62f041ae3e96591627501592b34203e16..edd971e114c0769f092a70e0f06d5a4db7134dda 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -32,6 +32,7 @@ cc_library( srcs = ["testing.cc"], hdrs = ["testing.h"], deps = [ + "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index daa1557df0b97ee20679f45b8d54164ca93555fa..d8bfc945807d061234c1bc5999ea377a72e85a62 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -34,11 +35,11 @@ std::unique_ptr MakeFakeDataOrDie(const Shape& shape, client, tensorflow::strings::StrCat("make_fake_", ShapeUtil::HumanString(shape))); // TODO(b/26811613): Replace this when RNG is supported on all backends. - b.Broadcast(b.ConstantLiteral(LiteralUtil::One(shape.element_type())), + b.Broadcast(b.ConstantLiteral(Literal::One(shape.element_type())), AsInt64Slice(shape.dimensions())); Computation computation = b.Build().ConsumeValueOrDie(); - ExecutionOptions execution_options; + auto execution_options = CreateDefaultExecutionOptions(); *execution_options.mutable_shape_with_output_layout() = shape; return client->Execute(computation, /*arguments=*/{}, &execution_options) .ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc index 67f3a6c1df4d74e5ef714dcaa56bae1e81f8276a..33d5b6f1d4d15d5143a3421c87eab9b7a7d11345 100644 --- a/tensorflow/compiler/xla/executable_run_options.cc +++ b/tensorflow/compiler/xla/executable_run_options.cc @@ -77,4 +77,14 @@ ExecutionProfile* ExecutableRunOptions::execution_profile() const { return execution_profile_; } +ExecutableRunOptions& ExecutableRunOptions::set_device_assignment( + DeviceAssignment* device_assignment) { + device_assignment_ = device_assignment; + return *this; +} + +DeviceAssignment* ExecutableRunOptions::device_assignment() const { + return device_assignment_; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index 03f2d016ad07b63e6b7d9681c86885ce947f5319..deb3ddb203d263d25bef0499a8a53a6098d0de0c 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -40,6 +40,7 @@ struct ThreadPoolDevice; namespace xla { class DeviceMemoryAllocator; +class DeviceAssignment; class ExecutionProfile; // Class containing options for running a LocalExecutable. @@ -79,9 +80,14 @@ class ExecutableRunOptions { ExecutionProfile* execution_profile() const; ExecutableRunOptions& set_execution_profile(ExecutionProfile* profile); + ExecutableRunOptions& set_device_assignment( + DeviceAssignment* device_assignment); + DeviceAssignment* device_assignment() const; + private: DeviceMemoryAllocator* allocator_ = nullptr; int device_ordinal_ = -1; + DeviceAssignment* device_assignment_ = nullptr; perftools::gputools::Stream* stream_ = nullptr; tensorflow::thread::ThreadPool* inter_op_thread_pool_ = nullptr; const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr; diff --git a/tensorflow/tensorboard/components/tf_color_scale/tf-color-scale.html b/tensorflow/compiler/xla/execution_options_util.cc similarity index 50% rename from tensorflow/tensorboard/components/tf_color_scale/tf-color-scale.html rename to tensorflow/compiler/xla/execution_options_util.cc index a325f0a04cd033dd89b870a2fc6eca9a7a6f0020..e83ff7cddd675197c7f6d7018257edb4c25b6228 100644 --- a/tensorflow/tensorboard/components/tf_color_scale/tf-color-scale.html +++ b/tensorflow/compiler/xla/execution_options_util.cc @@ -1,6 +1,4 @@ - +==============================================================================*/ +#include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" - - +namespace xla { - - - - - +} // namespace xla diff --git a/tensorflow/compiler/xla/execution_options_util.h b/tensorflow/compiler/xla/execution_options_util.h new file mode 100644 index 0000000000000000000000000000000000000000..562da78e837ea6c4a01f0d1170797340fd421ad8 --- /dev/null +++ b/tensorflow/compiler/xla/execution_options_util.h @@ -0,0 +1,29 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_EXECUTION_OPTIONS_UTIL_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_EXECUTION_OPTIONS_UTIL_H_ + +#include "tensorflow/compiler/xla/xla.pb.h" + +namespace xla { + +// Create a default ExecutionOptions proto; this proto has its debug options +// popupated to the default values taken from flags. +ExecutionOptions CreateDefaultExecutionOptions(); + +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_EXECUTION_OPTIONS_UTIL_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/BUILD b/tensorflow/compiler/xla/legacy_flags/BUILD index a147ce67a28884d485280b4d811875d569fad879..1fdb6d59cfc1b9da13e1b0f43b672ca0af8e24c8 100644 --- a/tensorflow/compiler/xla/legacy_flags/BUILD +++ b/tensorflow/compiler/xla/legacy_flags/BUILD @@ -73,85 +73,12 @@ cc_library( deps = [ ":parse_flags_from_env", - "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_proto", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", ], ) -cc_library( - name = "cpu_compiler_flags", - srcs = ["cpu_compiler_flags.cc"], - hdrs = ["cpu_compiler_flags.h"], - deps = - [ - ":parse_flags_from_env", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "cpu_runtime_flags", - srcs = ["cpu_runtime_flags.cc"], - hdrs = ["cpu_runtime_flags.h"], - deps = - [ - ":parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "compiler_functor_flags", - srcs = ["compiler_functor_flags.cc"], - hdrs = ["compiler_functor_flags.h"], - deps = [ - ":parse_flags_from_env", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "convolution_thunk_flags", - srcs = ["convolution_thunk_flags.cc"], - hdrs = ["convolution_thunk_flags.h"], - deps = [ - ":parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "gpu_compiler_flags", - srcs = ["gpu_compiler_flags.cc"], - hdrs = ["gpu_compiler_flags.h"], - deps = [ - ":parse_flags_from_env", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "gpu_backend_lib_flags", - srcs = ["gpu_backend_lib_flags.cc"], - hdrs = ["gpu_backend_lib_flags.h"], - deps = [ - ":parse_flags_from_env", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - cc_library( name = "stream_assignment_flags", srcs = ["stream_assignment_flags.cc"], @@ -163,40 +90,6 @@ cc_library( ], ) -cc_library( - name = "hlo_graph_dumper_flags", - srcs = ["hlo_graph_dumper_flags.cc"], - hdrs = ["hlo_graph_dumper_flags.h"], - deps = [ - ":parse_flags_from_env", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "alias_analysis_flags", - srcs = ["alias_analysis_flags.cc"], - hdrs = ["alias_analysis_flags.h"], - deps = [ - ":parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "llvm_util_flags", - srcs = ["llvm_util_flags.cc"], - hdrs = ["llvm_util_flags.h"], - deps = [ - ":parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - cc_library( name = "service_flags", srcs = ["service_flags.cc"], @@ -209,28 +102,6 @@ cc_library( ], ) -cc_library( - name = "buffer_assignment_flags", - srcs = ["buffer_assignment_flags.cc"], - hdrs = ["buffer_assignment_flags.h"], - deps = [ - ":parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "hlo_test_base_flags", - srcs = ["hlo_test_base_flags.cc"], - hdrs = ["hlo_test_base_flags.h"], - deps = [ - ":parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - cc_library( name = "backend_flags", srcs = ["backend_flags.cc"], @@ -243,18 +114,6 @@ cc_library( ], ) -cc_library( - name = "user_computation_flags", - srcs = ["user_computation_flags.cc"], - hdrs = ["user_computation_flags.h"], - deps = [ - ":parse_flags_from_env", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.cc b/tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.cc deleted file mode 100644 index 474753c10ad7ed5eb4a9a446c3f877280c5ad302..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.cc +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's alias_analysis module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static AliasAnalysisFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new AliasAnalysisFlags; - flags->xla_emit_alias_scope = true; - flag_list = new std::vector({ - tensorflow::Flag("xla_emit_alias_scope", &flags->xla_emit_alias_scope, - "Use buffer analysis to refine alias-analysis."), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's alias_analysis -// module. -void AppendAliasAnalysisFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the AliasAnalysisFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -AliasAnalysisFlags* GetAliasAnalysisFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h b/tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h deleted file mode 100644 index 369f8cd7caa6f42273cd405ca5f43d325e457128..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_ALIAS_ANALYSIS_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_ALIAS_ANALYSIS_FLAGS_H_ - -// Legacy flags for XLA's alias_analysis module. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with XLA's alias_analysis -// module. -void AppendAliasAnalysisFlags(std::vector* flag_list); - -// The values of flags associated with XLA's alias_analysis module. -typedef struct { - bool xla_emit_alias_scope; // Use buffer analysis to refine alias-analysis. -} AliasAnalysisFlags; - -// Return a pointer to the AliasAnalysisFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -AliasAnalysisFlags* GetAliasAnalysisFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_ALIAS_ANALYSIS_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.cc b/tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.cc deleted file mode 100644 index 71873f73afd5bb8c59832a4c82f87f4e51c31180..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's buffer_assignment module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static BufferAssignmentFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new BufferAssignmentFlags; - flags->xla_enable_buffer_reuse = true; - flag_list = new std::vector({ - tensorflow::Flag("xla_enable_buffer_reuse", - &flags->xla_enable_buffer_reuse, - "Enable reuse of buffers."), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's buffer_assignment -// module. -void AppendBufferAssignmentFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the BufferAssignmentFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -BufferAssignmentFlags* GetBufferAssignmentFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.h b/tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.h deleted file mode 100644 index 5f098c2663f638940aead45b74332edcf3fcc37f..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.h +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_BUFFER_ASSIGNMENT_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_BUFFER_ASSIGNMENT_FLAGS_H_ - -// Legacy flags for XLA's buffer_assignment module. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with XLA's buffer_assignment -// module. -void AppendBufferAssignmentFlags(std::vector* flag_list); - -// The values of flags associated with XLA's buffer_assignment module. -typedef struct { - bool xla_enable_buffer_reuse; // Enable reuse of buffers. -} BufferAssignmentFlags; - -// Return a pointer to the BufferAssignmentFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -BufferAssignmentFlags* GetBufferAssignmentFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_BUFFER_ASSIGNMENT_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.cc b/tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.cc deleted file mode 100644 index 617a9b712ed99d343dc28b6e6c0de4b54e271096..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.cc +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's compiler_functor module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static CompilerFunctorFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new CompilerFunctorFlags; - flag_list = new std::vector({ - tensorflow::Flag("xla_debug_cpu_dump_ir", &flags->xla_debug_cpu_dump_ir, - "Dump IR, before optimizations to a path"), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's compiler_functor -// module. -void AppendCompilerFunctorFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the CompilerFunctorFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -CompilerFunctorFlags* GetCompilerFunctorFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h b/tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h deleted file mode 100644 index 28b505ec5eac2d74879a22779137c6982a7c9ce8..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_COMPILER_FUNCTOR_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_COMPILER_FUNCTOR_FLAGS_H_ - -// Legacy flags for the XLA's compiler_functor module. - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with XLA's compiler_functor -// module. -void AppendCompilerFunctorFlags(std::vector* flag_list); - -// The values of flags associated with XLA's compiler_functor module. -typedef struct { - string xla_debug_cpu_dump_ir; // Dump IR, before optimizations to a path -} CompilerFunctorFlags; - -// Return a pointer to the CompilerFunctorFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -CompilerFunctorFlags* GetCompilerFunctorFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_COMPILER_FUNCTOR_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.cc b/tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.cc deleted file mode 100644 index fe5d19147f09557817fee5c670f52058f21f5cdc..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's convolution_thunk module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static ConvolutionThunkFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new ConvolutionThunkFlags; - flags->xla_gpu_autotune_convolution_algorithm = true; - flag_list = new std::vector({ - tensorflow::Flag("xla_gpu_autotune_convolution_algorithm", - &flags->xla_gpu_autotune_convolution_algorithm, - "Auto-tune the algorithm used by convolution"), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's convolution_thunk -// module. -void AppendConvolutionThunkFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the ConvolutionThunkFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -ConvolutionThunkFlags* GetConvolutionThunkFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.h b/tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.h deleted file mode 100644 index 53d6806a71902d1227728f74bd45f12f9d11421d..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.h +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CONVOLUTION_THUNK_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CONVOLUTION_THUNK_FLAGS_H_ - -// Legacy flags for XLA's convolution_thunk module. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with XLA's convolution_thunk -// module. -void AppendConvolutionThunkFlags(std::vector* flag_list); - -// The values of flags associated with XLA's convolution_thunk module. -typedef struct { - // Auto-tune the algorithm used by convolution - bool xla_gpu_autotune_convolution_algorithm; -} ConvolutionThunkFlags; - -// Return a pointer to the ConvolutionThunkFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -ConvolutionThunkFlags* GetConvolutionThunkFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CONVOLUTION_THUNK_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.cc b/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.cc deleted file mode 100644 index 13d41a8636b6ba3aa88545523e93dffe4b0c12f5..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.cc +++ /dev/null @@ -1,68 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's cpu_compiler module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static CpuCompilerFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new CpuCompilerFlags; - flags->xla_cpu_embed_ir = false; - flags->xla_cpu_dump_debug_json_to = ""; - flag_list = new std::vector({ - tensorflow::Flag( - "xla_cpu_embed_ir", &flags->xla_cpu_embed_ir, - "Embed the LLVM IR module string in the resultant CpuExecutable."), - tensorflow::Flag("xla_cpu_dump_debug_json_to", - &flags->xla_cpu_dump_debug_json_to, - "Dump debug JSON to this directory."), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's cpu_compiler -// module. -void AppendCpuCompilerFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the CpuCompilerFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -CpuCompilerFlags* GetCpuCompilerFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h b/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h deleted file mode 100644 index bac498e18eb241d3b3044f14c88ac2b3aaaa322f..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_COMPILER_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_COMPILER_FLAGS_H_ - -// Legacy flags for the XLA's cpu_compiler module. - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with XLA's cpu_compiler -// module. -void AppendCpuCompilerFlags(std::vector* flag_list); - -// The values of flags associated with XLA's cpu_compiler module. -typedef struct { - bool xla_cpu_embed_ir; // Embed the LLVM IR module string in the resultant - // CpuExecutable - string xla_cpu_dump_debug_json_to; // Dump debug JSON to this directory. -} CpuCompilerFlags; - -// Return a pointer to the CpuCompilerFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -CpuCompilerFlags* GetCpuCompilerFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_COMPILER_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.cc b/tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.cc deleted file mode 100644 index d7817c5d54a047b1987a19dfbde9f48081ae6413..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.cc +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's cpu_runtime module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static CpuRuntimeFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new CpuRuntimeFlags; - flags->xla_cpu_use_eigen = true; - flags->xla_cpu_multi_thread_eigen = true; - flag_list = new std::vector({ - tensorflow::Flag( - "xla_cpu_use_eigen", &flags->xla_cpu_use_eigen, - "Use Eigen for matrix multiply on the CPU platform. This " - "is a useful hack for performance comparisons against " - "XLA's implementation."), - tensorflow::Flag( - "xla_cpu_multi_thread_eigen", &flags->xla_cpu_multi_thread_eigen, - "When generating calls to Eigen for matmul and conv, should " - "single or multi-threaded eigen be used? " - "Only used when --xla_cpu_use_eigen is true."), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's cpu_runtime -// module. -void AppendCpuRuntimeFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the CpuRuntimeFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -CpuRuntimeFlags* GetCpuRuntimeFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h b/tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h deleted file mode 100644 index e3ff30da36a5fabd7d7798fd636cb3955a91b09f..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_RUNTIME_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_RUNTIME_FLAGS_H_ - -// Legacy flags for the XLA's cpu_runtime module. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with XLA's cpu_runtime -// module. -void AppendCpuRuntimeFlags(std::vector* flag_list); - -// The values of flags associated with XLA's cpu_runtime module. -typedef struct { - // Use Eigen for matrix multiply on the CPU platform. This is a useful hack - // for performance comparisons against XLA's implementation. - bool xla_cpu_use_eigen; - // When generating calls to Eigen for matmul and conv, should single or - // multi-threaded eigen be used? Only used when --xla_cpu_use_eigen is true. - bool xla_cpu_multi_thread_eigen; -} CpuRuntimeFlags; - -// Return a pointer to the CpuRuntimeFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -CpuRuntimeFlags* GetCpuRuntimeFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_RUNTIME_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index 5e3c4f912bf6073e89a66633c44a7e052ca43ade..2f3ec403a058c47daa53337cad65c5e90dbc8748 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -25,9 +25,29 @@ namespace legacy_flags { struct DebugOptionsFlags { string xla_generate_hlo_graph; + bool xla_hlo_graph_addresses; + bool xla_hlo_graph_layout; + string xla_hlo_graph_path; + bool xla_hlo_dump_as_graphdef; + string xla_log_hlo_text; + string xla_generate_hlo_text_to; + string xla_disable_hlo_passes; bool xla_enable_fast_math; + bool xla_llvm_enable_alias_scope_metadata; + bool xla_llvm_enable_noalias_metadata; + bool xla_llvm_enable_invariant_load_metadata; int32 xla_backend_optimization_level; + bool xla_embed_ir_in_executable; + string xla_dump_ir_to; + string xla_dump_debug_json_to; + bool xla_eliminate_hlo_implicit_broadcast; + + bool xla_cpu_multi_thread_eigen; + + string xla_gpu_cuda_data_dir; + bool xla_gpu_ftz; + string xla_backend_extra_options; }; @@ -42,9 +62,25 @@ std::once_flag flags_init; void AllocateFlags() { flag_values = new DebugOptionsFlags; flag_values->xla_generate_hlo_graph = ""; + flag_values->xla_hlo_graph_addresses = false; + flag_values->xla_hlo_graph_layout = false; + flag_values->xla_hlo_graph_path = "/tmp/"; + flag_values->xla_hlo_dump_as_graphdef = false; + flag_values->xla_log_hlo_text = ""; + flag_values->xla_generate_hlo_text_to = ""; flag_values->xla_disable_hlo_passes = ""; flag_values->xla_enable_fast_math = true; - flag_values->xla_backend_optimization_level = 2; + flag_values->xla_llvm_enable_alias_scope_metadata = true; + flag_values->xla_llvm_enable_noalias_metadata = true; + flag_values->xla_llvm_enable_invariant_load_metadata = true; + flag_values->xla_backend_optimization_level = 3; + flag_values->xla_embed_ir_in_executable = false; + flag_values->xla_dump_ir_to = ""; + flag_values->xla_dump_debug_json_to = ""; + flag_values->xla_eliminate_hlo_implicit_broadcast = false; + flag_values->xla_cpu_multi_thread_eigen = true; + flag_values->xla_gpu_cuda_data_dir = "./cuda_sdk_lib"; + flag_values->xla_gpu_ftz = false; flag_values->xla_backend_extra_options = ""; flag_objects = new std::vector( @@ -52,27 +88,83 @@ void AllocateFlags() { "xla_generate_hlo_graph", &flag_values->xla_generate_hlo_graph, "HLO modules matching this regex will be dumped to a .dot file " "throughout various stages in compilation."), - + tensorflow::Flag( + "xla_hlo_graph_addresses", &flag_values->xla_hlo_graph_addresses, + "With xla_generate_hlo_graph, show addresses of HLO ops in " + "graph dump."), + tensorflow::Flag( + "xla_hlo_graph_layout", &flag_values->xla_hlo_graph_layout, + "With xla_generate_hlo_graph, show layout of HLO ops in " + "graph dump."), + tensorflow::Flag( + "xla_hlo_graph_path", &flag_values->xla_hlo_graph_path, + "With xla_generate_hlo_graph, dump the graphs into this path."), + tensorflow::Flag("xla_hlo_dump_as_graphdef", + &flag_values->xla_hlo_dump_as_graphdef, + "Dump HLO graphs as TensorFlow GraphDefs."), + tensorflow::Flag( + "xla_log_hlo_text", &flag_values->xla_log_hlo_text, + "HLO modules matching this regex will be dumped to LOG(INFO). "), + tensorflow::Flag( + "xla_generate_hlo_text_to", &flag_values->xla_generate_hlo_text_to, + "Dump all HLO modules as text into the provided directory path."), tensorflow::Flag( "xla_enable_fast_math", &flag_values->xla_enable_fast_math, "Enable unsafe fast-math optimizations in the compiler; " "this may produce faster code at the expense of some accuracy."), + tensorflow::Flag("xla_llvm_enable_alias_scope_metadata", + &flag_values->xla_llvm_enable_alias_scope_metadata, + "In LLVM-based backends, enable the emission of " + "!alias.scope metadata in the generated IR."), + tensorflow::Flag("xla_llvm_enable_noalias_metadata", + &flag_values->xla_llvm_enable_noalias_metadata, + "In LLVM-based backends, enable the emission of " + "!noalias metadata in the generated IR."), + tensorflow::Flag("xla_llvm_enable_invariant_load_metadata", + &flag_values->xla_llvm_enable_invariant_load_metadata, + "In LLVM-based backends, enable the emission of " + "!invariant.load metadata in " + "the generated IR."), tensorflow::Flag( "xla_backend_optimization_level", &flag_values->xla_backend_optimization_level, "Numerical optimization level for the XLA compiler backend."), - + tensorflow::Flag( + "xla_disable_hlo_passes", &flag_values->xla_disable_hlo_passes, + "Comma-separated list of hlo passes to be disabled. These names " + "must exactly match the passes' names; no whitespace around " + "commas."), + tensorflow::Flag("xla_embed_ir_in_executable", + &flag_values->xla_embed_ir_in_executable, + "Embed the compiler IR as a string in the executable."), + tensorflow::Flag("xla_dump_ir_to", &flag_values->xla_dump_ir_to, + "Dump the compiler IR into this file/path."), + tensorflow::Flag("xla_eliminate_hlo_implicit_broadcast", + &flag_values->xla_eliminate_hlo_implicit_broadcast, + "Eliminate implicit broadcasts when lowering user " + "computations to HLO instructions; use explicit " + "broadcast instead."), + tensorflow::Flag("xla_cpu_multi_thread_eigen", + &flag_values->xla_cpu_multi_thread_eigen, + "When generating calls to Eigen in the CPU backend, " + "use multi-threaded Eigen mode."), + tensorflow::Flag("xla_gpu_cuda_data_dir", + &flag_values->xla_gpu_cuda_data_dir, + "If non-empty, speficies a local directory containing " + "ptxas and nvvm libdevice files; otherwise we use " + "those from runfile directories."), + tensorflow::Flag("xla_gpu_ftz", &flag_values->xla_gpu_ftz, + "If true, flush-to-zero semantics are enabled in the " + "code generated for GPUs."), + tensorflow::Flag( + "xla_dump_debug_json_to", &flag_values->xla_dump_debug_json_to, + "Dump compilation artifacts as JSON into this directory."), tensorflow::Flag("xla_backend_extra_options", &flag_values->xla_backend_extra_options, "Extra options to pass to a backend; " "comma-separated list of 'key=val' strings (=val " - "may be omitted); no whitespace around commas."), + "may be omitted); no whitespace around commas.")}); - tensorflow::Flag( - "xla_disable_hlo_passes", &flag_values->xla_disable_hlo_passes, - "Comma-separated list of HLO passes to be disabled. These names " - "must exactly match the passes' names; " - "no whitespace around commas.")}); ParseFlagsFromEnv(*flag_objects); } @@ -89,6 +181,12 @@ xla::DebugOptions GetDebugOptionsFromFlags() { DebugOptions options; options.set_xla_generate_hlo_graph(flag_values->xla_generate_hlo_graph); + options.set_xla_hlo_graph_addresses(flag_values->xla_hlo_graph_addresses); + options.set_xla_hlo_graph_layout(flag_values->xla_hlo_graph_layout); + options.set_xla_hlo_graph_path(flag_values->xla_hlo_graph_path); + options.set_xla_hlo_dump_as_graphdef(flag_values->xla_hlo_dump_as_graphdef); + options.set_xla_log_hlo_text(flag_values->xla_log_hlo_text); + options.set_xla_generate_hlo_text_to(flag_values->xla_generate_hlo_text_to); std::vector disabled_passes = tensorflow::str_util::Split(flag_values->xla_disable_hlo_passes, ','); @@ -99,6 +197,23 @@ xla::DebugOptions GetDebugOptionsFromFlags() { options.set_xla_enable_fast_math(flag_values->xla_enable_fast_math); options.set_xla_backend_optimization_level( flag_values->xla_backend_optimization_level); + options.set_xla_embed_ir_in_executable( + flag_values->xla_embed_ir_in_executable); + options.set_xla_dump_ir_to(flag_values->xla_dump_ir_to); + options.set_xla_eliminate_hlo_implicit_broadcast( + flag_values->xla_eliminate_hlo_implicit_broadcast); + options.set_xla_dump_debug_json_to(flag_values->xla_dump_debug_json_to); + + options.set_xla_cpu_multi_thread_eigen( + flag_values->xla_cpu_multi_thread_eigen); + options.set_xla_gpu_cuda_data_dir(flag_values->xla_gpu_cuda_data_dir); + options.set_xla_gpu_ftz(flag_values->xla_gpu_ftz); + options.set_xla_llvm_enable_alias_scope_metadata( + flag_values->xla_llvm_enable_alias_scope_metadata); + options.set_xla_llvm_enable_noalias_metadata( + flag_values->xla_llvm_enable_noalias_metadata); + options.set_xla_llvm_enable_invariant_load_metadata( + flag_values->xla_llvm_enable_invariant_load_metadata); std::vector extra_options_parts = tensorflow::str_util::Split(flag_values->xla_backend_extra_options, ','); diff --git a/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.cc b/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.cc deleted file mode 100644 index f8f6ea26b1d0df67b934616fe60aa29199fc2eb9..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.cc +++ /dev/null @@ -1,88 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's gpu_backend_lib module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static GpuBackendLibFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new GpuBackendLibFlags; - flags->dump_temp_products_to = ""; - flags->ftz = false; - flags->fma = true; - flags->verbose_ptx_asm = false; - flags->kernel = ""; - flags->llvm_dump_passes = false; - flags->llvm_cl_opts = ""; - flags->dump_ir_before_passes = false; - flags->opt_level = 3; - flag_list = new std::vector({ - tensorflow::Flag("dump_temp_products_to", &flags->dump_temp_products_to, - "dump temporary compilation products to this directory. " - "If empty, no dump is produced"), - tensorflow::Flag("ftz", &flags->ftz, "flush to zero semantics"), - tensorflow::Flag("fma", &flags->fma, "use FMA synthesis"), - tensorflow::Flag("verbose_ptx_asm", &flags->verbose_ptx_asm, - "emit PTX assembly with extra comments"), - tensorflow::Flag("kernel", &flags->kernel, - "only emit the IR and PTX for this kernel"), - tensorflow::Flag("llvm_dump_passes", &flags->llvm_dump_passes, - "dump the passes LLVM runs to stderr"), - tensorflow::Flag( - "llvm_cl_opts", &flags->llvm_cl_opts, - "comma-separated list of command line options to pass to " - "LLVM. For example, --llvm_cl_opts=--print-before=loop-unroll"), - tensorflow::Flag("dump_ir_before_passes", &flags->dump_ir_before_passes, - "dump the IR before each optimization pass in " - "sequentially-named files."), - tensorflow::Flag("opt_level", &flags->opt_level, - "optimization level (default to 3)"), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's gpu_backend_lib -// module. -void AppendGpuBackendLibFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the GpuBackendLibFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -GpuBackendLibFlags* GetGpuBackendLibFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.h b/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.h deleted file mode 100644 index 31cb50e9da986b5bad3e71439a4976ec84e17be7..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.h +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_BACKEND_LIB_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_BACKEND_LIB_FLAGS_H_ - -// Legacy flags for XLA's gpu_backend_lib module. - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with XLA's gpu_backend_lib -// module. -void AppendGpuBackendLibFlags(std::vector* flag_list); - -// The values of flags associated with XLA's gpu_backend_lib module. -typedef struct { - string dump_temp_products_to; // temporary compilation products dir - bool ftz; // flush to zero semantics - bool fma; // use FMA synthesis - bool verbose_ptx_asm; // emit PTX assembly with extra comments - string kernel; // only emit the IR and PTX for this kernel - bool llvm_dump_passes; // dump the passes LLVM runs to stderr - string llvm_cl_opts; // comma-separated list of LLVM options - bool dump_ir_before_passes; // dump IR before each pass - int32 opt_level; // optimization level -} GpuBackendLibFlags; - -// Return a pointer to the GpuBackendLibFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -GpuBackendLibFlags* GetGpuBackendLibFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_BACKEND_LIB_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc b/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc deleted file mode 100644 index 131e3ce70ac9e7fc2f6f233ffd93e8757d0bc725..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's gpu_compiler module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static GpuCompilerFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new GpuCompilerFlags; - flags->xla_gpu_embed_ir = false; - flags->xla_cuda_data_dir = "./cuda_sdk_lib"; - flags->xla_gpu_dump_debug_json_to = ""; - flag_list = new std::vector({ - tensorflow::Flag( - "xla_gpu_embed_ir", &flags->xla_gpu_embed_ir, - "Embed the LLVM IR module string in the resultant GpuExecutable."), - tensorflow::Flag( - "xla_cuda_data_dir", &flags->xla_cuda_data_dir, - "If non-empty, specifies a local directory containing ptxas and " - "nvvm libdevice files. Otherwise, by default, we use those from " - "runfile directories."), - tensorflow::Flag("xla_ptxas_path", &flags->xla_ptxas_path, - "The path to ptxas. Required to log stats of the ptx."), - tensorflow::Flag("xla_gpu_dump_debug_json_to", - &flags->xla_gpu_dump_debug_json_to, - "Dump debug JSON to this directory."), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's gpu_compiler -// module. -void AppendGpuCompilerFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the GpuCompilerFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -GpuCompilerFlags* GetGpuCompilerFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h b/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h deleted file mode 100644 index 0cf39e0ab35e663c7abc14980daa8b92d15489d6..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_COMPILER_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_COMPILER_FLAGS_H_ - -// Legacy flags for XLA's gpu_compiler module. - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with XLA's gpu_compiler -// module. -void AppendGpuCompilerFlags(std::vector* flag_list); - -// The values of flags associated with XLA's gpu_compiler module. -typedef struct { - bool xla_gpu_embed_ir; // Embed the LLVM IR module string in the resultant - // GpuExecutable. - string xla_cuda_data_dir; // If non-empty, specifies a local directory - // containing ptxas and nvvm libdevice files. - // Otherwise, by default, we use those from runfile - // directories. - string xla_ptxas_path; // The path to ptxas. Required to log stats of - // the ptx. - string xla_gpu_dump_debug_json_to; // Dump debug JSON to this directory. -} GpuCompilerFlags; - -// Return a pointer to the GpuCompilerFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -GpuCompilerFlags* GetGpuCompilerFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_COMPILER_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.cc b/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.cc deleted file mode 100644 index ba43a5919522ff783f450481c629d64613e1f8ab..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.cc +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's hlo_graph_dumper module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static HloGraphDumperFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new HloGraphDumperFlags; - flags->xla_hlo_dump_graph_path = "/tmp/"; - flags->xla_hlo_dump_as_graphdef = false; - flag_list = new std::vector({ - tensorflow::Flag("xla_hlo_dump_graph_path", - &flags->xla_hlo_dump_graph_path, - "Path to write dumped HLO graphs to"), - tensorflow::Flag("xla_hlo_dump_as_graphdef", - &flags->xla_hlo_dump_as_graphdef, - "Dumps HLO graphs as tensorflow GraphDefs"), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's hlo_graph_dumper -// module. -void AppendHloGraphDumperFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the HloGraphDumperFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -HloGraphDumperFlags* GetHloGraphDumperFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h b/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h deleted file mode 100644 index d0b4d092ff1003bc1df90c3d878feacf71a5aa21..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_GRAPH_DUMPER_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_GRAPH_DUMPER_FLAGS_H_ - -// Legacy flags for XLA's hlo_graph_dumper module. - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with XLA's hlo_graph_dumper -// module. -void AppendHloGraphDumperFlags(std::vector* flag_list); - -// The values of flags associated with XLA's hlo_graph_dumper module. -typedef struct { - string xla_hlo_dump_graph_path; // Path to write dumped HLO graphs to - // If set, dumps HLO graphs as tensorflow GraphDef; otherwise, dumps HLO - // graphs as DOT graph. - bool xla_hlo_dump_as_graphdef; -} HloGraphDumperFlags; - -// Return a pointer to the HloGraphDumperFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -HloGraphDumperFlags* GetHloGraphDumperFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_GRAPH_DUMPER_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.cc b/tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.cc deleted file mode 100644 index c7893c138596b034dbb83df9fda2d4c5edd8e32b..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's hlo_test_base module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static HloTestBaseFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new HloTestBaseFlags; - flags->xla_hlo_test_generate_hlo_graph = false; - flag_list = new std::vector({ - tensorflow::Flag("xla_hlo_test_generate_hlo_graph", - &flags->xla_hlo_test_generate_hlo_graph, - "Generate graph output of HLO instructions"), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's hlo_test_base -// module. -void AppendHloTestBaseFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the HloTestBaseFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -HloTestBaseFlags* GetHloTestBaseFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.h b/tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.h deleted file mode 100644 index 23b808cecb7e5eaf480292f5207a4b87ebd4a2d5..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.h +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_TEST_BASE_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_TEST_BASE_FLAGS_H_ - -// Legacy flags for XLA's hlo_test_base module. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with XLA's hlo_test_base -// module. -void AppendHloTestBaseFlags(std::vector* flag_list); - -// The values of flags associated with XLA's hlo_test_base module. -typedef struct { - bool xla_hlo_test_generate_hlo_graph; // Generate graph output of HLO - // instructions -} HloTestBaseFlags; - -// Return a pointer to the HloTestBaseFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -HloTestBaseFlags* GetHloTestBaseFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_TEST_BASE_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/llvm_util_flags.cc b/tensorflow/compiler/xla/legacy_flags/llvm_util_flags.cc deleted file mode 100644 index 3c53729a67049fdac6b358149e06f39858ebd98f..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/llvm_util_flags.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's llvm_util module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static LlvmUtilFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new LlvmUtilFlags; - flags->xla_emit_tbaa = true; - flag_list = new std::vector({ - tensorflow::Flag("xla_emit_tbaa", &flags->xla_emit_tbaa, - "Perform type-based alias analysis optimizations for " - "LLVM-based backends."), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's llvm_util -// module. -void AppendLlvmUtilFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the LlvmUtilFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -LlvmUtilFlags* GetLlvmUtilFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h b/tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h deleted file mode 100644 index 98da26b4b806dd83c7baf6bdcf60cbf5297457a6..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LLVM_UTIL_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LLVM_UTIL_FLAGS_H_ - -// Legacy flags for XLA's llvm_util module. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with XLA's llvm_util module. -void AppendLlvmUtilFlags(std::vector* flag_list); - -// The values of flags associated with XLA's llvm_util module. -typedef struct { - bool xla_emit_tbaa; // Perform type-based alias analysis optimizations for - // LLVM-based backends. -} LlvmUtilFlags; - -// Return a pointer to the LlvmUtilFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -LlvmUtilFlags* GetLlvmUtilFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LLVM_UTIL_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/service_flags.cc b/tensorflow/compiler/xla/legacy_flags/service_flags.cc index 41cb8d8bdfc51de1d8fe77906317b4b4a0804802..90d30e756905259ffcd4ac10163b6cd75c51ae0a 100644 --- a/tensorflow/compiler/xla/legacy_flags/service_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/service_flags.cc @@ -36,33 +36,13 @@ static std::once_flag flags_init; static void AllocateFlags() { flags = new ServiceFlags; flags->xla_hlo_profile = false; - flags->xla_log_hlo_text = ""; - flags->xla_generate_hlo_graph = ""; - flags->xla_hlo_graph_addresses = false; - flags->xla_hlo_graph_layout = false; flags->xla_hlo_graph_for_compute_constant = false; flags->xla_dump_computations_to = ""; - flags->xla_dump_hlo_text_to = ""; flags->xla_dump_executions_to = ""; flag_list = new std::vector({ tensorflow::Flag( "xla_hlo_profile", &flags->xla_hlo_profile, "Instrument the computation to collect per-HLO cycle counts"), - tensorflow::Flag( - "xla_log_hlo_text", &flags->xla_log_hlo_text, - "If non-empty, print the text format of " - "HLO modules whose name partially matches this regex. E.g. " - "xla_log_hlo_text=.* will dump the text for every module."), - tensorflow::Flag( - "xla_generate_hlo_graph", &flags->xla_generate_hlo_graph, - "If non-empty, dump graph of HLO modules whose name partially " - "matches this regex. E.g. --xla_generate_hlo_graph=.* will dump " - "the graph of every module."), - tensorflow::Flag("xla_hlo_graph_addresses", - &flags->xla_hlo_graph_addresses, - "Show addresses of HLO ops in graph"), - tensorflow::Flag("xla_hlo_graph_layout", &flags->xla_hlo_graph_layout, - "Show layout of HLO ops in graph"), tensorflow::Flag( "xla_hlo_graph_for_compute_constant", &flags->xla_hlo_graph_for_compute_constant, @@ -72,9 +52,6 @@ static void AllocateFlags() { &flags->xla_dump_computations_to, "Dumps computations that XLA executes into the provided " "directory path"), - tensorflow::Flag("xla_dump_hlo_text_to", &flags->xla_dump_hlo_text_to, - "Dumps HLO modules that XLA executes into the provided " - "directory path"), tensorflow::Flag("xla_dump_executions_to", &flags->xla_dump_executions_to, "Dumps parameters and results of computations that XLA " "executes into the provided directory path"), diff --git a/tensorflow/compiler/xla/legacy_flags/service_flags.h b/tensorflow/compiler/xla/legacy_flags/service_flags.h index d982506944daed41eb6e7c4a238d540b38cf8be3..72d0c52402c43b41e6de3ef4638929bd682d5029 100644 --- a/tensorflow/compiler/xla/legacy_flags/service_flags.h +++ b/tensorflow/compiler/xla/legacy_flags/service_flags.h @@ -34,23 +34,11 @@ void AppendServiceFlags(std::vector* flag_list); typedef struct { bool xla_hlo_profile; // Instrument the computation to collect per-HLO cycle // counts - string xla_log_hlo_text; // If non-empty, print the text format of the HLO - // modules whose name partially - // matches this regex. E.g. xla_log_hlo_text=.* - // will dump the text for every module. - string xla_generate_hlo_graph; // If non-empty, dump graph of HLO modules - // whose name partially matches this regex. - // E.g. --xla_generate_hlo_graph=.* will dump - // the graph of every module. - bool xla_hlo_graph_addresses; // Show addresses of HLO ops in graph - bool xla_hlo_graph_layout; // Show layout of HLO ops in graph bool xla_hlo_graph_for_compute_constant; // If true, include hlo dumps of // graphs from ComputeConstant. // Such graphs still need to be // matched via // xla_generate_hlo_graph. - string xla_dump_hlo_text_to; // Dumps HLO text for each HLO module that is - // executed into the provided directory path string xla_dump_computations_to; // Dumps computations that XLA executes // into the provided directory path // Dumps parameters and results of computations that XLA executes into diff --git a/tensorflow/compiler/xla/legacy_flags/user_computation_flags.cc b/tensorflow/compiler/xla/legacy_flags/user_computation_flags.cc deleted file mode 100644 index a9597d0cd8f89d7d664c38b79d225b0aa6b6b13b..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/user_computation_flags.cc +++ /dev/null @@ -1,64 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static UserComputationFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new UserComputationFlags; - flags->xla_eliminate_hlo_implicit_broadcast = false; - flag_list = new std::vector({ - tensorflow::Flag("xla_eliminate_hlo_implicit_broadcast", - &flags->xla_eliminate_hlo_implicit_broadcast, - "Eliminate implicit broadcast on when lowering user " - "computation to HLO instructions, use explicit " - "broadcast instead."), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's hlo_pass_pipeline -// module. -void AppendUserComputationFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the UserComputationFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -UserComputationFlags* GetUserComputationFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/user_computation_flags.h b/tensorflow/compiler/xla/legacy_flags/user_computation_flags.h deleted file mode 100644 index f5222c927cb203b901fb3bc6ea3d2e7d30cb658a..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/user_computation_flags.h +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_USER_COMPUTATION_FLAGS_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_USER_COMPUTATION_FLAGS_H_ - -// Legacy flags for XLA's user_computation module. - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flags definitions associated with XLA's user_computation -// module. -void AppendUserComputationFlags(std::vector* flag_list); - -typedef struct { - // Eliminate implicit broadcast on when lowering user computation to HLO - // instructions, use explicit broadcast instead. - bool xla_eliminate_hlo_implicit_broadcast; -} UserComputationFlags; - -// Return a pointer to the UserComputationFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -UserComputationFlags* GetUserComputationFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_USER_COMPUTATION_FLAGS_H_ diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index caef3a3869f4bcde7a6982ce3dfc0db9d36cbc5e..6760f72e55f8920e1948c0b1f5fd8751dbe858e0 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -62,7 +62,17 @@ Literal::StrideConfig::StrideConfig( std::unique_ptr Literal::CreateFromShape(const Shape& shape) { auto literal = MakeUnique(); *literal->mutable_shape() = shape; - literal->Reserve(ShapeUtil::ElementsIn(literal->shape())); + if (ShapeUtil::IsTuple(shape)) { + int64 num_elements = ShapeUtil::TupleElementCount(shape); + literal->tuple_literals_.resize(num_elements); + for (int i = 0; i < num_elements; ++i) { + std::unique_ptr elem = + CreateFromShape(ShapeUtil::GetTupleElementShape(shape, i)); + literal->tuple_literals_[i] = std::move(*elem); + } + } else { + literal->Reserve(ShapeUtil::ElementsIn(literal->shape())); + } return literal; } @@ -321,6 +331,7 @@ Status Literal::Copy(const Literal& src_literal, } std::unique_ptr Literal::Relayout(const Layout& layout) const { + CHECK(ShapeUtil::IsArray(shape())); std::unique_ptr result = CloneToUnique(); *result->mutable_shape()->mutable_layout() = layout; @@ -754,10 +765,30 @@ void Literal::EachCellAsString( } namespace { +template +std::unique_ptr ConvertBetweenNativeTypes(const Literal& src_literal) { + auto result_literal = MakeUnique(); + Shape* result_shape = result_literal->mutable_shape(); + *result_shape = src_literal.shape(); + result_shape->set_element_type( + primitive_util::NativeToPrimitiveType()); + result_literal->Reserve(ShapeUtil::ElementsIn(*result_shape)); + tensorflow::gtl::ArraySlice src_data = + src_literal.GetArraySlice(); + tensorflow::gtl::MutableArraySlice dest_data = + result_literal->GetMutableArraySlice(); + int64 num_elements = ShapeUtil::ElementsIn(src_literal.shape()); + + for (int64 i = 0; i < num_elements; ++i) { + dest_data[i] = static_cast(src_data[i]); + } + return result_literal; +} + template std::unique_ptr ConvertIfTypesMatch(const Literal& src_literal) { CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); - return LiteralUtil::Convert< + return ConvertBetweenNativeTypes< typename primitive_util::PrimitiveTypeToNative::type, typename primitive_util::PrimitiveTypeToNative< primitive_dest_type>::type>(src_literal); @@ -782,19 +813,20 @@ StatusOr> ConvertIfDestTypeMatches( #undef CONVERT_IF_TYPES_MATCH // Other types are not yet supported. default: - return tensorflow::errors::InvalidArgument( - "Unimplemented: ConvertIfDestTypeMatches for type " + - PrimitiveType_Name(src_literal.shape().element_type())); + return InvalidArgument( + "Unimplemented: Convert from type %s to type %s", + PrimitiveType_Name(src_literal.shape().element_type()).c_str(), + PrimitiveType_Name(primitive_dest_type).c_str()); } } -} +} // namespace -StatusOr> LiteralUtil::ConvertIfSrcTypeMatches( - const Literal& src_literal, PrimitiveType primitive_dest_type) { - switch (src_literal.shape().element_type()) { +StatusOr> Literal::Convert( + PrimitiveType primitive_dest_type) const { + switch (shape().element_type()) { #define CONVERT_IF_DEST_TYPE_MATCHES(type) \ case (type): \ - return ConvertIfDestTypeMatches<(type)>(src_literal, primitive_dest_type); + return ConvertIfDestTypeMatches<(type)>(*this, primitive_dest_type); CONVERT_IF_DEST_TYPE_MATCHES(PRED) CONVERT_IF_DEST_TYPE_MATCHES(S8) CONVERT_IF_DEST_TYPE_MATCHES(S32) @@ -807,9 +839,9 @@ StatusOr> LiteralUtil::ConvertIfSrcTypeMatches( #undef CONVERT_IF_DEST_TYPE_MATCHES // Other types are not yet supported. default: - return tensorflow::errors::InvalidArgument( - "Unimplemented: ConvertIfSrcTypeMatches for type " + - PrimitiveType_Name(src_literal.shape().element_type())); + return InvalidArgument("Unimplemented: Convert from type %s to type %s", + PrimitiveType_Name(shape().element_type()).c_str(), + PrimitiveType_Name(primitive_dest_type).c_str()); } } @@ -971,7 +1003,7 @@ tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { // support in protobuf auto values = mutable_f16s(); return tensorflow::gtl::MutableArraySlice( - reinterpret_cast(&(*values)[0]), values->size() / sizeof(half)); + reinterpret_cast(&(*values)[0]), values->size()); } template <> diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 42c8b61acec8f4dc661111affc17773b1aa71583..8266511614d6c4f7593e8143f04c80e6662c118f 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -68,11 +68,13 @@ class BoolVector { } BoolVector(const BoolVector& other) { CopyFrom(other); } + BoolVector(BoolVector&&) = default; BoolVector& operator=(const BoolVector& other) { CopyFrom(other); return *this; } + BoolVector& operator=(BoolVector&&) = default; void push_back(const bool& value) { resize(size_ + 1); @@ -147,10 +149,12 @@ class Literal { Literal() {} Literal(const Literal& other) = default; + Literal(Literal&&) = default; explicit Literal(const LiteralProto& other) { CopyFromProto(other); } Literal& operator=(const Literal& other) = default; + Literal& operator=(Literal&&) = default; LiteralProto ToProto() const; @@ -251,7 +255,7 @@ class Literal { *other = temp; } - // CreatesCreate new literal of a given rank. To minimize ambiguity (for users + // Creates a new literal of a given rank. To minimize ambiguity (for users // and the compiler) these CreateR[0-2] methods should explicitly specify the // native type. For example: // @@ -362,10 +366,10 @@ class Literal { template std::unique_ptr Replicate(int64 times) const; - // Creates a literal by converting each element in this literal to a new - // type. - template - std::unique_ptr Convert() const; + // Converts this literal to another primitive type. Returns an error if the + // conversion is not possible. + StatusOr> Convert( + PrimitiveType primitive_dest_type) const; // Creates a literal value zero of the given primitive type. static Literal Zero(PrimitiveType primitive_type); @@ -444,10 +448,21 @@ class Literal { template void Set(tensorflow::gtl::ArraySlice multi_index, NativeT value); - // Retrieves the mutable array slice interface which can be used to manipulate - // pre-allocated literal values. + // Returns a (Mutable)ArraySlice view of the array for this literal for the + // given NativeT (e.g., float). These functions map native type to XLA + // PrimitiveType via template specialization. The unspecialized forms below + // aborts to handle the error case where the given native type does not map to + // an XLA primitive type. template - tensorflow::gtl::MutableArraySlice GetMutableArraySlice(); + tensorflow::gtl::ArraySlice GetArraySlice() const { + static_assert(!std::is_same::value, + "Cannot map native type to primitive type."); + } + template + tensorflow::gtl::MutableArraySlice GetMutableArraySlice() { + static_assert(!std::is_same::value, + "Cannot map native type to primitive type."); + } // Returns the element value at index (0, ..., 0), however many zeroes are // required for that index. @@ -588,17 +603,6 @@ class Literal { bool IsZero(tensorflow::gtl::ArraySlice indices) const; private: - // Returns an ArraySlice view of the array for this literal for the given - // NativeT (e.g., float). These functions map native type to XLA PrimitiveType - // via template specialization. The unspecialized forms below aborts to handle - // the error case where the given native type does not map to an XLA primitive - // type. - template - tensorflow::gtl::ArraySlice GetArraySlice() const { - static_assert(!std::is_same::value, - "Cannot map native type to primitive type."); - } - // Copy from a LiteralProto instance. void CopyFromProto(const LiteralProto& literal_proto); @@ -646,544 +650,6 @@ class Literal { std::vector tuple_literals_; }; -// Utility class for dealing with XLA literal values. Most methods are -// templated by native (host) type which corresponds to a unique XLA -// PrimitiveType. See ComputationBuilder for details. Not all primitive types -// defined in xla_data.proto have a corresponding native type or even have a -// storage location in the Literal proto yet (for example, primitive type F16). -// -// TODO(dnovillo) - All functions in this class simply redirect to the -// corresponding function in class Literal. Remove this class after converting -// all user code to use Literal directly. -class LiteralUtil { - public: - // Creates new literal of a given rank. To minimize ambiguity (for users and - // the compiler) these CreateR[0-2] methods should explicitly specify the - // native type. For example: - // - // CreateR1({1.0, 42.0}); - // CreateR2({{1, 2}, {3, 4}}); - // - // The variants not ending with WithLayout use the default XLA layout for the - // literal's linear representation in memory. - template - static std::unique_ptr CreateR0(NativeT value) { - return Literal::CreateR0(value); - } - - template - static std::unique_ptr CreateR1( - tensorflow::gtl::ArraySlice values) { - return Literal::CreateR1(values); - } - - static std::unique_ptr CreateR1( - const tensorflow::core::Bitmap& values) { - return Literal::CreateR1(values); - } - - template - static std::unique_ptr CreateR2( - std::initializer_list> values) { - return Literal::CreateR2(values); - } - - template - static std::unique_ptr CreateR2WithLayout( - std::initializer_list> values, - const Layout& layout) { - return Literal::CreateR2WithLayout(values, layout); - } - - template - static std::unique_ptr CreateR3( - std::initializer_list< - std::initializer_list>> - values) { - return Literal::CreateR3(values); - } - - template - static std::unique_ptr CreateR3WithLayout( - std::initializer_list< - std::initializer_list>> - values, - const Layout& layout) { - return Literal::CreateR3WithLayout(values, layout); - } - - template - static std::unique_ptr CreateR4( - std::initializer_list>>> - values) { - return Literal::CreateR4(values); - } - - template - static std::unique_ptr CreateR4WithLayout( - std::initializer_list>>> - values, - const Layout& layout) { - return Literal::CreateR4WithLayout(values, layout); - } - - // Creates a new Literal object with the shape specified as parameter. - // The content of the literal values is the default value of the primitive - // type of literal itself (0 for numeric types, and false for predicates). - static std::unique_ptr CreateFromShape(const Shape& shape) { - return Literal::CreateFromShape(shape); - } - - // Creates a new Literal object with its values havings the primitive_type - // type, and with dimensions defined by the dimensions parameter. - // The content of the literal values is the default value of the primitive - // type of literal itself (0 for numeric types, and false for predicates). - static std::unique_ptr CreateFromDimensions( - PrimitiveType primitive_type, - tensorflow::gtl::ArraySlice dimensions) { - return Literal::CreateFromDimensions(primitive_type, dimensions); - } - - // Copies the values from src_literal, starting at src_base shape indexes, - // to dest_literal, starting at dest_base, where the copy size in each - // dimension is specified by copy_size. - // - // The src_literal and dest_literal must have the same primitive type, - // src_base+copy_size must fit the source literal dimensions, as well as - // dest_base+copy_size must fit the destination literal dimensions. - static Status Copy(const Literal& src_literal, - tensorflow::gtl::ArraySlice src_base, - Literal* dest_literal, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size) { - return dest_literal->Copy(src_literal, src_base, dest_base, copy_size); - } - - // Creates a new value that has the equivalent value as literal, but conforms - // to new_layout; e.g. a literal matrix that was in {0, 1} minor-to-major - // dimension layout can be re-laid-out as {1, 0} minor-to-major dimension - // layout and the value in the cell at any given logical index (i0, i1) will - // be the same. - // - // Note: this is useful when the client wants to ensure that a value placed in - // the XLA allocation tracker has a particular layout; for efficiency - // purposes or avoiding unimplemented operation/layout combinations. - static std::unique_ptr Relayout(const Literal& literal, - const Layout& new_layout) { - return literal.Relayout(new_layout); - } - - // Reshapes literal 'input' to have 'shape'. Both the original shape and - // 'shape' must contain the same number of elements. The implementation - // currently only supports monotonic dim0-major layouts. - static StatusOr> Reshape( - const xla::Literal& input, tensorflow::gtl::ArraySlice shape) { - return input.Reshape(shape); - } - - // Creates a new literal by reordering the dimensions of the original literal. - // The given `permutation` must be a permutation of the dimension numbers - // in the original literal, and it specifies the order of the new dimensions - // in the result literal (i.e., new_order[i] = old_order[permutation[i]]). - // For example, a transpose call on a literal of shape [3 x 8 x 4] and - // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. - static std::unique_ptr Transpose( - const Literal& literal, tensorflow::gtl::ArraySlice permutation) { - return literal.Transpose(permutation); - } - - // Creates a sub-array from the given literal by extracting the indices - // [start_index, limit_index) of each dimension. The result literal has the - // same rank and layout as for the given literal. The number of indices in - // start_indices and limit_indices must be the rank of the literal, and the - // indices follow the order of the dimensions. - static std::unique_ptr Slice( - const Literal& literal, tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices) { - return literal.Slice(start_indices, limit_indices); - } - - // Creates a literal with a prepended dimension with bound "times"; e.g. a - // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from the input - // literal replicated four times. - template - static std::unique_ptr Replicate(const Literal& input, int64 times) { - return input.Replicate(times); - } - - // Creates a literal by converting each element in an original literal to a - // new type. - template - static std::unique_ptr Convert(const Literal& literal) { - return literal.Convert(); - } - - // Convert a literal to another primitive type, but only if the literal - // type is connvertable into the destination type - static StatusOr> ConvertIfSrcTypeMatches( - const Literal& src_literal, PrimitiveType primitive_dest_type); - - // Creates a literal value zero of the given primitive type. - static Literal Zero(PrimitiveType primitive_type) { - return Literal::Zero(primitive_type); - } - - // Creates a literal value one of the given primitive type. - static Literal One(PrimitiveType primitive_type) { - return Literal::One(primitive_type); - } - - // Creates a literal value containing the minimum value of the given - // primitive type. For floating-point types, returns -inf. - static Literal MinValue(PrimitiveType primitive_type) { - return Literal::MinValue(primitive_type); - } - - // Creates a literal value containing the maximum value of the given - // primitive type. For floating-point types, returns inf. - static Literal MaxValue(PrimitiveType primitive_type) { - return Literal::MaxValue(primitive_type); - } - - // Creates a literal of the given shape where each element is `value`. - template - static std::unique_ptr CreateFullWithMonotonicDim0MajorLayout( - tensorflow::gtl::ArraySlice dimensions, NativeT value) { - return Literal::CreateFullWithMonotonicDim0MajorLayout(dimensions, value); - } - - // Creates a new literal from an array. The variants not ending with - // WithLayout use the default XLA layout for the literal's linear - // representation in memory. - template - static std::unique_ptr CreateR2FromArray2D( - const Array2D& values) { - return Literal::CreateR2FromArray2D(values); - } - - template - static std::unique_ptr CreateR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout) { - return Literal::CreateR2FromArray2DWithLayout(values, layout); - } - - template - static std::unique_ptr CreateR3FromArray3D( - const Array3D& values) { - return Literal::CreateR3FromArray3D(values); - } - - template - static std::unique_ptr CreateR3FromArray3DWithLayout( - const Array3D& values, const Layout& layout) { - return Literal::CreateR3FromArray3DWithLayout(values, layout); - } - - template - static std::unique_ptr CreateR4FromArray4D( - const Array4D& values) { - return Literal::CreateR4FromArray4D(values); - } - - template - static std::unique_ptr CreateR4FromArray4DWithLayout( - const Array4D& values, const Layout& layout) { - return Literal::CreateR4FromArray4DWithLayout(values, layout); - } - - // Creates a new vector of U8s literal value from a string. - static std::unique_ptr CreateR1U8(tensorflow::StringPiece value) { - return Literal::CreateR1U8(value); - } - - // Creates a linspace-populated literal with the given number of rows and - // columns. - static std::unique_ptr CreateR2F32Linspace(float from, float to, - int64 rows, int64 cols) { - return Literal::CreateR2F32Linspace(from, to, rows, cols); - } - - // Creates a literal that projects the (x, y) dimensions given in values into - // the z dimension given by "projection". - template - static std::unique_ptr CreateR3Projected( - std::initializer_list> values, - int64 projection) { - return Literal::CreateR3Projected(values, projection); - } - - // Creates a literal that projects the (x, y) dimensions given in values into - // the z and p dimensions given. - template - static std::unique_ptr CreateR4Projected( - std::initializer_list> values, - int64 projection_p, int64 projection_z) { - return Literal::CreateR4Projected(values, projection_p, projection_z); - } - - // Clones literal into an owned unique_ptr version. - static std::unique_ptr CloneToUnique(const Literal& literal) { - return literal.CloneToUnique(); - } - - // Returns the linear index of the given index within the literal's - // element_type repeated field. - static int64 LinearIndex(const Literal& literal, - tensorflow::gtl::ArraySlice multi_index) { - return literal.LinearIndex(multi_index); - } - - // Gets or sets an element in the literal at the given index. The index is - // CHECKed against the dimension sizes. - template - static NativeT Get(const Literal& literal, - tensorflow::gtl::ArraySlice multi_index) { - return literal.Get(multi_index); - } - - template - static void Set(Literal* literal, - tensorflow::gtl::ArraySlice multi_index, - NativeT value) { - literal->Set(multi_index, value); - } - - // Retrieves the mutable array slice interface which can be used to manipulate - // pre-allocated literal values. - template - static tensorflow::gtl::MutableArraySlice GetMutableArraySlice( - Literal* literal) { - return literal->GetMutableArraySlice(); - } - - // Returns the element value at index (0, ..., 0), however many zeroes are - // required for that index. - template - static NativeT GetFirstElement(const Literal& literal) { - return literal.GetFirstElement(); - } - - // As Get(), but determines the correct type and converts the value - // into text. - static string GetAsString(const Literal& literal, - tensorflow::gtl::ArraySlice multi_index) { - return literal.GetAsString(multi_index); - } - - // Returns an identity matrix (rank 2) with the given row and column count. - template - static std::unique_ptr MakeIdentityR2(int64 size) { - return Literal::MakeIdentityR2(size); - } - - // Returns a tuple literal composed of given literals. - static std::unique_ptr MakeTuple( - tensorflow::gtl::ArraySlice elements) { - return Literal::MakeTuple(elements); - } - - // Validates that the data payload of the literal matches the literal shape; - // if it does not, an appropriate status is returned. - static tensorflow::Status ValidateLiteral(const Literal& literal) { - return literal.ValidateLiteral(); - } - - // Returns a string representation of the literal value. - static string ToString(const Literal& literal) { return literal.ToString(); } - - // Invokes the "per cell" callback for each element in the provided - // literal with the element's indices and a string representation of - // the element's value. - // - // This function is useful if you want a polymorphic representation - // of the tensor's elements (turning it to a string for something - // like representation in a protobuf). - static void EachCellAsString( - const Literal& literal, - const std::function indices, - const string& value)>& per_cell) { - literal.EachCellAsString(per_cell); - } - - template - static void EachCell( - const Literal& literal, - std::function indices, - NativeT value)> - per_cell) { - literal.EachCell(per_cell); - } - - // Templated methods which populate the given repeated field in the Literal - // proto with the given value(s). The Shape field of the Literal proto is set - // to match the array dimensions and type. Examples: - // - // // Populate with floats. - // Array2D float_values = ... - // PopulateR2FromArray2D(values, literal); - // - // // Populate with int32s. - // PopulateR2({{1, 2}, {3, 4}}, literal); - // - template - static void PopulateR0(NativeT values, Literal* literal) { - literal->PopulateR0(values); - } - - template - static void PopulateR1(tensorflow::gtl::ArraySlice values, - Literal* literal) { - literal->PopulateR1(values); - } - - static void PopulateR1(const tensorflow::core::Bitmap& values, - Literal* literal) { - literal->PopulateR1(values); - } - - template - static void PopulateR2( - std::initializer_list> values, - Literal* literal) { - literal->PopulateR2(values); - } - - template - static void PopulateR2WithLayout( - std::initializer_list> values, - const Layout& layout, Literal* literal) { - literal->PopulateR2WithLayout(values, layout); - } - - template - static void PopulateR2FromArray2D(const Array2D& values, - Literal* literal) { - literal->PopulateR2FromArray2D(values); - } - - template - static void PopulateR2FromArray2DWithLayout(const Array2D& values, - const Layout& layout, - Literal* literal) { - literal->PopulateR2FromArray2DWithLayout(values, layout); - } - - template - static void PopulateR3FromArray3D(const Array3D& values, - Literal* literal) { - literal->PopulateR3FromArray3D(values); - } - - template - static void PopulateR3FromArray3DWithLayout(const Array3D& values, - const Layout& layout, - Literal* literal) { - literal->PopulateR3FromArray3DWithLayout(values, layout); - } - - template - static void PopulateR4FromArray4D(const Array4D& values, - Literal* literal) { - literal->PopulateR4FromArray4D(values); - } - - template - static void PopulateR4FromArray4DWithLayout(const Array4D& values, - const Layout& layout, - Literal* literal) { - literal->PopulateR4FromArray4DWithLayout(values, layout); - } - - // Populates literal values by calling the generator function for every cell - // in the literal object. - template - static Status Populate( - Literal* literal, - const std::function indexes)>& - generator) { - return literal->Populate(generator); - } - - // Creates a Literal of the given dimensions with all elements set to the - // given value. - template - static void PopulateWithValue(NativeT value, - tensorflow::gtl::ArraySlice dimensions, - Literal* literal) { - return literal->PopulateWithValue(value, dimensions); - } - - // Returns a pointer to the underlying vector containing the array data. Use - // with care. - static const void* InternalData(const Literal& literal) { - return literal.InternalData(); - } - - static void* MutableInternalData(Literal* literal) { - return literal->MutableInternalData(); - } - - // Allocates space in the underlying vector of the literal sufficient to hold - // num_elements of the literal's primitive type. Values in the vector are set - // to zero. num_elements must equal the number of elements in the literals - // shape. - static void Reserve(int64 num_elements, Literal* literal) { - literal->Reserve(num_elements); - } - - // Allocates space in the underlying vector of the literal sufficient to hold - // num_elements of the literal's primitive type and sets each element in the - // literal to the given value. num_elements must equal the number of elements - // in the literals shape. - template - static void Resize(int64 num_elements, NativeT value, Literal* literal) { - literal->Resize(num_elements, value); - } - - // Returns true if the two given literals have the same shape and - // values. Layout is not considered in the comparison. - static bool Equal(const Literal& literal1, const Literal& literal2) { - return literal1.Equal(literal2); - } - - // Returns whether every element in the given literal is equal to value. - // - // value is an int8 because we expect this to be called with small - // compile-time constants (0, -1, etc.) and so that whatever value you pass - // can be represented exactly by floating-point types as small as 16 bits. - // - // If value doesn't fit in literal's type, returns false. Values of 1/0 are - // considered equal to true/false; other values are not considered equal to - // true. - static bool IsAll(const Literal& literal, int8 value) { - return literal.IsAll(value); - } - - // Like IsAll(const Literal&, int8), except we check whether the literal is - // equal to a particular floating-point number. - // - // If the literal is not a floating-point value, this always returns false. - // - // This casts value to the type of literal, then compares using ==. The usual - // admonishments about floating-point equality checks apply. We expect you to - // use this to check for values that can be expressed precisely as a float, - // e.g. -0.5. - static bool IsAllFloat(const Literal& literal, float value) { - return literal.IsAllFloat(value); - } - - // Returns whether the literal is zero at the specified index. The literal - // must be an array. - static bool IsZero(const Literal& literal, - tensorflow::gtl::ArraySlice indices) { - return literal.IsZero(indices); - } - - TF_DISALLOW_COPY_AND_ASSIGN(LiteralUtil); -}; - // Declarations of template specializations for GetArraySlice and // GetMutableArraySlice. The specializations map native type to XLA primitive // type. @@ -1759,27 +1225,6 @@ void Literal::PopulateWithValue(NativeT value, Resize(ShapeUtil::ElementsIn(shape()), value); } -template -std::unique_ptr Literal::Convert() const { - const Shape& this_shape = shape(); - auto result_literal = MakeUnique(); - Shape* result_shape = result_literal->mutable_shape(); - *result_shape = this_shape; - result_shape->set_element_type( - primitive_util::NativeToPrimitiveType()); - result_literal->Reserve(ShapeUtil::ElementsIn(*result_shape)); - tensorflow::gtl::ArraySlice src_data = - GetArraySlice(); - tensorflow::gtl::MutableArraySlice dest_data = - result_literal->GetMutableArraySlice(); - int64 num_elements = ShapeUtil::ElementsIn(this_shape); - - for (int64 i = 0; i < num_elements; ++i) { - dest_data[i] = static_cast(src_data[i]); - } - return result_literal; -} - template /* static */ std::unique_ptr Literal::CreateFullWithMonotonicDim0MajorLayout( diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index 8d4a75d7affebd3ee39702cb1226ee52aff09691..6c3648e1e07306dbef5a7c2b37e3d7d873206a5b 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -72,11 +72,11 @@ class LiteralUtilTest : public ::testing::Test { layout_r4_dim0minor_ = LayoutUtil::MakeLayout({0, 1, 2, 3}); literal_r4_2x2x3x3_dim0major_ = - LiteralUtil::CreateR4FromArray4DWithLayout(arr4d, - layout_r4_dim0major_); + Literal::CreateR4FromArray4DWithLayout(arr4d, + layout_r4_dim0major_); literal_r4_2x2x3x3_dim0minor_ = - LiteralUtil::CreateR4FromArray4DWithLayout(arr4d, - layout_r4_dim0minor_); + Literal::CreateR4FromArray4DWithLayout(arr4d, + layout_r4_dim0minor_); } Layout layout_r2_dim0major_; @@ -90,43 +90,42 @@ class LiteralUtilTest : public ::testing::Test { }; TEST_F(LiteralUtilTest, LiteralScalarToString) { - auto true_lit = LiteralUtil::CreateR0(true); - ASSERT_EQ("true", LiteralUtil::ToString(*true_lit)); + auto true_lit = Literal::CreateR0(true); + ASSERT_EQ("true", true_lit->ToString()); - auto false_lit = LiteralUtil::CreateR0(false); - ASSERT_EQ("false", LiteralUtil::ToString(*false_lit)); + auto false_lit = Literal::CreateR0(false); + ASSERT_EQ("false", false_lit->ToString()); - auto u32_lit = LiteralUtil::CreateR0(42); - ASSERT_EQ("42", LiteralUtil::ToString(*u32_lit)); + auto u32_lit = Literal::CreateR0(42); + ASSERT_EQ("42", u32_lit->ToString()); - auto s32_lit = LiteralUtil::CreateR0(-999); - ASSERT_EQ("-999", LiteralUtil::ToString(*s32_lit)); + auto s32_lit = Literal::CreateR0(-999); + ASSERT_EQ("-999", s32_lit->ToString()); - auto f32_lit = LiteralUtil::CreateR0(3.14f); - ASSERT_EQ("3.14", LiteralUtil::ToString(*f32_lit)); + auto f32_lit = Literal::CreateR0(3.14f); + ASSERT_EQ("3.14", f32_lit->ToString()); - auto f16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); - ASSERT_EQ("0.5", LiteralUtil::ToString(*f16_lit)); + auto f16_lit = Literal::CreateR0(static_cast(0.5f)); + ASSERT_EQ("0.5", f16_lit->ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { - auto pred_vec = LiteralUtil::CreateR1({true, false, true}); - ASSERT_EQ("{101}", LiteralUtil::ToString(*pred_vec)); + auto pred_vec = Literal::CreateR1({true, false, true}); + ASSERT_EQ("{101}", pred_vec->ToString()); } TEST_F(LiteralUtilTest, R2ToString) { - const auto literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}); + const auto literal = Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}); const string expected = R"(s32[3,2] { { 1, 2 }, { 3, 4 }, { 5, 6 }, })"; - ASSERT_EQ(expected, LiteralUtil::ToString(*literal)); + ASSERT_EQ(expected, literal->ToString()); } TEST_F(LiteralUtilTest, R3ToString) { - const auto literal = - LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}}); + const auto literal = Literal::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}}); const string expected = R"(s32[3,2,1] { { { 1 }, { 2 } }, @@ -135,13 +134,13 @@ TEST_F(LiteralUtilTest, R3ToString) { { { 5 }, { 6 } } })"; - ASSERT_EQ(expected, LiteralUtil::ToString(*literal)); + ASSERT_EQ(expected, literal->ToString()); } TEST_F(LiteralUtilTest, TupleToString) { - auto scalar = LiteralUtil::CreateR0(1.0); - auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); + auto scalar = Literal::CreateR0(1.0); + auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); const string expected = R"((f32[], f32[2,2]) ( 1, f32[2,2] { @@ -149,7 +148,7 @@ f32[2,2] { { 3, 4 }, }, ))"; - ASSERT_EQ(expected, LiteralUtil::ToString(*tuple)); + ASSERT_EQ(expected, tuple->ToString()); } TEST_F(LiteralUtilTest, CreateR3FromArray3d) { @@ -164,9 +163,9 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) { }); // clang-format on - auto literal = LiteralUtil::CreateR3FromArray3D(array_3d); + auto literal = Literal::CreateR3FromArray3D(array_3d); EXPECT_THAT(literal->shape().dimensions(), ElementsAre(2, 3, 2)); - string result = LiteralUtil::ToString(*literal); + string result = literal->ToString(); const string expected = R"(f32[2,3,2] { { { 1, 2 }, { 3, 4 }, @@ -180,14 +179,14 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) { TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { // clang-format off - auto literal = LiteralUtil::CreateR4Projected({ + auto literal = Literal::CreateR4Projected({ {1, 2}, {1001, 1002}, {2001, 2002}, }, /*projection_p=*/1, /*projection_z=*/2); // clang-format on EXPECT_THAT(literal->shape().dimensions(), ElementsAre(1, 2, 3, 2)); - string result = LiteralUtil::ToString(*literal); + string result = literal->ToString(); const string expected = R"(f32[1,2,3,2] { { // i0=0 { // i1=0 @@ -208,7 +207,7 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) { EXPECT_THAT(literal_r4_2x2x3x3_dim0major_->shape().dimensions(), ElementsAre(2, 2, 3, 3)); - string result = LiteralUtil::ToString(*literal_r4_2x2x3x3_dim0major_); + string result = literal_r4_2x2x3x3_dim0major_->ToString(); const string expected = R"(f32[2,2,3,3] { { // i0=0 { // i1=0 @@ -240,14 +239,13 @@ TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) { TEST_F(LiteralUtilTest, EachCellR2F32) { // clang-format off - auto literal = LiteralUtil::CreateR2({ + auto literal = Literal::CreateR2({ {3.1f, 4.2f}, {9.3f, 12.4f}, }); // clang-format on std::vector> seen; - LiteralUtil::EachCellAsString( - *literal, + literal->EachCellAsString( [&seen](tensorflow::gtl::ArraySlice indices, const string& value) { seen.emplace_back(indices[0], indices[1], value); }); @@ -259,176 +257,171 @@ TEST_F(LiteralUtilTest, EachCellR2F32) { } TEST_F(LiteralUtilTest, ScalarEquality) { - // Test LiteralUtil::Equal with scalars. - auto f32_42 = LiteralUtil::CreateR0(42.0); - auto f32_42_clone = LiteralUtil::CreateR0(42.0); + // Test Literal::Equal with scalars. + auto f32_42 = Literal::CreateR0(42.0); + auto f32_42_clone = Literal::CreateR0(42.0); - EXPECT_TRUE(LiteralUtil::Equal(*f32_42, *f32_42)); - EXPECT_TRUE(LiteralUtil::Equal(*f32_42, *f32_42_clone)); + EXPECT_TRUE(f32_42->Equal(*f32_42)); + EXPECT_TRUE(f32_42->Equal(*f32_42_clone)); - auto f32_123 = LiteralUtil::CreateR0(123.0); - EXPECT_FALSE(LiteralUtil::Equal(*f32_42, *f32_123)); + auto f32_123 = Literal::CreateR0(123.0); + EXPECT_FALSE(f32_42->Equal(*f32_123)); - auto f64_42 = LiteralUtil::CreateR0(42.0); - EXPECT_FALSE(LiteralUtil::Equal(*f32_42, *f64_42)); + auto f64_42 = Literal::CreateR0(42.0); + EXPECT_FALSE(f32_42->Equal(*f64_42)); } TEST_F(LiteralUtilTest, NonScalarEquality) { - // Test LiteralUtil::Equal with nonscalars. - auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto matrix_clone = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto matrix_different = - LiteralUtil::CreateR2({{4.0, 3.0}, {1.0, 2.0}}); - auto vector_literal = LiteralUtil::CreateR1({1.0, 2.0, 3.0, 4.0}); - auto scalar = LiteralUtil::CreateR0(1.0); - - EXPECT_TRUE(LiteralUtil::Equal(*matrix, *matrix)); - EXPECT_TRUE(LiteralUtil::Equal(*matrix, *matrix_clone)); - EXPECT_FALSE(LiteralUtil::Equal(*matrix, *matrix_different)); - EXPECT_FALSE(LiteralUtil::Equal(*matrix, *vector_literal)); - EXPECT_FALSE(LiteralUtil::Equal(*matrix, *scalar)); + // Test Literal::Equal with nonscalars. + auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto matrix_clone = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto matrix_different = Literal::CreateR2({{4.0, 3.0}, {1.0, 2.0}}); + auto vector_literal = Literal::CreateR1({1.0, 2.0, 3.0, 4.0}); + auto scalar = Literal::CreateR0(1.0); + + EXPECT_TRUE(matrix->Equal(*matrix)); + EXPECT_TRUE(matrix->Equal(*matrix_clone)); + EXPECT_FALSE(matrix->Equal(*matrix_different)); + EXPECT_FALSE(matrix->Equal(*vector_literal)); + EXPECT_FALSE(matrix->Equal(*scalar)); } TEST_F(LiteralUtilTest, DifferentLayoutEquality) { - // Test LiteralUtil::Equal with literals which have different layouts. + // Test Literal::Equal with literals which have different layouts. auto colmajor = MakeUnique(); *colmajor->mutable_shape() = ShapeUtil::MakeShape(F32, {2, 2}); *colmajor->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); - LiteralUtil::Reserve(4, colmajor.get()); - LiteralUtil::Set(colmajor.get(), {0, 0}, 1.0); - LiteralUtil::Set(colmajor.get(), {0, 1}, 2.0); - LiteralUtil::Set(colmajor.get(), {1, 0}, 3.0); - LiteralUtil::Set(colmajor.get(), {1, 1}, 4.0); + colmajor->Reserve(4); + colmajor->Set({0, 0}, 1.0); + colmajor->Set({0, 1}, 2.0); + colmajor->Set({1, 0}, 3.0); + colmajor->Set({1, 1}, 4.0); auto rowmajor = MakeUnique(); *rowmajor->mutable_shape() = ShapeUtil::MakeShape(F32, {2, 2}); *rowmajor->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0}); - LiteralUtil::Reserve(4, rowmajor.get()); - LiteralUtil::Set(rowmajor.get(), {0, 0}, 1.0); - LiteralUtil::Set(rowmajor.get(), {0, 1}, 2.0); - LiteralUtil::Set(rowmajor.get(), {1, 0}, 3.0); - LiteralUtil::Set(rowmajor.get(), {1, 1}, 4.0); + rowmajor->Reserve(4); + rowmajor->Set({0, 0}, 1.0); + rowmajor->Set({0, 1}, 2.0); + rowmajor->Set({1, 0}, 3.0); + rowmajor->Set({1, 1}, 4.0); - EXPECT_TRUE(LiteralUtil::Equal(*rowmajor, *colmajor)); + EXPECT_TRUE(rowmajor->Equal(*colmajor)); } TEST_F(LiteralUtilTest, TupleEquality) { - // Test LiteralUtil::Equal with tuples. - auto scalar = LiteralUtil::CreateR0(1.0); - auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple1 = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); + // Test Literal::Equal with tuples. + auto scalar = Literal::CreateR0(1.0); + auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto tuple1 = Literal::MakeTuple({scalar.get(), matrix.get()}); // Tuple with the same elements. One element is shared with the original // tuple, the other is a clone of the element in the original tuple. - auto scalar_clone = LiteralUtil::CreateR0(1.0); - auto tuple2 = LiteralUtil::MakeTuple({scalar_clone.get(), matrix.get()}); - EXPECT_TRUE(LiteralUtil::Equal(*tuple1, *tuple2)); + auto scalar_clone = Literal::CreateR0(1.0); + auto tuple2 = Literal::MakeTuple({scalar_clone.get(), matrix.get()}); + EXPECT_TRUE(tuple1->Equal(*tuple2)); // Tuple with elements reversed. - auto reversed_tuple = LiteralUtil::MakeTuple({matrix.get(), scalar.get()}); - EXPECT_FALSE(LiteralUtil::Equal(*tuple1, *reversed_tuple)); + auto reversed_tuple = Literal::MakeTuple({matrix.get(), scalar.get()}); + EXPECT_FALSE(tuple1->Equal(*reversed_tuple)); // Tuple with different value. - auto scalar_42 = LiteralUtil::CreateR0(42.0); - auto different_tuple = - LiteralUtil::MakeTuple({scalar_42.get(), matrix.get()}); - EXPECT_FALSE(LiteralUtil::Equal(*tuple1, *different_tuple)); + auto scalar_42 = Literal::CreateR0(42.0); + auto different_tuple = Literal::MakeTuple({scalar_42.get(), matrix.get()}); + EXPECT_FALSE(tuple1->Equal(*different_tuple)); } TEST_F(LiteralUtilTest, IsAllTuple) { - auto element1 = LiteralUtil::CreateR0(0.0); - auto element2 = LiteralUtil::CreateR2({{0.0, 0.0}, {0.0, 0.0}}); - auto tuple = LiteralUtil::MakeTuple({element1.get(), element1.get()}); + auto element1 = Literal::CreateR0(0.0); + auto element2 = Literal::CreateR2({{0.0, 0.0}, {0.0, 0.0}}); + auto tuple = Literal::MakeTuple({element1.get(), element1.get()}); // Tuples should always return false for IsAll. - EXPECT_FALSE(LiteralUtil::IsAll(*tuple, 0)); - EXPECT_FALSE(LiteralUtil::IsAll(*tuple, 1)); + EXPECT_FALSE(tuple->IsAll(0)); + EXPECT_FALSE(tuple->IsAll(1)); +} + +// Verifies that CreateFromShape works for tuples. +TEST_F(LiteralUtilTest, CreateFromShapeTuple) { + auto scalar = Literal::CreateR0(0.0); + auto matrix = Literal::CreateR2({{0, 0}, {0, 0}}); + auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); + + auto x = Literal::CreateFromShape(tuple->shape()); + EXPECT_TRUE(tuple->Equal(*x)); } TEST_F(LiteralUtilTest, IsAll) { - EXPECT_TRUE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(false), 0)); - EXPECT_TRUE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(true), 1)); - EXPECT_FALSE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(false), 1)); - EXPECT_FALSE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(false), 2)); - EXPECT_FALSE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(true), 0)); - EXPECT_FALSE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(true), 2)); - EXPECT_FALSE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(true), -1)); + EXPECT_TRUE(Literal::CreateR0(false)->IsAll(0)); + EXPECT_TRUE(Literal::CreateR0(true)->IsAll(1)); + EXPECT_FALSE(Literal::CreateR0(false)->IsAll(1)); + EXPECT_FALSE(Literal::CreateR0(false)->IsAll(2)); + EXPECT_FALSE(Literal::CreateR0(true)->IsAll(0)); + EXPECT_FALSE(Literal::CreateR0(true)->IsAll(2)); + EXPECT_FALSE(Literal::CreateR0(true)->IsAll(-1)); // We shouldn't reinterpret int8_min as an unsigned type and then decide that // it is equal to 255. auto int8_min = std::numeric_limits::min(); - EXPECT_FALSE( - LiteralUtil::IsAll(*LiteralUtil::CreateR0(255), int8_min)); + EXPECT_FALSE(Literal::CreateR0(255)->IsAll(int8_min)); - EXPECT_TRUE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(42.0), 42)); - EXPECT_FALSE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(42.0001), 42)); + EXPECT_TRUE(Literal::CreateR0(42.0)->IsAll(42)); + EXPECT_FALSE(Literal::CreateR0(42.0001)->IsAll(42)); - EXPECT_TRUE( - LiteralUtil::IsAll(*LiteralUtil::CreateR1({100, 100, 100}), 100)); - EXPECT_FALSE(LiteralUtil::IsAll( - *LiteralUtil::CreateR1({100, 100, 100.001}), 100)); + EXPECT_TRUE(Literal::CreateR1({100, 100, 100})->IsAll(100)); + EXPECT_FALSE(Literal::CreateR1({100, 100, 100.001})->IsAll(100)); - EXPECT_TRUE( - LiteralUtil::IsAll(*LiteralUtil::CreateR2({{8, 8}, {8, 8}}), 8)); - EXPECT_FALSE( - LiteralUtil::IsAll(*LiteralUtil::CreateR2({{8, 8}, {8, 9}}), 8)); - EXPECT_FALSE( - LiteralUtil::IsAll(*LiteralUtil::CreateR2({{9, 8}, {8, 8}}), 8)); + EXPECT_TRUE(Literal::CreateR2({{8, 8}, {8, 8}})->IsAll(8)); + EXPECT_FALSE(Literal::CreateR2({{8, 8}, {8, 9}})->IsAll(8)); + EXPECT_FALSE(Literal::CreateR2({{9, 8}, {8, 8}})->IsAll(8)); half h8(8.0f); half h9(9.0f); - EXPECT_TRUE( - LiteralUtil::IsAll(*LiteralUtil::CreateR2({{h8}, {h8}}), 8)); - EXPECT_FALSE( - LiteralUtil::IsAll(*LiteralUtil::CreateR2({{h8}, {h9}}), 8)); - EXPECT_FALSE( - LiteralUtil::IsAll(*LiteralUtil::CreateR2({{h9}, {h8}}), 8)); + EXPECT_TRUE(Literal::CreateR2({{h8}, {h8}})->IsAll(8)); + EXPECT_FALSE(Literal::CreateR2({{h8}, {h9}})->IsAll(8)); + EXPECT_FALSE(Literal::CreateR2({{h9}, {h8}})->IsAll(8)); auto uint64_max = std::numeric_limits::max(); - EXPECT_FALSE(LiteralUtil::IsAll( - *LiteralUtil::CreateR2( - {{uint64_max, uint64_max}, {uint64_max, uint64_max}}), - -1)); + EXPECT_FALSE(Literal::CreateR2( + {{uint64_max, uint64_max}, {uint64_max, uint64_max}}) + ->IsAll(-1)); } TEST_F(LiteralUtilTest, IsAllFloat) { // IsAllFloat always returns false when the literal is not floating-point. - EXPECT_FALSE(LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(false), 0)); - EXPECT_FALSE(LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(0), 0)); - EXPECT_FALSE(LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(0), 0)); - EXPECT_FALSE(LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(0), 0)); - - EXPECT_TRUE(LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(0), 0)); - EXPECT_TRUE(LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(.5), .5)); - EXPECT_TRUE(LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(-.5), -.5)); + EXPECT_FALSE(Literal::CreateR0(false)->IsAllFloat(0)); + EXPECT_FALSE(Literal::CreateR0(0)->IsAllFloat(0)); + EXPECT_FALSE(Literal::CreateR0(0)->IsAllFloat(0)); + EXPECT_FALSE(Literal::CreateR0(0)->IsAllFloat(0)); + + EXPECT_TRUE(Literal::CreateR0(0)->IsAllFloat(0)); + EXPECT_TRUE(Literal::CreateR0(.5)->IsAllFloat(.5)); + EXPECT_TRUE(Literal::CreateR0(-.5)->IsAllFloat(-.5)); + EXPECT_FALSE(Literal::CreateR0(-.5)->IsAllFloat(-.49)); EXPECT_FALSE( - LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(-.5), -.49)); - EXPECT_FALSE(LiteralUtil::IsAllFloat( - *LiteralUtil::CreateR2({{0, 0, 0}, {0, .1, 0}}), 0)); - EXPECT_TRUE(LiteralUtil::IsAllFloat( - *LiteralUtil::CreateR2({{.5, .5, .5}, {.5, .5, .5}}), .5)); - - EXPECT_TRUE(LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(0), 0)); - EXPECT_TRUE(LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(.5), .5)); + Literal::CreateR2({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); EXPECT_TRUE( - LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(-.5), -.5)); + Literal::CreateR2({{.5, .5, .5}, {.5, .5, .5}})->IsAllFloat(.5)); + + EXPECT_TRUE(Literal::CreateR0(0)->IsAllFloat(0)); + EXPECT_TRUE(Literal::CreateR0(.5)->IsAllFloat(.5)); + EXPECT_TRUE(Literal::CreateR0(-.5)->IsAllFloat(-.5)); + EXPECT_FALSE(Literal::CreateR0(-.5)->IsAllFloat(-.49)); EXPECT_FALSE( - LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(-.5), -.49)); - EXPECT_FALSE(LiteralUtil::IsAllFloat( - *LiteralUtil::CreateR2({{0, 0, 0}, {0, .1, 0}}), 0)); + Literal::CreateR2({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); } TEST_F(LiteralUtilTest, IsZero) { - auto scalar_zero = LiteralUtil::CreateR0(0.0f); - auto scalar_one = LiteralUtil::CreateR0(1.0f); - EXPECT_TRUE(LiteralUtil::IsZero(*scalar_zero, {})); - EXPECT_FALSE(LiteralUtil::IsZero(*scalar_one, {})); - - auto array = LiteralUtil::CreateR2({{1, 2, 0, 3}, {1, 0, 1, 2}}); - EXPECT_FALSE(LiteralUtil::IsZero(*array, {0, 1})); - EXPECT_TRUE(LiteralUtil::IsZero(*array, {0, 2})); - EXPECT_TRUE(LiteralUtil::IsZero(*array, {1, 1})); - EXPECT_FALSE(LiteralUtil::IsZero(*array, {1, 2})); + auto scalar_zero = Literal::CreateR0(0.0f); + auto scalar_one = Literal::CreateR0(1.0f); + EXPECT_TRUE(scalar_zero->IsZero({})); + EXPECT_FALSE(scalar_one->IsZero({})); + + auto array = Literal::CreateR2({{1, 2, 0, 3}, {1, 0, 1, 2}}); + EXPECT_FALSE(array->IsZero({0, 1})); + EXPECT_TRUE(array->IsZero({0, 2})); + EXPECT_TRUE(array->IsZero({1, 1})); + EXPECT_FALSE(array->IsZero({1, 2})); } template @@ -440,127 +433,122 @@ TYPED_TEST_CASE(LiteralUtilTestTemplated, TestedTypes); TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) { // Make a non-integer for floating point types. TypeParam half = TypeParam(1) / TypeParam(2); - auto data = LiteralUtil::CreateR2({{half, 2}, {3, 4}}); + auto data = Literal::CreateR2({{half, 2}, {3, 4}}); const Layout layout01 = LayoutUtil::MakeLayout({0, 1}); const Layout layout10 = LayoutUtil::MakeLayout({1, 0}); - auto data01 = LiteralUtil::Relayout(*data, layout01); + auto data01 = data->Relayout(layout01); EXPECT_TRUE(LayoutUtil::Equal(data01->shape().layout(), layout01)); - EXPECT_TRUE(LiteralUtil::Equal(*data, *data01)); + EXPECT_TRUE(data->Equal(*data01)); - auto data10 = LiteralUtil::Relayout(*data, layout10); + auto data10 = data->Relayout(layout10); EXPECT_TRUE(LayoutUtil::Equal(data10->shape().layout(), layout10)); - EXPECT_TRUE(LiteralUtil::Equal(*data, *data10)); + EXPECT_TRUE(data->Equal(*data10)); } TEST_F(LiteralUtilTest, ReshapeR0) { - auto original = LiteralUtil::CreateR0(1.7f); - auto reshape = - LiteralUtil::Reshape(*original, /*shape=*/{}).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*original, *reshape)); + auto original = Literal::CreateR0(1.7f); + auto reshape = original->Reshape(/*shape=*/{}).ConsumeValueOrDie(); + EXPECT_TRUE(original->Equal(*reshape)); } TEST_F(LiteralUtilTest, ReshapeR4) { // clang-format off // F32[1x3x2x4] - auto original = LiteralUtil::CreateR4WithLayout({{ + auto original = Literal::CreateR4WithLayout({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}, layout_r4_dim0major_); // F32[1x3x4x2] - auto expected = LiteralUtil::CreateR3WithLayout({ + auto expected = Literal::CreateR3WithLayout({ {{10, 11}, {12, 13}, {14, 15}, {16, 17}}, {{18, 19}, {20, 21}, {22, 23}, {24, 25}}, {{26, 27}, {28, 29}, {30, 31}, {32, 33}}, }, layout_r3_dim0major_); // clang-format on - auto reshape = LiteralUtil::Reshape(*original, {3, 4, 2}).ConsumeValueOrDie(); + auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*expected, *reshape)); + EXPECT_TRUE(expected->Equal(*reshape)); } TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) { // clang-format off // F32[1x3x2x4] - auto original = LiteralUtil::CreateR4WithLayout({{ + auto original = Literal::CreateR4WithLayout({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}, layout_r4_dim0minor_); // F32[1x3x4x2] - auto expected = LiteralUtil::CreateR3WithLayout({ + auto expected = Literal::CreateR3WithLayout({ {{10, 11}, {12, 13}, {14, 15}, {16, 17}}, {{18, 19}, {20, 21}, {22, 23}, {24, 25}}, {{26, 27}, {28, 29}, {30, 31}, {32, 33}}, }, layout_r3_dim0major_); // clang-format on - auto reshape = LiteralUtil::Reshape(*original, {3, 4, 2}).ConsumeValueOrDie(); + auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*expected, *reshape)); + EXPECT_TRUE(expected->Equal(*reshape)); } TEST_F(LiteralUtilTest, TransposeR0) { - auto original = LiteralUtil::CreateR0(1.7f); - auto reshape = LiteralUtil::Transpose(*original, /*permutation=*/{}); - EXPECT_TRUE(LiteralUtil::Equal(*original, *reshape)); + auto original = Literal::CreateR0(1.7f); + auto reshape = original->Transpose(/*permutation=*/{}); + EXPECT_TRUE(original->Equal(*reshape)); } TEST_F(LiteralUtilTest, TransposeR4) { // clang-format off // F32[1x3x2x4] - auto original = LiteralUtil::CreateR4({{ + auto original = Literal::CreateR4({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}); // clang-format on - auto reshape = - LiteralUtil::Transpose(*original, /*permutation=*/{2, 3, 0, 1}); - - LiteralUtil::EachCell( - *reshape, [&](tensorflow::gtl::ArraySlice indices, float value) { - EXPECT_EQ(value, - LiteralUtil::Get(*original, {indices[2], indices[3], - indices[0], indices[1]})); + auto reshape = original->Transpose(/*permutation=*/{2, 3, 0, 1}); + + reshape->EachCell( + [&](tensorflow::gtl::ArraySlice indices, float value) { + EXPECT_EQ(value, original->Get( + {indices[2], indices[3], indices[0], indices[1]})); }); } TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) { // Tests that using Relayout on an array is equivalent to creating it in the // target layout in the first place. - auto dim0minor_relaid_to_dim0major = LiteralUtil::Relayout( - *literal_r4_2x2x3x3_dim0minor_, layout_r4_dim0major_); - EXPECT_TRUE(LiteralUtil::Equal(*literal_r4_2x2x3x3_dim0major_, - *dim0minor_relaid_to_dim0major)); - - auto dim0major_relaid_to_dim0minor = LiteralUtil::Relayout( - *literal_r4_2x2x3x3_dim0major_, layout_r4_dim0minor_); - EXPECT_TRUE(LiteralUtil::Equal(*literal_r4_2x2x3x3_dim0minor_, - *dim0major_relaid_to_dim0minor)); + auto dim0minor_relaid_to_dim0major = + literal_r4_2x2x3x3_dim0minor_->Relayout(layout_r4_dim0major_); + EXPECT_TRUE( + literal_r4_2x2x3x3_dim0major_->Equal(*dim0minor_relaid_to_dim0major)); + + auto dim0major_relaid_to_dim0minor = + literal_r4_2x2x3x3_dim0major_->Relayout(layout_r4_dim0minor_); + EXPECT_TRUE( + literal_r4_2x2x3x3_dim0minor_->Equal(*dim0major_relaid_to_dim0minor)); } TEST_F(LiteralUtilTest, TestR2LinearLayout) { // Test expected memory layout of R2 dim0-minor (column-major) literal. - auto mat_dim0minor = LiteralUtil::CreateR2WithLayout( - {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_); + auto mat_dim0minor = Literal::CreateR2WithLayout({{1, 2, 3}, {4, 5, 6}}, + layout_r2_dim0minor_); EXPECT_EQ(mat_dim0minor->s32s_size(), 6); EXPECT_THAT(mat_dim0minor->s32s(), ElementsAre(1, 4, 2, 5, 3, 6)); // Test expected memory layout when using Relayout to row major. - auto relaid_mat_to_dim0major = - LiteralUtil::Relayout(*mat_dim0minor, layout_r2_dim0major_); + auto relaid_mat_to_dim0major = mat_dim0minor->Relayout(layout_r2_dim0major_); EXPECT_THAT(relaid_mat_to_dim0major->s32s(), ElementsAre(1, 2, 3, 4, 5, 6)); // Test expected memory layout of R2 created with dim0-major (row-major). - auto mat_dim0major = LiteralUtil::CreateR2WithLayout( - {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_); + auto mat_dim0major = Literal::CreateR2WithLayout({{1, 2, 3}, {4, 5, 6}}, + layout_r2_dim0major_); EXPECT_EQ(mat_dim0major->s32s_size(), 6); EXPECT_THAT(mat_dim0major->s32s(), ElementsAre(1, 2, 3, 4, 5, 6)); // Test expected memory layout when using Relayout to column major. - auto relaid_mat_to_dim0minor = - LiteralUtil::Relayout(*mat_dim0major, layout_r2_dim0minor_); + auto relaid_mat_to_dim0minor = mat_dim0major->Relayout(layout_r2_dim0minor_); EXPECT_THAT(relaid_mat_to_dim0minor->s32s(), ElementsAre(1, 4, 2, 5, 3, 6)); } @@ -578,8 +566,8 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) { {10, 11, 12}, }, }); // clang-format on - auto lit_dim0minor = LiteralUtil::CreateR3FromArray3DWithLayout( - arr3d, layout_r3_dim0minor_); + auto lit_dim0minor = + Literal::CreateR3FromArray3DWithLayout(arr3d, layout_r3_dim0minor_); EXPECT_EQ(lit_dim0minor->s32s_size(), 12); std::vector expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12}; @@ -587,122 +575,120 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) { testing::ElementsAreArray(expected_dim0minor)); // Test expected memory layout when using Relayout to row major. - auto relaid_lit_to_dim0major = - LiteralUtil::Relayout(*lit_dim0minor, layout_r3_dim0major_); + auto relaid_lit_to_dim0major = lit_dim0minor->Relayout(layout_r3_dim0major_); std::vector expected_dim0major{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; EXPECT_THAT(relaid_lit_to_dim0major->s32s(), testing::ElementsAreArray(expected_dim0major)); // Test expected memory layout of R3 created with dim0-major (row-major). - auto lit_dim0major = LiteralUtil::CreateR3FromArray3DWithLayout( - arr3d, layout_r3_dim0major_); + auto lit_dim0major = + Literal::CreateR3FromArray3DWithLayout(arr3d, layout_r3_dim0major_); EXPECT_EQ(lit_dim0major->s32s_size(), 12); EXPECT_THAT(lit_dim0major->s32s(), testing::ElementsAreArray(expected_dim0major)); // Test expected memory layout when using Relayout to column major. - auto relaid_lit_to_dim0minor = - LiteralUtil::Relayout(*lit_dim0major, layout_r3_dim0minor_); + auto relaid_lit_to_dim0minor = lit_dim0major->Relayout(layout_r3_dim0minor_); EXPECT_THAT(relaid_lit_to_dim0minor->s32s(), testing::ElementsAreArray(expected_dim0minor)); } TEST_F(LiteralUtilTest, SliceR0S32) { - auto input = LiteralUtil::CreateR0(1); - auto result = LiteralUtil::Slice(*input, {}, {}); - EXPECT_TRUE(LiteralUtil::Equal(*input, *result)); + auto input = Literal::CreateR0(1); + auto result = input->Slice({}, {}); + EXPECT_TRUE(input->Equal(*result)); } TEST_F(LiteralUtilTest, SliceR1F32) { - auto input = LiteralUtil::CreateR1({1.0, 2.0, 3.0, 4.0, 5.0}); - auto result = LiteralUtil::Slice(*input, {3}, {4}); - auto expected = LiteralUtil::CreateR1({4.0}); - EXPECT_TRUE(LiteralUtil::Equal(*expected, *result)); + auto input = Literal::CreateR1({1.0, 2.0, 3.0, 4.0, 5.0}); + auto result = input->Slice({3}, {4}); + auto expected = Literal::CreateR1({4.0}); + EXPECT_TRUE(expected->Equal(*result)); } TEST_F(LiteralUtilTest, SliceR2U32) { - auto input_3x4 = LiteralUtil::CreateR2( - {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); - auto result = LiteralUtil::Slice(*input_3x4, {0, 2}, {2, 4}); - auto expected = LiteralUtil::CreateR2({{3, 4}, {7, 8}}); - EXPECT_TRUE(LiteralUtil::Equal(*expected, *result)); + auto input_3x4 = + Literal::CreateR2({{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); + auto result = input_3x4->Slice({0, 2}, {2, 4}); + auto expected = Literal::CreateR2({{3, 4}, {7, 8}}); + EXPECT_TRUE(expected->Equal(*result)); } TEST_F(LiteralUtilTest, SliceR3U32Full) { - auto input_2x3x2 = LiteralUtil::CreateR3( + auto input_2x3x2 = Literal::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); - auto result = LiteralUtil::Slice(*input_2x3x2, {0, 0, 0}, {2, 3, 2}); - EXPECT_TRUE(LiteralUtil::Equal(*input_2x3x2, *result)); + auto result = input_2x3x2->Slice({0, 0, 0}, {2, 3, 2}); + EXPECT_TRUE(input_2x3x2->Equal(*result)); } TEST_F(LiteralUtilTest, PopulateR1S64) { Literal output; - LiteralUtil::PopulateR1({77}, &output); - auto expected = LiteralUtil::CreateR1({77}); - EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); + output.PopulateR1({77}); + auto expected = Literal::CreateR1({77}); + EXPECT_TRUE(output.Equal(*expected)); } TEST_F(LiteralUtilTest, PopulateR2U64) { Literal output; - LiteralUtil::PopulateR1({{77, 88}}, &output); - auto expected = LiteralUtil::CreateR1({{77, 88}}); - EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); + output.PopulateR1({{77, 88}}); + auto expected = Literal::CreateR1({{77, 88}}); + EXPECT_TRUE(output.Equal(*expected)); } TEST_F(LiteralUtilTest, PopulateWithValueR0F32) { Literal output; - LiteralUtil::PopulateWithValue(2.5f, {}, &output); - auto expected = LiteralUtil::CreateR0(2.5f); - EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); + output.PopulateWithValue(2.5f, {}); + auto expected = Literal::CreateR0(2.5f); + EXPECT_TRUE(output.Equal(*expected)); } TEST_F(LiteralUtilTest, PopulateWithValueR1S64) { Literal output; - LiteralUtil::PopulateWithValue(-7, {3}, &output); - auto expected = LiteralUtil::CreateR1({-7, -7, -7}); - EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); + output.PopulateWithValue(-7, {3}); + auto expected = Literal::CreateR1({-7, -7, -7}); + EXPECT_TRUE(output.Equal(*expected)); } TEST_F(LiteralUtilTest, PopulateWithValueR2U64) { Literal output; - LiteralUtil::PopulateWithValue(42, {2, 2}, &output); - auto expected = LiteralUtil::CreateR2({{42, 42}, {42, 42}}); - EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); + output.PopulateWithValue(42, {2, 2}); + auto expected = Literal::CreateR2({{42, 42}, {42, 42}}); + EXPECT_TRUE(output.Equal(*expected)); } TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { Literal output; half h(0.25f); - LiteralUtil::PopulateWithValue(h, {}, &output); - auto expected = LiteralUtil::CreateR0(h); - EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); + output.PopulateWithValue(h, {}); + auto expected = Literal::CreateR0(h); + EXPECT_TRUE(output.Equal(*expected)); } TEST_F(LiteralUtilTest, PopulateWithValueR1F16) { Literal output; half h(0.5f); - LiteralUtil::PopulateWithValue(h, {3}, &output); - auto expected = LiteralUtil::CreateR1({h, h, h}); - EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); + output.PopulateWithValue(h, {3}); + auto expected = Literal::CreateR1({h, h, h}); + EXPECT_TRUE(output.Equal(*expected)); } TEST_F(LiteralUtilTest, PopulateWithValueR2F16) { Literal output; half h(2.0f); - LiteralUtil::PopulateWithValue(h, {2, 2}, &output); - auto expected = LiteralUtil::CreateR2({{h, h}, {h, h}}); - EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); + output.PopulateWithValue(h, {2, 2}); + auto expected = Literal::CreateR2({{h, h}, {h, h}}); + EXPECT_TRUE(output.Equal(*expected)); } TEST_F(LiteralUtilTest, ReplicateR2U32) { - auto input = LiteralUtil::CreateR2( - {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); - auto output = LiteralUtil::Replicate(*input, 3); - auto expected = LiteralUtil::CreateR3( + auto input = + Literal::CreateR2({{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); + auto output = input->Replicate(3); + auto expected = Literal::CreateR3( {{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}}); - EXPECT_TRUE(LiteralUtil::Equal(*output, *expected)); + EXPECT_TRUE(output->Equal(*expected)); } TEST_F(LiteralUtilTest, Copy) { @@ -712,13 +698,13 @@ TEST_F(LiteralUtilTest, Copy) { for (const auto& layout : layouts) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), dimensions, layout); - auto blank = LiteralUtil::CreateFromShape(shape); - auto source = LiteralUtil::CreateFromShape(shape); + auto blank = Literal::CreateFromShape(shape); + auto source = Literal::CreateFromShape(shape); const int64 zero_base[] = {0, 0, 0, 0}; const int64 step[] = {1, 1, 1, 1}; uint32 seqnr = 0; auto init_proc = [&](const std::vector& indexes) { - LiteralUtil::Set(source.get(), indexes, ++seqnr); + source->Set(indexes, ++seqnr); return true; }; @@ -729,8 +715,7 @@ TEST_F(LiteralUtilTest, Copy) { const int64 dest_base[] = {6, 4, 12, 2}; const int64 copy_size[] = {7, 8, 11, 9}; - TF_EXPECT_OK(LiteralUtil::Copy(*source, src_base, blank.get(), dest_base, - copy_size)); + TF_EXPECT_OK(blank->Copy(*source, src_base, dest_base, copy_size)); std::vector source_indexes(TF_ARRAYSIZE(dimensions), 0); std::vector blank_indexes(TF_ARRAYSIZE(dimensions), 0); bool matched = true; @@ -741,9 +726,8 @@ TEST_F(LiteralUtilTest, Copy) { std::copy(indexes.begin(), indexes.end(), blank_indexes.begin()); std::transform(blank_indexes.begin(), blank_indexes.end(), dest_base, blank_indexes.begin(), std::plus()); - auto bval = LiteralUtil::Get(*blank, blank_indexes); - matched = (bval != 0 && - bval == LiteralUtil::Get(*source, source_indexes)); + auto bval = blank->Get(blank_indexes); + matched = (bval != 0 && bval == source->Get(source_indexes)); return matched; }; ShapeUtil::ForEachIndex(source->shape(), zero_base, copy_size, step, @@ -753,25 +737,25 @@ TEST_F(LiteralUtilTest, Copy) { } TEST_F(LiteralUtilTest, CopyScalars) { - auto zero = LiteralUtil::CreateR0(0); - auto nine = LiteralUtil::CreateR0(9); - TF_EXPECT_OK(LiteralUtil::Copy(*nine, {}, zero.get(), {}, {})); - EXPECT_TRUE(LiteralUtil::Equal(*zero, *nine)); - - auto vect = LiteralUtil::CreateR1({3, 4, 9, 12, 5, 17, 21}); - TF_EXPECT_OK(LiteralUtil::Copy(*vect, {5}, zero.get(), {}, {})); - EXPECT_EQ(LiteralUtil::Get(*zero, {}), 17); - TF_EXPECT_OK(LiteralUtil::Copy(*zero, {}, vect.get(), {4}, {})); - EXPECT_EQ(LiteralUtil::Get(*vect, {4}), 17); + auto zero = Literal::CreateR0(0); + auto nine = Literal::CreateR0(9); + TF_EXPECT_OK(zero->Copy(*nine, {}, {}, {})); + EXPECT_TRUE(zero->Equal(*nine)); + + auto vect = Literal::CreateR1({3, 4, 9, 12, 5, 17, 21}); + TF_EXPECT_OK(zero->Copy(*vect, {5}, {}, {})); + EXPECT_EQ(zero->Get({}), 17); + TF_EXPECT_OK(vect->Copy(*zero, {}, {4}, {})); + EXPECT_EQ(vect->Get({4}), 17); } TEST_F(LiteralUtilTest, F16) { // Verify that the internal data views are consistent and that they // are in little endian format // TODO - modify if we make the data format machine endianess dependent - auto m1 = LiteralUtil::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2})); + auto m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2})); Literal* l1 = m1.get(); - const char* d1 = static_cast(LiteralUtil::InternalData(*l1)); + const char* d1 = static_cast(l1->InternalData()); EXPECT_EQ(d1[0], 0); EXPECT_EQ(d1[1], 0); EXPECT_EQ(d1[2], 0); @@ -780,14 +764,13 @@ TEST_F(LiteralUtilTest, F16) { EXPECT_EQ(d1[5], 0); EXPECT_EQ(d1[6], 0); EXPECT_EQ(d1[7], 0); - EXPECT_EQ(LiteralUtil::InternalData(*l1), - LiteralUtil::MutableInternalData(l1)); + EXPECT_EQ(l1->InternalData(), l1->MutableInternalData()); half h1(1.0f); half h2(2.0f); - auto m2 = LiteralUtil::CreateR2({{h1, h2}, {h2, h1}}); + auto m2 = Literal::CreateR2({{h1, h2}, {h2, h1}}); Literal* l2 = m2.get(); - const char* d2 = static_cast(LiteralUtil::InternalData(*l2)); + const char* d2 = static_cast(l2->InternalData()); EXPECT_EQ(d2[0], 0); EXPECT_EQ(d2[1], 0x3C); EXPECT_EQ(d2[2], 0); @@ -796,8 +779,7 @@ TEST_F(LiteralUtilTest, F16) { EXPECT_EQ(d2[5], 0x40); EXPECT_EQ(d2[6], 0); EXPECT_EQ(d2[7], 0x3C); - EXPECT_EQ(LiteralUtil::InternalData(*l2), - LiteralUtil::MutableInternalData(l2)); + EXPECT_EQ(l2->InternalData(), l2->MutableInternalData()); } TEST_F(LiteralUtilTest, Populate) { @@ -818,19 +800,19 @@ TEST_F(LiteralUtilTest, Populate) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); - auto literal = LiteralUtil::CreateFromShape(shape); + auto literal = Literal::CreateFromShape(shape); auto generator = [&](tensorflow::gtl::ArraySlice indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. - return LiteralUtil::LinearIndex(*literal, indexes) + 17; + return literal->LinearIndex(indexes) + 17; }; - TF_EXPECT_OK(LiteralUtil::Populate(literal.get(), generator)); + TF_EXPECT_OK(literal->Populate(generator)); std::vector zero_base(data.dimensions.size(), 0); std::vector step(data.dimensions.size(), 1); bool matched = true; auto check_function = [&](const std::vector& indexes) { - auto value = LiteralUtil::Get(*literal, indexes); + auto value = literal->Get(indexes); matched = matched && (value == generator(indexes)); return matched; }; @@ -842,65 +824,66 @@ TEST_F(LiteralUtilTest, Populate) { TEST_F(LiteralUtilTest, ConvertR4) { // clang-format off - auto original = LiteralUtil::CreateR4WithLayout({{ + auto original = Literal::CreateR4WithLayout({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}, layout_r4_dim0major_); - auto expected = LiteralUtil::CreateR4WithLayout({{ + auto expected = Literal::CreateR4WithLayout({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}, layout_r4_dim0major_); // clang-format on - auto converted = LiteralUtil::Convert(*original); + TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr converted, + original->Convert(U32)); - EXPECT_TRUE(LiteralUtil::Equal(*expected, *converted)); + EXPECT_TRUE(expected->Equal(*converted)); } TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { // clang-format off - auto s8 = LiteralUtil::CreateR4WithLayout({{ + auto s8 = Literal::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); - auto s32 = LiteralUtil::CreateR4WithLayout({{ + auto s32 = Literal::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); - auto u32 = LiteralUtil::CreateR4WithLayout({{ + auto u32 = Literal::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); - auto s64 = LiteralUtil::CreateR4WithLayout({{ + auto s64 = Literal::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); - auto u64 = LiteralUtil::CreateR4WithLayout({{ + auto u64 = Literal::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); - auto pred = LiteralUtil::CreateR4WithLayout({{ + auto pred = Literal::CreateR4WithLayout({{ {{true, false, true, false}, {false, true, false, true}}, {{false, true, false, true}, {true, false, true, false}}, {{true, false, true, false}, {false, true, false, true}}, }}, layout_r4_dim0major_); - auto int32_pred = LiteralUtil::CreateR4WithLayout({{ + auto int32_pred = Literal::CreateR4WithLayout({{ {{1, 0, 1, 0}, {0, 1, 0, 1}}, {{0, 1, 0, 1}, {1, 0, 1, 0}}, {{1, 0, 1, 0}, {0, 1, 0, 1}}, }}, layout_r4_dim0major_); - auto f32 = LiteralUtil::CreateR4WithLayout({{ + auto f32 = Literal::CreateR4WithLayout({{ {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}}, {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}}, {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}}, }}, layout_r4_dim0major_); - auto f64 = LiteralUtil::CreateR4WithLayout({{ + auto f64 = Literal::CreateR4WithLayout({{ {{10.0, 0.0, 12.0, 0.0}, {0.0, 15.0, 0.0, 17.0}}, {{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}}, {{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}}, @@ -908,40 +891,40 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { // clang-format on std::unique_ptr conv; - conv = LiteralUtil::ConvertIfSrcTypeMatches(*s8, U32).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*conv, *u32)); + conv = s8->Convert(U32).ConsumeValueOrDie(); + EXPECT_TRUE(conv->Equal(*u32)); - conv = LiteralUtil::ConvertIfSrcTypeMatches(*s8, S32).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*conv, *s32)); + conv = s8->Convert(S32).ConsumeValueOrDie(); + EXPECT_TRUE(conv->Equal(*s32)); - conv = LiteralUtil::ConvertIfSrcTypeMatches(*s8, U64).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*conv, *u64)); + conv = s8->Convert(U64).ConsumeValueOrDie(); + EXPECT_TRUE(conv->Equal(*u64)); - conv = LiteralUtil::ConvertIfSrcTypeMatches(*s8, S64).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*conv, *s64)); + conv = s8->Convert(S64).ConsumeValueOrDie(); + EXPECT_TRUE(conv->Equal(*s64)); - conv = LiteralUtil::ConvertIfSrcTypeMatches(*s8, PRED).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*conv, *pred)); + conv = s8->Convert(PRED).ConsumeValueOrDie(); + EXPECT_TRUE(conv->Equal(*pred)); - conv = LiteralUtil::ConvertIfSrcTypeMatches(*pred, S32).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*conv, *int32_pred)); + conv = pred->Convert(S32).ConsumeValueOrDie(); + EXPECT_TRUE(conv->Equal(*int32_pred)); - conv = LiteralUtil::ConvertIfSrcTypeMatches(*f32, S32).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*conv, *s32)); + conv = f32->Convert(S32).ConsumeValueOrDie(); + EXPECT_TRUE(conv->Equal(*s32)); - conv = LiteralUtil::ConvertIfSrcTypeMatches(*f64, S32).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*conv, *s32)); + conv = f64->Convert(S32).ConsumeValueOrDie(); + EXPECT_TRUE(conv->Equal(*s32)); - conv = LiteralUtil::ConvertIfSrcTypeMatches(*s32, F32).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*conv, *f32)); + conv = s32->Convert(F32).ConsumeValueOrDie(); + EXPECT_TRUE(conv->Equal(*f32)); - EXPECT_EQ(LiteralUtil::ConvertIfSrcTypeMatches(*s32, TUPLE).status().code(), + EXPECT_EQ(s32->Convert(TUPLE).status().code(), tensorflow::error::INVALID_ARGUMENT); - EXPECT_EQ(LiteralUtil::ConvertIfSrcTypeMatches(*s32, F16).status().code(), + EXPECT_EQ(s32->Convert(F16).status().code(), tensorflow::error::INVALID_ARGUMENT); - EXPECT_EQ(LiteralUtil::ConvertIfSrcTypeMatches(*s32, S16).status().code(), + EXPECT_EQ(s32->Convert(S16).status().code(), tensorflow::error::INVALID_ARGUMENT); - EXPECT_EQ(LiteralUtil::ConvertIfSrcTypeMatches(*s32, U16).status().code(), + EXPECT_EQ(s32->Convert(U16).status().code(), tensorflow::error::INVALID_ARGUMENT); } @@ -996,9 +979,7 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { half h1(1.0f); half h2(2.0f); - const char half_vals[8] = { - 0x00, 0x3C, 0x00, 0x40, 0x00, 0x40, 0x00, 0x3C - }; + const char half_vals[8] = {0x00, 0x3C, 0x00, 0x40, 0x00, 0x40, 0x00, 0x3C}; LiteralProto p; p.mutable_shape()->set_element_type(F16); p.mutable_shape()->clear_dimensions(); @@ -1006,7 +987,6 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { p.clear_f16s(); p.set_f16s(half_vals, 8); - Literal literal(p); ASSERT_EQ(4, literal.f16s_size()); ASSERT_EQ(h1, literal.f16s(0)); @@ -1022,6 +1002,5 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { ASSERT_EQ(h1, r[3]); } - } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc index d488830a6cd7b07ccb8de237121ab0693bd73a0f..70e0f5a74711c8ceef1b6d4225141aa1cc9c6219 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.cc +++ b/tensorflow/compiler/xla/packed_literal_reader.cc @@ -58,8 +58,7 @@ StatusOr> PackedLiteralReader::Read( } int64 elements = ShapeUtil::ElementsIn(shape); - LiteralUtil::Resize(elements, std::numeric_limits::quiet_NaN(), - result.get()); + result->Resize(elements, std::numeric_limits::quiet_NaN()); std::vector* field = result->mutable_f32s(); char* data = tensorflow::bit_cast(field->data()); uint64 bytes = elements * sizeof(float); diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index e8de559a5ef9e69864abab21c99887d40cfd378a..138e360c290b97cd19d8a4564a1a8668d7052627 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -135,6 +135,49 @@ ReferenceUtil::SeparableConvArray4D(const Array4D& input, return tensorflow::MathUtil::CeilOfRatio(unpadded_width, stride); } +/* static */ std::unique_ptr> +ReferenceUtil::ReduceWindow1DGeneric( + const tensorflow::gtl::ArraySlice& operand, float init, + const std::function& reduce_func, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, Padding padding) { + std::vector dim_lengths{static_cast(operand.size())}; + auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); + + std::vector window_counts(window.size(), 0); + std::vector pad_low(window.size(), 0); + for (int64 i = 0; i < window.size(); ++i) { + window_counts[i] = + WindowCount(dim_lengths[i], window[i], stride[i], padding); + pad_low[i] = padding_both[i].first; + } + auto result = MakeUnique>(window_counts[0]); + + // Do a full 1D reduce window. + for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { + int64 i0_base = i0 * stride[0] - pad_low[0]; + + float val = init; + for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) { + if (i0_base + i0_win >= 0 && i0_base + i0_win < dim_lengths[0]) { + val = reduce_func(val, operand[i0_base + i0_win]); + } + } + (*result)[i0] = val; + } + return result; +} + +/* static */ std::unique_ptr> +ReferenceUtil::ReduceWindow1DAdd( + const tensorflow::gtl::ArraySlice& operand, float init, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, Padding padding) { + const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; + return ReduceWindow1DGeneric(operand, init, add_reduce, window, stride, + padding); +} + /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow2DAdd( const Array2D& operand, float init, const tensorflow::gtl::ArraySlice& window, @@ -252,6 +295,20 @@ ReferenceUtil::ReduceWindow4DGeneric( padding); } +/* static */ std::unique_ptr> ReferenceUtil::BatchNorm4D( + const Array4D& input, const Array4D& mean, + const Array4D& var, const Array4D& scale, + const Array4D& offset, float epsilon) { + auto normalized = + *MapArray4D(input, mean, [](float a, float b) { return a - b; }); + normalized = *MapArray4D(normalized, var, [&](float a, float b) { + return a / std::sqrt(b + epsilon); + }); + normalized = + *MapArray4D(normalized, scale, [](float a, float b) { return a * b; }); + return MapArray4D(normalized, offset, [](float a, float b) { return a + b; }); +} + /* static */ std::unique_ptr> ReferenceUtil::SelectAndScatter4DGePlus( const Array4D& operand, const Array4D& source, float init, @@ -439,21 +496,21 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( // Lambda to access the rhs operand at the given 4D index. height_over_dky // should be equal to height / dky, and width_over_dkx should be equal to // width / dkx. (This is an optimization to avoid doing divisions.) - const auto rhs_element = [&]( - int64 kernel_output_feature, int64 kernel_input_feature, int64 height, - int64 width, int64 height_over_dky, int64 width_over_dkx) { - DCHECK_EQ(height % dky, 0); - DCHECK_EQ(width % dkx, 0); - DCHECK_EQ(height / dky, height_over_dky); - DCHECK_EQ(width / dkx, width_over_dkx); - - std::array index; - index[dnums.kernel_output_feature_dimension()] = kernel_output_feature; - index[dnums.kernel_input_feature_dimension()] = kernel_input_feature; - index[dnums.kernel_spatial_dimensions(0)] = height_over_dky; - index[dnums.kernel_spatial_dimensions(1)] = width_over_dkx; - return rhs(index[0], index[1], index[2], index[3]); - }; + const auto rhs_element = + [&](int64 kernel_output_feature, int64 kernel_input_feature, int64 height, + int64 width, int64 height_over_dky, int64 width_over_dkx) { + DCHECK_EQ(height % dky, 0); + DCHECK_EQ(width % dkx, 0); + DCHECK_EQ(height / dky, height_over_dky); + DCHECK_EQ(width / dkx, width_over_dkx); + + std::array index; + index[dnums.kernel_output_feature_dimension()] = kernel_output_feature; + index[dnums.kernel_input_feature_dimension()] = kernel_input_feature; + index[dnums.kernel_spatial_dimensions(0)] = height_over_dky; + index[dnums.kernel_spatial_dimensions(1)] = width_over_dkx; + return rhs(index[0], index[1], index[2], index[3]); + }; // Lambda to access the result data at the given 4D index. const auto result_element = [&](int64 batch, int64 kernel_output_feature, @@ -491,6 +548,30 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( } } } + if (samples == 0 || kx == 0 || ky == 0 || ox == 0 || oy == 0 || oz == 0 || + iz == 0) { + LOG(INFO) << "Output will be trivially empty because one of these " + "dimensions is 0: samples: " + << samples << " kx: " << kx << " ky: " << ky << " ox: " << ox + << " oy: " << oy << " oz: " << oz << " iz: " << iz; + return result; + } + bool trivial = true; + auto check_trivial = [&trivial](tensorflow::gtl::ArraySlice indices, + float value) { + if (value != 0.0) { + trivial = false; + } + }; + lhs.Each(check_trivial); + if (trivial) { + LOG(FATAL) << "LHS is all 0.0."; + } + trivial = true; + rhs.Each(check_trivial); + if (trivial) { + LOG(FATAL) << "RHS is all 0.0."; + } return result; } @@ -566,6 +647,38 @@ ReferenceUtil::ReduceToRowArray2D( return result; } +/* static */ std::unique_ptr> ReferenceUtil::Broadcast1DTo4D( + const std::vector& array, const std::vector& bounds, + int64 broadcast_from_dim) { + auto result = + MakeUnique>(bounds[0], bounds[1], bounds[2], bounds[3]); + for (int64 i = 0; i < result->n1(); ++i) { + for (int64 j = 0; j < result->n2(); ++j) { + for (int64 k = 0; k < result->n3(); ++k) { + for (int64 l = 0; l < result->n4(); ++l) { + switch (broadcast_from_dim) { + case 0: + (*result)(i, j, k, l) = array[i]; + break; + case 1: + (*result)(i, j, k, l) = array[j]; + break; + case 2: + (*result)(i, j, k, l) = array[k]; + break; + case 3: + (*result)(i, j, k, l) = array[l]; + break; + default: + break; + } + } + } + } + } + return result; +} + /* static */ std::unique_ptr> ReferenceUtil::Reduce3DTo2D( const Array3D& array, float init, tensorflow::gtl::ArraySlice dims, diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index f58f0bdc9f51dff62c10dda4aba7aac03e689ce7..1d326aff5fb5cc752b9d1e9f1f735a26cbb83b3a 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -120,6 +120,11 @@ class ReferenceUtil { tensorflow::gtl::ArraySlice dims, std::function reduce_function); + // Broadcast 1D dimension to 4D, from the dimension `broadcast_from_dim`. + static std::unique_ptr> Broadcast1DTo4D( + const std::vector& array, const std::vector& bounds, + int64 broadcast_from_dim); + // Returns the result of reducing the 3D array to a 2D array, reducing away // the dimensions specified in dims. static std::unique_ptr> Reduce3DTo2D( @@ -144,19 +149,26 @@ class ReferenceUtil { static int64 WindowCount(int64 unpadded_width, int64 window_len, int64 stride, Padding padding); - // Performs a 2D window reduction with Add as the function to apply. + // Windowed reductions with Add as the function to apply. + static std::unique_ptr> ReduceWindow1DAdd( + const tensorflow::gtl::ArraySlice& operand, float init, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, Padding padding); static std::unique_ptr> ReduceWindow2DAdd( const Array2D& operand, float init, const tensorflow::gtl::ArraySlice& window, const tensorflow::gtl::ArraySlice& stride, Padding padding); - - // Performs a 4D window reduction with Add as the function to apply. static std::unique_ptr> ReduceWindow4DAdd( const Array4D& operand, float init, const tensorflow::gtl::ArraySlice& window, const tensorflow::gtl::ArraySlice& stride, Padding padding); - // Performs a 4D window reduction with a generic reduce function. + // Windowed reductions with a generic reduce function. + static std::unique_ptr> ReduceWindow1DGeneric( + const tensorflow::gtl::ArraySlice& operand, float init, + const std::function& reduce_func, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, Padding padding); static std::unique_ptr> ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, @@ -169,6 +181,12 @@ class ReferenceUtil { const tensorflow::gtl::ArraySlice& stride, const tensorflow::gtl::ArraySlice>& padding); + // Batch normalize data. + static std::unique_ptr> BatchNorm4D( + const Array4D& input, const Array4D& mean, + const Array4D& var, const Array4D& scale, + const Array4D& offset, float epsilon); + // Performs select and scatter with Greater Than or equal as the select, plus // as the scatter, and Same Padding. static std::unique_ptr> SelectAndScatter4DGePlus( @@ -396,6 +414,41 @@ class ReferenceUtil { return result; } + // Applies map_function to each pair of elements in the input lhs and rhs + // (4D array) and returns the result. + template + static std::unique_ptr> MapArray4D(const Array4D& lhs, + const Array4D& rhs, + F&& map_function) { + return MapWithIndexArray4D( + lhs, rhs, [&](float lhs, float rhs, int64, int64, int64, int64) { + return map_function(lhs, rhs); + }); + } + + // Applies map_function to each pair of element in lhs and rhs (4D array) and + // returns the result. + // (plane, depth, height, width) index of each element is also provided as + // arguments to map_function. + template + static std::unique_ptr> MapWithIndexArray4D( + const Array4D& lhs, const Array4D& rhs, F&& map_function) { + auto result = MakeUnique>(lhs.planes(), lhs.depth(), + lhs.height(), lhs.width()); + for (int64 plane = 0; plane < lhs.planes(); ++plane) { + for (int64 depth = 0; depth < lhs.depth(); ++depth) { + for (int64 height = 0; height < lhs.height(); ++height) { + for (int64 width = 0; width < lhs.width(); ++width) { + (*result)(plane, depth, height, width) = map_function( + lhs(plane, depth, height, width), + rhs(plane, depth, height, width), plane, depth, height, width); + } + } + } + } + return result; + } + // Returns the result of a 2D pad on an input matrix. static std::unique_ptr> PadArray2D( const Array2D& operand, const PaddingConfig& padding, diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc index f839ac019df07c5c5e07eed856ea55463bb3efae..215f220258964afc4f3a3acc25b98882c9483d2e 100644 --- a/tensorflow/compiler/xla/reference_util_test.cc +++ b/tensorflow/compiler/xla/reference_util_test.cc @@ -52,7 +52,7 @@ class ReferenceUtilTest : public ::testing::Test { TEST_F(ReferenceUtilTest, TransposeArray2D) { auto result = ReferenceUtil::TransposeArray2D(*matrix_); - auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); + auto actual_literal = Literal::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 4.f}, {2.f, 5.f}, {3.f, 6.f}}, *actual_literal, ErrorSpec(0.0001)); } @@ -62,7 +62,7 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) { {7.f, 8.f}, {9.f, 10.f}, {11.f, 12.f}, }); auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs); - auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); + auto actual_literal = Literal::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{58.f, 64.f}, {139.f, 154.f}}, *actual_literal, ErrorSpec(0.0001)); } @@ -70,7 +70,7 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) { TEST_F(ReferenceUtilTest, ReduceToColArray2D) { auto add = [](float lhs, float rhs) { return lhs + rhs; }; auto result = ReferenceUtil::ReduceToColArray2D(*matrix_, 0.0f, add); - auto actual_literal = LiteralUtil::CreateR1(*result); + auto actual_literal = Literal::CreateR1(*result); LiteralTestUtil::ExpectR1Near({6.f, 15.f}, *actual_literal, ErrorSpec(0.0001)); } @@ -78,7 +78,7 @@ TEST_F(ReferenceUtilTest, ReduceToColArray2D) { TEST_F(ReferenceUtilTest, ReduceToRowArray2D) { auto add = [](float lhs, float rhs) { return lhs + rhs; }; auto result = ReferenceUtil::ReduceToRowArray2D(*matrix_, 0.0f, add); - auto actual_literal = LiteralUtil::CreateR1(*result); + auto actual_literal = Literal::CreateR1(*result); LiteralTestUtil::ExpectR1Near({5.f, 7.f, 9.f}, *actual_literal, ErrorSpec(0.0001)); } @@ -86,7 +86,7 @@ TEST_F(ReferenceUtilTest, ReduceToRowArray2D) { TEST_F(ReferenceUtilTest, MapArray2D) { auto identity = [](float value) { return log(exp(value)); }; auto result = ReferenceUtil::MapArray2D(*matrix_, identity); - auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); + auto actual_literal = Literal::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *actual_literal, ErrorSpec(0.0001)); } @@ -96,7 +96,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray2D) { return value + row + col; }; auto result = ReferenceUtil::MapWithIndexArray2D(*matrix_, add_index); - auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); + auto actual_literal = Literal::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 3.f, 5.f}, {5.f, 7.f, 9.f}}, *actual_literal, ErrorSpec(0.0001)); } @@ -107,7 +107,7 @@ TEST_F(ReferenceUtilTest, MapArray4D) { input->FillWithMultiples(1.0f); auto multiply_by_two = [](float value) { return 2 * value; }; auto result = ReferenceUtil::MapArray4D(*input, multiply_by_two); - auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result); + auto actual_literal = Literal::CreateR4FromArray4D(*result); Array4D expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5); expected.FillWithMultiples(2.0f); @@ -124,7 +124,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) { return value - (3 * 4 * 5 * plane + 4 * 5 * depth + 5 * height + width); }; auto result = ReferenceUtil::MapWithIndexArray4D(*input, subtract_index); - auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result); + auto actual_literal = Literal::CreateR4FromArray4D(*result); Array4D expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5); expected.Fill(0.0f); @@ -161,7 +161,7 @@ TEST_F(ReferenceUtilTest, ConvWithSamePadding) { })); // clang-format on - auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); + auto actual_literal = Literal::CreateR4FromArray4D(*actual); LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, ErrorSpec(0.0001)); @@ -195,7 +195,7 @@ TEST_F(ReferenceUtilTest, ConvWithValidPadding) { })); // clang-format on - auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); + auto actual_literal = Literal::CreateR4FromArray4D(*actual); LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, ErrorSpec(0.0001)); @@ -247,7 +247,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithSamePadding) { }}); // clang-format on - auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); + auto actual_literal = Literal::CreateR4FromArray4D(*actual); LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, ErrorSpec(0.0001)); @@ -296,7 +296,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) { Array4D expected({{{{2514, 2685}}}}); // clang-format on - auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); + auto actual_literal = Literal::CreateR4FromArray4D(*actual); LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, ErrorSpec(0.0001)); @@ -309,7 +309,7 @@ TEST_F(ReferenceUtilTest, ApplyElementwise2D) { auto actual = ReferenceUtil::ApplyElementwise2D( [](float x, float y, float z) { return 100 * x + 10 * y + z; }, a, b, c); - auto actual_literal = LiteralUtil::CreateR2FromArray2D(*actual); + auto actual_literal = Literal::CreateR2FromArray2D(*actual); LiteralTestUtil::ExpectR2Near({{300.f, 600.f}, {900.f, 1200.f}}, *actual_literal, ErrorSpec(0.0001)); } diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 0687368b83db343cfa15da969b9f4d9d1a821078..f3b2bf627940b73a14f7e07c73932c1c9b467cc6 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -24,9 +24,7 @@ xla_proto_library( xla_proto_library( name = "hlo_proto", srcs = ["hlo.proto"], - deps = [ - "//tensorflow/compiler/xla:xla_data_proto", - ], + deps = ["//tensorflow/compiler/xla:xla_data_proto"], ) # Filegroup used to collect source files for dependency checking. @@ -88,11 +86,13 @@ cc_library( deps = [ ":hlo", ":hlo_query", + ":shape_inference", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", ], @@ -106,12 +106,16 @@ cc_test( ":hlo", ":hlo_evaluator", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", "//tensorflow/core:test_main", @@ -138,6 +142,7 @@ cc_library( deps = [ ":hlo_module_config", ":hlo_proto", + ":hlo_reachability", ":name_uniquer", ":versioned_computation_handle", "//tensorflow/compiler/xla:literal_util", @@ -155,6 +160,31 @@ cc_library( ], ) +cc_library( + name = "hlo_reachability", + srcs = ["hlo_reachability.cc"], + hdrs = ["hlo_reachability.h"], + deps = [ + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +cc_test( + name = "hlo_reachability_test", + srcs = ["hlo_reachability_test.cc"], + deps = [ + ":hlo", + ":hlo_reachability", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "hlo_matchers", testonly = 1, @@ -285,7 +315,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:user_computation_flags", + "//tensorflow/compiler/xla:xla_proto", "//tensorflow/core:lib", ], ) @@ -303,7 +333,7 @@ cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:user_computation_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:test", ], @@ -330,6 +360,7 @@ cc_library( hdrs = ["backend.h"], deps = [ ":compiler", + ":computation_placer", ":device_memory_allocator", ":platform_util", ":pool", @@ -338,7 +369,6 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/legacy_flags:backend_flags", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", @@ -382,6 +412,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/legacy_flags:backend_flags", "//tensorflow/compiler/xla/legacy_flags:service_flags", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/core:lib", @@ -416,6 +447,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/legacy_flags:service_flags", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", @@ -508,7 +540,6 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/legacy_flags:service_flags", "//tensorflow/core:lib", - "//tensorflow/core:regexp_internal", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/stream_executor", ], @@ -707,9 +738,10 @@ cc_library( ], deps = [ ":buffer_liveness", + ":heap_simulator", ":hlo", - ":hlo_ordering", ":hlo_proto", + ":hlo_scheduling", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -718,7 +750,6 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:buffer_assignment_flags", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", ], @@ -736,6 +767,7 @@ cc_test( ":flatten_call_graph", ":hlo", ":hlo_ordering", + ":hlo_scheduling", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -748,11 +780,61 @@ cc_test( ], ) +cc_library( + name = "hlo_ordering", + srcs = ["hlo_ordering.cc"], + hdrs = ["hlo_ordering.h"], + deps = [ + ":call_graph", + ":hlo", + ":hlo_proto", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "hlo_ordering_test", + size = "small", + srcs = ["hlo_ordering_test.cc"], + deps = [ + ":hlo", + ":hlo_ordering", + ":hlo_scheduling", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + ], +) + +cc_library( + name = "heap_simulator", + srcs = ["heap_simulator.cc"], + hdrs = ["heap_simulator.h"], + deps = [ + ":hlo", + ":hlo_ordering", + ":hlo_proto", + ":liveness_util", + ":logical_buffer", + ":tuple_points_to_analysis", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + cc_test( name = "heap_simulator_test", size = "small", srcs = ["heap_simulator_test.cc"], deps = [ + ":heap_simulator", ":hlo", ":hlo_ordering", ":logical_buffer", @@ -765,23 +847,15 @@ cc_test( ], ) -# The hlo_ordering library contains both hlo_ordering and heap_simulator because -# they are mutually dependent. cc_library( - name = "hlo_ordering", - srcs = [ - "heap_simulator.cc", - "hlo_ordering.cc", - ], - hdrs = [ - "heap_simulator.h", - "hlo_ordering.h", - ], + name = "hlo_scheduling", + srcs = ["hlo_scheduling.cc"], + hdrs = ["hlo_scheduling.h"], deps = [ - ":call_graph", + ":heap_simulator", ":hlo", + ":hlo_ordering", ":hlo_proto", - ":liveness_util", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -794,12 +868,13 @@ cc_library( ) cc_test( - name = "hlo_ordering_test", + name = "hlo_scheduling_test", size = "small", - srcs = ["hlo_ordering_test.cc"], + srcs = ["hlo_scheduling_test.cc"], deps = [ ":hlo", ":hlo_ordering", + ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -948,6 +1023,26 @@ cc_test( ], ) +cc_library( + name = "computation_placer", + srcs = ["computation_placer.cc"], + hdrs = ["computation_placer.h"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], + alwayslink = True, # Contains per-platform computation placer registration +) + cc_library( name = "generic_transfer_manager", srcs = ["generic_transfer_manager.cc"], @@ -1030,12 +1125,8 @@ cc_test( cc_library( name = "hlo_cost_analysis", - srcs = [ - "hlo_cost_analysis.cc", - ], - hdrs = [ - "hlo_cost_analysis.h", - ], + srcs = ["hlo_cost_analysis.cc"], + hdrs = ["hlo_cost_analysis.h"], deps = [ ":hlo", "//tensorflow/compiler/xla:shape_util", @@ -1137,12 +1228,8 @@ cc_test( cc_library( name = "logical_buffer", - srcs = [ - "logical_buffer.cc", - ], - hdrs = [ - "logical_buffer.h", - ], + srcs = ["logical_buffer.cc"], + hdrs = ["logical_buffer.h"], deps = [ ":hlo", ":hlo_proto", @@ -1155,18 +1242,31 @@ cc_library( ) cc_library( - name = "hlo_dataflow_analysis", - srcs = [ - "hlo_dataflow_analysis.cc", - ], - hdrs = [ - "hlo_dataflow_analysis.h", + name = "hlo_value", + srcs = ["hlo_value.cc"], + hdrs = ["hlo_value.h"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", ], +) + +cc_library( + name = "hlo_dataflow_analysis", + srcs = ["hlo_dataflow_analysis.cc"], + hdrs = ["hlo_dataflow_analysis.h"], deps = [ ":call_graph", ":hlo", + ":hlo_ordering", + ":hlo_value", ":liveness_util", - "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:statusor", @@ -1174,7 +1274,6 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", ], ) @@ -1201,20 +1300,32 @@ cc_test( ) cc_library( - name = "hlo_alias_analysis", - srcs = [ - "hlo_alias_analysis.cc", - ], - hdrs = [ - "hlo_alias_analysis.h", + name = "hlo_buffer", + srcs = ["hlo_buffer.cc"], + hdrs = ["hlo_buffer.h"], + deps = [ + ":hlo", + ":hlo_value", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", ], +) + +cc_library( + name = "hlo_alias_analysis", + srcs = ["hlo_alias_analysis.cc"], + hdrs = ["hlo_alias_analysis.h"], deps = [ - ":call_graph", ":hlo", + ":hlo_buffer", ":hlo_dataflow_analysis", - ":logical_buffer", - "//tensorflow/compiler/xla:shape_tree", + ":hlo_value", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -1245,12 +1356,8 @@ cc_test( cc_library( name = "tuple_points_to_analysis", - srcs = [ - "tuple_points_to_analysis.cc", - ], - hdrs = [ - "tuple_points_to_analysis.h", - ], + srcs = ["tuple_points_to_analysis.cc"], + hdrs = ["tuple_points_to_analysis.h"], deps = [ ":hlo", ":logical_buffer", @@ -1287,12 +1394,8 @@ cc_test( cc_library( name = "compilation_cache", - srcs = [ - "compilation_cache.cc", - ], - hdrs = [ - "compilation_cache.h", - ], + srcs = ["compilation_cache.cc"], + hdrs = ["compilation_cache.h"], deps = [ ":executable", ":hlo_module_config", @@ -1386,7 +1489,10 @@ cc_library( name = "hlo_verifier", srcs = ["hlo_verifier.cc"], hdrs = ["hlo_verifier.h"], - deps = [":hlo_pass"], + deps = [ + ":hlo_pass", + "//tensorflow/core:lib", + ], ) cc_library( @@ -1398,9 +1504,9 @@ cc_library( ":call_graph", ":flatten_call_graph", ":hlo", - ":hlo_cost_analysis", ":hlo_dce", ":hlo_ordering", + ":hlo_scheduling", ":liveness_util", ":logical_buffer", ":tuple_points_to_analysis", @@ -1497,8 +1603,8 @@ cc_library( "hlo_pass_pipeline.h", ], deps = [ - ":compiler", ":hlo", + ":hlo_graph_dumper", ":hlo_pass", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -1572,10 +1678,8 @@ cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/core:lib", ], ) @@ -1707,8 +1811,9 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:window_util", - "//tensorflow/compiler/xla/legacy_flags:hlo_graph_dumper_flags", + "//tensorflow/compiler/xla:xla_proto", "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", ], alwayslink = 1, ) @@ -1777,10 +1882,41 @@ cc_library( ":hlo", ":hlo_proto", "//tensorflow/compiler/xla:status", + ], +) + +cc_library( + name = "reduce_precision_insertion", + srcs = ["reduce_precision_insertion.cc"], + hdrs = ["reduce_precision_insertion.h"], + deps = [ + ":buffer_liveness", + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", ], ) +cc_test( + name = "reduce_precision_insertion_test", + size = "small", + srcs = ["reduce_precision_insertion_test.cc"], + deps = [ + ":hlo", + ":hlo_matchers", + ":reduce_precision_insertion", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 754ac0c68dc025c6d2bde4b40e148e6043f0cf6d..b351861425d76b9dbed6e71daf935191e98b40ba 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -48,7 +48,7 @@ namespace { // Returns whether operand is a literal with the given value. bool IsLiteralWithValue(const HloInstruction* operand, int8 value) { return operand->opcode() == HloOpcode::kConstant && - LiteralUtil::IsAll(operand->literal(), value); + operand->literal().IsAll(value); } bool IsAll(const HloInstruction* op, int8 value) { @@ -126,10 +126,12 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { HloInstruction* concatenate, tensorflow::gtl::ArraySlice operands) override; - Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override; + Status HandleConstant(HloInstruction* constant, + const Literal& literal) override; - Status HandleConvert(HloInstruction* convert, - HloInstruction* operand) override; + Status HandleCopy(HloInstruction* copy) override; + + Status HandleConvert(HloInstruction* convert) override; Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, HloInstruction* rhs, const Window& window) override; @@ -179,11 +181,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleSubtract(HloInstruction* sub, HloInstruction* lhs, HloInstruction* rhs) override; - Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs, - HloInstruction* rhs) override; - - Status HandleMinimum(HloInstruction* minimum, HloInstruction* lhs, - HloInstruction* rhs) override; + Status HandleMaximum(HloInstruction* maximum) override; + Status HandleMinimum(HloInstruction* minimum) override; // Returns whether algebraic simplification has occurred. const bool changed() const { return changed_; } @@ -334,16 +333,16 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add, return Status::OK(); } -Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy, - HloInstruction* operand) { +Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) { // If a copy feeds a copy, make it a single copy. - if (operand->opcode() == HloOpcode::kCopy) { + if (copy->operand(0)->opcode() == HloOpcode::kCopy) { return ReplaceWithNewInstruction( - copy, HloInstruction::CreateUnary(copy->shape(), HloOpcode::kCopy, - operand->operands()[0])); + copy, HloInstruction::CreateUnary( + copy->shape(), HloOpcode::kCopy, + copy->mutable_operand(0)->mutable_operand(0))); } // All copies can be eliminated (assuming layout constraints are satisified). - ReplaceInstructionIfSameShape(copy, operand); + ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0)); return Status::OK(); } @@ -415,6 +414,32 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( return Status::OK(); } +static HloInstruction* BuildTupleConstant(HloComputation* computation, + const Literal& literal) { + if (ShapeUtil::IsTuple(literal.shape())) { + std::vector elems; + elems.reserve(ShapeUtil::TupleElementCount(literal.shape())); + for (const Literal& child : literal.tuple_literals()) { + elems.push_back(BuildTupleConstant(computation, child)); + } + return computation->AddInstruction(HloInstruction::CreateTuple(elems)); + } else { + return computation->AddInstruction( + HloInstruction::CreateConstant(MakeUnique(literal))); + } +} + +Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant, + const Literal& literal) { + // Tuple constants aren't directly supported by any backend. Expand them into + // explicit Tuple instructions. + if (ShapeUtil::IsTuple(constant->shape())) { + return ReplaceInstruction(constant, + BuildTupleConstant(computation_, literal)); + } + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub, HloInstruction* lhs, HloInstruction* rhs) { @@ -448,6 +473,72 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide, subtract)); } + // A/exp(B) => A*exp(-B) + if (rhs->opcode() == HloOpcode::kExp) { + VLOG(10) << "transform [A/exp(B) => A*exp(-B)]: " << divide->ToString(); + HloInstruction* negate = + computation_->AddInstruction(HloInstruction::CreateUnary( + divide->shape(), HloOpcode::kNegate, rhs->mutable_operand(0))); + HloInstruction* new_exp = computation_->AddInstruction( + HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp, negate)); + return ReplaceWithNewInstruction( + divide, HloInstruction::CreateBinary( + divide->shape(), HloOpcode::kMultiply, lhs, new_exp)); + } + + // A/pow(B,C) => A*pow(B,-C) + if (rhs->opcode() == HloOpcode::kPower) { + VLOG(10) << "transform [A/pow(B,C) => A*pow(B,-C)]: " << divide->ToString(); + HloInstruction* negate = + computation_->AddInstruction(HloInstruction::CreateUnary( + divide->shape(), HloOpcode::kNegate, rhs->mutable_operand(1))); + HloInstruction* new_power = computation_->AddInstruction( + HloInstruction::CreateBinary(divide->shape(), HloOpcode::kPower, + rhs->mutable_operand(0), negate)); + return ReplaceWithNewInstruction( + divide, HloInstruction::CreateBinary( + divide->shape(), HloOpcode::kMultiply, lhs, new_power)); + } + + // Simplifying integral division would produce unexpected results. + if (ShapeUtil::ElementIsIntegral(divide->shape())) { + return Status::OK(); + } + + // (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C) + if (lhs->opcode() == HloOpcode::kDivide && + rhs->opcode() == HloOpcode::kDivide) { + auto a_times_d = computation_->AddInstruction(HloInstruction::CreateBinary( + divide->shape(), HloOpcode::kMultiply, lhs->mutable_operand(0), + rhs->mutable_operand(1))); + auto b_times_c = computation_->AddInstruction(HloInstruction::CreateBinary( + divide->shape(), HloOpcode::kMultiply, lhs->mutable_operand(1), + rhs->mutable_operand(0))); + return ReplaceWithNewInstruction( + divide, HloInstruction::CreateBinary( + divide->shape(), HloOpcode::kDivide, a_times_d, b_times_c)); + } + + // (A / B) / C => A / (B * C) + if (lhs->opcode() == HloOpcode::kDivide) { + auto b_times_c = computation_->AddInstruction(HloInstruction::CreateBinary( + divide->shape(), HloOpcode::kMultiply, lhs->mutable_operand(1), rhs)); + return ReplaceWithNewInstruction( + divide, + HloInstruction::CreateBinary(divide->shape(), HloOpcode::kDivide, + lhs->mutable_operand(0), b_times_c)); + } + + // A / (B / C) => (A*C) / B + if (rhs->opcode() == HloOpcode::kDivide) { + auto a_times_c = computation_->AddInstruction(HloInstruction::CreateBinary( + divide->shape(), HloOpcode::kMultiply, lhs, rhs->mutable_operand(1))); + return ReplaceWithNewInstruction( + divide, + HloInstruction::CreateBinary(divide->shape(), HloOpcode::kDivide, + a_times_c, rhs->mutable_operand(0))); + } + return Status::OK(); } @@ -469,7 +560,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot, ShapeUtil::HasZeroElements(lhs->shape()) || ShapeUtil::HasZeroElements(rhs->shape())) { auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); return ReplaceWithNewInstruction( dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {})); } @@ -507,7 +598,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot, HloComputation* add_reduce_computation = CreateScalarBinaryComputation( computation_->parent(), F32, HloOpcode::kAdd); auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce( ShapeUtil::MakeShape(dot->shape().element_type(), {}), multiply, zero, {0}, add_reduce_computation)); @@ -531,7 +622,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot, HloComputation* add_reduce_computation = CreateScalarBinaryComputation( computation_->parent(), F32, HloOpcode::kAdd); auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); HloInstruction* reduce; if (ShapeUtil::Rank(rhs->shape()) == 1) { auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary( @@ -571,7 +662,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot, HloComputation* add_reduce_computation = CreateScalarBinaryComputation( computation_->parent(), F32, HloOpcode::kAdd); auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce( ShapeUtil::MakeShape(dot->shape().element_type(), {lhs->shape().dimensions(0)}), @@ -595,6 +686,16 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply, if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(multiply, rhs)) { return Status::OK(); } + + // exp(A) * exp(B) => exp(A+B) + if (lhs->opcode() == HloOpcode::kExp && rhs->opcode() == HloOpcode::kExp) { + auto add = computation_->AddInstruction(HloInstruction::CreateBinary( + multiply->shape(), HloOpcode::kAdd, lhs->mutable_operand(0), + rhs->mutable_operand(0))); + return ReplaceWithNewInstruction( + multiply, + HloInstruction::CreateUnary(multiply->shape(), HloOpcode::kExp, add)); + } return Status::OK(); } @@ -606,6 +707,17 @@ Status AlgebraicSimplifierVisitor::HandleLog(HloInstruction* log, ReplaceInstructionIfSameShape(log, operand->mutable_operand(0))) { return Status::OK(); } + + // ln(pow(A,B)) => B*ln(A) + if (operand->opcode() == HloOpcode::kPower) { + auto new_log = computation_->AddInstruction(HloInstruction::CreateUnary( + log->shape(), HloOpcode::kLog, operand->mutable_operand(0))); + return ReplaceWithNewInstruction( + log, + HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply, + new_log, operand->mutable_operand(1))); + } + return Status::OK(); } @@ -792,12 +904,11 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { // A conversion to the same element type as the operand is a nop and can be // removed. A conversion of a constant can be simplified by making a new // constant. -Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert, - HloInstruction* operand) { - PrimitiveType src_type = operand->shape().element_type(); +Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) { + PrimitiveType src_type = convert->operand(0)->shape().element_type(); PrimitiveType dest_type = convert->shape().element_type(); if (src_type == dest_type) { - return ReplaceInstruction(convert, operand); + return ReplaceInstruction(convert, convert->mutable_operand(0)); } return Status::OK(); } @@ -878,10 +989,10 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { } // Verify that the slice shape matches the pad shape. - TF_ASSIGN_OR_RETURN(Shape inferred_slice_shape, - ShapeInference::InferSliceShape( - nonzero_pad_shape, start_indices, end_indices, - strides)); + TF_ASSIGN_OR_RETURN( + Shape inferred_slice_shape, + ShapeInference::InferSliceShape(nonzero_pad_shape, start_indices, + end_indices, strides)); TF_RET_CHECK(ShapeUtil::Compatible(inferred_slice_shape, pad->shape())); std::unique_ptr slice = HloInstruction::CreateSlice( @@ -897,8 +1008,8 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power, HloInstruction* rhs) { VLOG(10) << "trying transform [pow(A, 0) => 1]: " << power->ToString(); if (IsAll(rhs, 0)) { - auto one = HloInstruction::CreateConstant(LiteralUtil::CloneToUnique( - LiteralUtil::One(power->shape().element_type()))); + auto one = HloInstruction::CreateConstant( + Literal::One(power->shape().element_type()).CloneToUnique()); std::unique_ptr ones; if (ShapeUtil::IsScalar(power->shape())) { ones = std::move(one); @@ -914,6 +1025,14 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power, return Status::OK(); } + // pow(exp(A),B) => exp(A*B) + if (lhs->opcode() == HloOpcode::kExp) { + auto a_times_b = computation_->AddInstruction(HloInstruction::CreateBinary( + power->shape(), HloOpcode::kMultiply, lhs->operands()[0], rhs)); + return ReplaceWithNewInstruction( + power, HloInstruction::CreateUnary(power->shape(), HloOpcode::kExp, + a_times_b)); + } VLOG(10) << "trying transform [pow(A, 2) => A*A]: " << power->ToString(); if (IsAll(rhs, 2)) { return ReplaceWithNewInstruction( @@ -923,9 +1042,8 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power, VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString(); if (IsAll(rhs, -1)) { - auto* one = computation_->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CloneToUnique( - LiteralUtil::One(rhs->shape().element_type())))); + auto* one = computation_->AddInstruction(HloInstruction::CreateConstant( + Literal::One(rhs->shape().element_type()).CloneToUnique())); return ReplaceWithNewInstruction( power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kDivide, one, lhs)); @@ -937,6 +1055,9 @@ StatusOr AlgebraicSimplifierVisitor:: TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand( HloInstruction* reshape_or_broadcast) { bool changed = false; + if (ShapeUtil::IsScalar(reshape_or_broadcast->shape())) { + return false; + } HloInstruction* operand = reshape_or_broadcast->mutable_operand(0); for (HloInstruction* user : reshape_or_broadcast->users()) { if (user->user_count() == 0 && user != computation_->root_instruction()) { @@ -1008,7 +1129,7 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { // dimension. if (ShapeUtil::HasZeroElements(reshape->shape())) { auto empty_constant = HloInstruction::CreateConstant( - LiteralUtil::CreateFromShape(reshape->shape())); + Literal::CreateFromShape(reshape->shape())); return ReplaceWithNewInstruction(reshape, std::move(empty_constant)); } @@ -1208,8 +1329,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( // try to get more fancy about proving equivalence in cases beyond that. if (pad_value->opcode() != HloOpcode::kConstant || reduce_init_value->opcode() != HloOpcode::kConstant || - !LiteralUtil::Equal(pad_value->literal(), - reduce_init_value->literal())) { + !pad_value->literal().Equal(reduce_init_value->literal())) { VLOG(10) << "Not folding pad into reduce-window due to different pad " "values."; return Status::OK(); @@ -1396,9 +1516,7 @@ bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape( return true; } -Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum, - HloInstruction* lhs, - HloInstruction* rhs) { +Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum) { // Match the following tree: // min_operand operand // \ / @@ -1429,9 +1547,7 @@ Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum, return Status::OK(); } -Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum, - HloInstruction* lhs, - HloInstruction* rhs) { +Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum) { // Match the following tree: // max_operand operand // \ / diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index e4368a7bb25093f70bf78288db2021d36fa7f25a..ff119c009862d75e5a248bb82eaa822351898d7c 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -55,7 +55,7 @@ TEST_F(AlgebraicSimplifierTest, AddZero) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero)); @@ -76,7 +76,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r2f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); HloInstruction* bcast = builder.AddInstruction( HloInstruction::CreateBroadcast(r2f32, zero, {0, 1})); builder.AddInstruction( @@ -99,7 +99,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r2f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 0, 0}))); + HloInstruction::CreateConstant(Literal::CreateR1({0, 0, 0}))); HloInstruction* bcast = builder.AddInstruction(HloInstruction::CreateBroadcast(r2f32, zero, {1})); builder.AddInstruction( @@ -123,7 +123,7 @@ TEST_F(AlgebraicSimplifierTest, SubZero) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero)); @@ -138,6 +138,155 @@ TEST_F(AlgebraicSimplifierTest, SubZero) { EXPECT_EQ(root, param0); } +// Test that (A/B)/C is simplified to A/(B*C). +TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32, "param1")); + HloInstruction* param2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, r0f32, "param2")); + HloInstruction* div = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, param1)); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, div, param2)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Divide(op::Divide(param0, param1), param2)); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Divide(param0, op::Multiply(param1, param2))); +} + +// Test that A/(B/C) is simplified to (A*C)/B. +TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32, "param1")); + HloInstruction* param2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, r0f32, "param2")); + HloInstruction* div = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param1, param2)); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, div)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Divide(param0, op::Divide(param1, param2))); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Divide(op::Multiply(param0, param2), param1)); +} + +// Test that (A/B)/(C/D) is simplified to (A*D)/(B*C). +TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32, "param1")); + HloInstruction* param2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, r0f32, "param2")); + HloInstruction* param3 = builder.AddInstruction( + HloInstruction::CreateParameter(3, r0f32, "param3")); + HloInstruction* div0 = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, param1)); + HloInstruction* div1 = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param2, param3)); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, div0, div1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT( + computation->root_instruction(), + op::Divide(op::Divide(param0, param1), op::Divide(param2, param3))); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_THAT( + computation->root_instruction(), + op::Divide(op::Multiply(param0, param3), op::Multiply(param1, param2))); +} + +// Test that A/exp(B) is simplified to A*exp(-B). +TEST_F(AlgebraicSimplifierTest, DivOfExp) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32, "param1")); + HloInstruction* exp = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1)); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, exp)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Divide(param0, op::Exp(param1))); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Multiply(param0, op::Exp(op::Negate(param1)))); +} + +// Test that A/pow(B,C) is simplified to A*pow(B,-C). +TEST_F(AlgebraicSimplifierTest, DivOfPower) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32, "param1")); + HloInstruction* param2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, r0f32, "param2")); + HloInstruction* power = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param1, param2)); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, power)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Divide(param0, op::Power(param1, param2))); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Multiply(param0, op::Power(param1, op::Negate(param2)))); +} + // Test that A/1 is simplified to A for a scalar. TEST_F(AlgebraicSimplifierTest, DivOneScalar) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); @@ -145,7 +294,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneScalar) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* one = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); HloInstruction* div = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one)); @@ -167,7 +316,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r2f32, "param0")); HloInstruction* one = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 1.0}, {1.0, 1.0}}))); + Literal::CreateR2({{1.0, 1.0}, {1.0, 1.0}}))); HloInstruction* div = builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one)); @@ -239,6 +388,89 @@ TEST_F(AlgebraicSimplifierTest, ExpDiv) { op::Exp(op::Subtract(param0, param1))); } +// Test that exp(A)*exp(B) is simplified to exp(A+B) +TEST_F(AlgebraicSimplifierTest, ExpMul) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32, "param1")); + HloInstruction* exp0 = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0)); + HloInstruction* exp1 = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1)); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kMultiply, exp0, exp1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Multiply(op::Exp(param0), op::Exp(param1))); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Exp(op::Add(param0, param1))); +} + +// Test that pow(exp(A), B) is simplified to exp(A*B) +TEST_F(AlgebraicSimplifierTest, PowExp) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32, "param1")); + HloInstruction* exp0 = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0)); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, exp0, param1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Power(op::Exp(param0), param1)); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Exp(op::Multiply(param0, param1))); +} + +// Test that ln(pow(A, B)) is simplified to ln(A)*B +TEST_F(AlgebraicSimplifierTest, LnPow) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32, "param1")); + HloInstruction* pow = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, param1)); + builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, pow)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Log(op::Power(param0, param1))); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Multiply(op::Log(param0), param1)); +} + // Test that ln(exp(A)) is simplified to A TEST_F(AlgebraicSimplifierTest, LnExp) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); @@ -300,7 +532,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + HloInstruction::CreateConstant(Literal::CreateR0(0))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero)); @@ -315,7 +547,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) { HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Constant()); - EXPECT_EQ(LiteralUtil::GetFirstElement(root->literal()), 1); + EXPECT_EQ(root->literal().GetFirstElement(), 1); } // Test that pow(A, 0) where A is not a scalar is simplified to broadcast(1). @@ -325,7 +557,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r1f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + HloInstruction::CreateConstant(Literal::CreateR0(0))); builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero)); @@ -344,8 +576,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) { << ShapeUtil::HumanString(root->shape()); EXPECT_EQ(root->dimensions().size(), 0); EXPECT_TRUE(ShapeUtil::IsScalar(root->operand(0)->shape())); - EXPECT_EQ(LiteralUtil::GetFirstElement(root->operand(0)->literal()), - 1); + EXPECT_EQ(root->operand(0)->literal().GetFirstElement(), 1); } // Test that pow(A, 1) is simplified to A. @@ -355,7 +586,7 @@ TEST_F(AlgebraicSimplifierTest, Pow1) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* one = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + HloInstruction::CreateConstant(Literal::CreateR0(1))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one)); @@ -378,7 +609,7 @@ TEST_F(AlgebraicSimplifierTest, Pow2) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* two = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); + HloInstruction::CreateConstant(Literal::CreateR0(2))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two)); @@ -401,7 +632,7 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* negative_one = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(-1))); + HloInstruction::CreateConstant(Literal::CreateR0(-1))); builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, negative_one)); @@ -416,8 +647,7 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Divide(op::Constant(), param0)); - EXPECT_EQ(LiteralUtil::GetFirstElement(root->operand(0)->literal()), - 1); + EXPECT_EQ(root->operand(0)->literal().GetFirstElement(), 1); } TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { @@ -451,7 +681,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) { HloComputation::Builder builder(TestName()); HloInstruction* input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); @@ -519,7 +749,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { HloInstruction* param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, r1f32, "param1")); HloInstruction* empty_literal = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({}))); + HloInstruction::CreateConstant(Literal::CreateR1({}))); HloInstruction* empty_slice = builder.AddInstruction(HloInstruction::CreateSlice( ShapeUtil::MakeShape(F32, {0}), param1, {42}, {42}, {1})); @@ -550,7 +780,7 @@ TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r1f32, "param0")); HloInstruction* empty_literal = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({}))); + HloInstruction::CreateConstant(Literal::CreateR1({}))); HloInstruction* empty_slice = builder.AddInstruction(HloInstruction::CreateSlice( ShapeUtil::MakeShape(F32, {0}), param0, {42}, {42}, {1})); @@ -735,7 +965,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}), param)); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); builder.AddInstruction( HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}), HloOpcode::kMaximum, movable_reshape, zero)); @@ -753,6 +983,34 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) { op::Reshape(op::Maximum(param, zero))); } +// Regression test for a bug in the reshape sinking transformation, where +// moving a reshape to a scalar led to a crash. +TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) { + HloComputation::Builder builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 1}), "param")); + HloInstruction* reshape = builder.AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {}), param)); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({1., 2., 3.}))); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kMaximum, reshape, zero)); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Maximum(op::Reshape(param), zero)); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + bitcasting_callback()); + + simplifier.Run(module.get()).ValueOrDie(); + + EXPECT_THAT(computation->root_instruction(), + op::Maximum(op::Reshape(param), zero)); +} + TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { HloComputation::Builder builder(TestName()); HloInstruction* param = @@ -1035,7 +1293,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {2, 2}), "param")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); PaddingConfig no_padding; for (int i = 0; i < 2; ++i) { auto dimension = no_padding.add_dimensions(); @@ -1066,7 +1324,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {10, 10}), "param")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); PaddingConfig padding; int64 low_padding[2] = {-1, -2}; int64 high_padding[2] = {2, -3}; @@ -1376,9 +1634,9 @@ TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* min_value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); HloInstruction* max_value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); HloInstruction* min = builder.AddInstruction(HloInstruction::CreateBinary( r0f32, HloOpcode::kMinimum, param0, min_value)); builder.AddInstruction( @@ -1406,9 +1664,9 @@ TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* min_value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); HloInstruction* max_value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( r0f32, HloOpcode::kMaximum, param0, max_value)); builder.AddInstruction( @@ -1437,9 +1695,9 @@ TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r1f32, "param0")); HloInstruction* min_value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); HloInstruction* max_value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( r1f32, HloOpcode::kMaximum, param0, max_value)); builder.AddInstruction( @@ -1497,9 +1755,9 @@ TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* min_value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); HloInstruction* max_value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( r0f32, HloOpcode::kMaximum, param0, max_value)); HloInstruction* fmax = builder.AddInstruction( @@ -1566,7 +1824,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { HloComputation::Builder builder(TestName()); HloInstruction* forty_two = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6}); HloInstruction* broadcast = @@ -1614,7 +1872,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { padding.mutable_dimensions(3)->set_edge_padding_high(2); HloInstruction* pad_value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(5.0f))); HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {1, 3, 3, 5}), operand, pad_value, padding)); @@ -1645,7 +1903,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { const Shape reduce_window_shape = ShapeUtil::MakeShape(F32, {111, 113, 113, 115}); HloInstruction* reduce_init_value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(5.0f))); HloInstruction* reduce_window = builder.AddInstruction(HloInstruction::CreateReduceWindow( reduce_window_shape, pad, reduce_init_value, window, @@ -1714,9 +1972,9 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { HloComputation::Builder call_builder(TestName() + ".Call"); HloInstruction* zero = call_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({0.0f}))); + HloInstruction::CreateConstant(Literal::CreateR1({0.0f}))); HloInstruction* one = call_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({1.0f}))); + HloInstruction::CreateConstant(Literal::CreateR1({1.0f}))); builder.AddInstruction( HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get())); @@ -1728,6 +1986,26 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); } +// Test that a constant with tuple shape becomes a tuple of constants. +TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) { + HloComputation::Builder builder(TestName()); + const float constant_scalar = 7.3f; + std::initializer_list constant_vector = {1.1f, 2.0f, 3.3f}; + std::unique_ptr value = + Literal::MakeTuple({Literal::CreateR0(constant_scalar).get(), + Literal::CreateR1(constant_vector).get()}); + builder.AddInstruction(HloInstruction::CreateConstant(std::move(value))); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), + op::Tuple(op::Constant(), op::Constant())); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index 66d54ad3802fe442decd11335eddf74bdd1cf950..9abe30e3f371cc294c36c1dcd743224b11b0c4f5 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -22,7 +22,6 @@ limitations under the License. #define EIGEN_USE_THREADS #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/compiler/xla/legacy_flags/backend_flags.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -51,13 +50,6 @@ perftools::gputools::Platform* BackendOptions::platform() const { return platform_; } -BackendOptions& BackendOptions::set_number_of_replicas(int number_of_replicas) { - number_of_replicas_ = number_of_replicas; - return *this; -} - -int BackendOptions::number_of_replicas() const { return number_of_replicas_; } - BackendOptions& BackendOptions::set_intra_op_parallelism_threads( int num_threads) { intra_op_parallelism_threads_ = num_threads; @@ -85,20 +77,17 @@ struct Backend::EigenThreadPoolWrapper { /* static */ StatusOr> Backend::CreateBackend( const BackendOptions& options) { - int64 replica_count = options.number_of_replicas(); - if (replica_count == -1) { - legacy_flags::BackendFlags* flags = legacy_flags::GetBackendFlags(); - replica_count = flags->xla_replicas; - } perftools::gputools::Platform* platform = options.platform(); TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform)); TF_ASSIGN_OR_RETURN(auto stream_executors, PlatformUtil::GetStreamExecutors(platform)); TF_ASSIGN_OR_RETURN(auto transfer_manager, TransferManager::GetForPlatform(platform)); + TF_ASSIGN_OR_RETURN(auto computation_placer, + ComputationPlacer::GetForPlatform(platform)); std::unique_ptr backend( - new Backend(replica_count, platform, compiler, stream_executors, - transfer_manager, options.intra_op_parallelism_threads())); + new Backend(platform, compiler, stream_executors, transfer_manager, + computation_placer, options.intra_op_parallelism_threads())); return std::move(backend); } @@ -132,34 +121,25 @@ StatusOr Backend::BorrowStream( } Backend::Backend( - int64 replica_count, perftools::gputools::Platform* platform, - Compiler* compiler, + perftools::gputools::Platform* platform, Compiler* compiler, tensorflow::gtl::ArraySlice stream_executors, - TransferManager* transfer_manager, int intra_op_parallelism_threads) + TransferManager* transfer_manager, ComputationPlacer* computation_placer, + int intra_op_parallelism_threads) : platform_(platform), compiler_(compiler), transfer_manager_(transfer_manager), - replica_count_(replica_count) { + computation_placer_(computation_placer) { // The given set of stream executors set may include invalid executors. for (se::StreamExecutor* exec : stream_executors) { if (exec != nullptr) { stream_executors_.push_back(exec); } } - CHECK_GE(replica_count, 1) << "Must request at least 1 replica."; - // Create a memory allocator for the valid stream executors. memory_allocator_ = MakeUnique(platform, stream_executors); - - // First check that there are some non-null stream executors to avoid issuing - // an error mentioning replicas in the common case of requesting just 1 - // replica, which means no replication. CHECK(!stream_executors_.empty()) << "Service found no devices for backend " << platform_->Name() << '.'; - CHECK_GE(stream_executors_.size(), replica_count) - << "Requested more replicas than there are devices for backend " - << platform_->Name() << '.'; if (platform->id() == se::host::kHostPlatformId) { inter_op_thread_pool_.reset(new tensorflow::thread::ThreadPool( @@ -179,36 +159,6 @@ int Backend::default_device_ordinal() const { return default_stream_executor()->device_ordinal(); } -StatusOr> Backend::Replicas( - int device_ordinal) const { - if (stream_executors_[device_ordinal] == nullptr) { - return InvalidArgument("device %s not supported by XLA service", - device_name(device_ordinal).c_str()); - } - - // Find replica_count_ stream executors starting from the given device - // ordinal. - std::vector replicas; - for (se::StreamExecutor* exec : stream_executors_) { - CHECK(exec != nullptr); - if (exec->device_ordinal() >= device_ordinal) { - replicas.push_back(exec); - if (replicas.size() >= replica_count_) { - return replicas; - } - } - } - - return InvalidArgument( - "Not enough devices for replicas for the device ordinal %d", - device_ordinal); -} - -std::vector Backend::Replicas() const { - CHECK_GE(stream_executors_.size(), replica_count_); - return Replicas(default_device_ordinal()).ValueOrDie(); -} - tensorflow::thread::ThreadPool* Backend::inter_op_thread_pool() const { return inter_op_thread_pool_.get(); } diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index e0b15dc43f25244bc1a3e3c5cdc45877d4d11804..b5ca483b7274d20c31e932d748b6a4c9dea926f9 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/pool.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -46,12 +47,6 @@ class BackendOptions { BackendOptions& set_platform(perftools::gputools::Platform* platform); perftools::gputools::Platform* platform() const; - // Set the number of replicas to use when compiling replicated - // programs. The default is -1 meaning that the value is read from - // the xla_replicas flag. - BackendOptions& set_number_of_replicas(int number_of_replicas); - int number_of_replicas() const; - // Sets the thread pool size for parallel execution of an individual operator. // The default value of -1 will result in initializing the thread pool with // the number of threads equal to the number of cores in the system. @@ -60,7 +55,6 @@ class BackendOptions { private: perftools::gputools::Platform* platform_ = nullptr; - int number_of_replicas_ = -1; int intra_op_parallelism_threads_ = -1; }; @@ -74,8 +68,7 @@ class Backend { public: using StreamPtr = Pool::SmartPtr; - // Creates a new backend for the given platform with the given number of - // replicas. + // Creates a new backend. static StatusOr> CreateBackend( const BackendOptions& options); @@ -92,6 +85,7 @@ class Backend { return memory_allocator_.get(); } TransferManager* transfer_manager() const { return transfer_manager_; } + ComputationPlacer* computation_placer() const { return computation_placer_; } // Returns the number of devices of the platform type which are visible. Not // all of these devices may be usable by XLA. @@ -107,24 +101,13 @@ class Backend { return stream_executors_; } - // Returns the replicas for the default stream executor. - // - // When the number of replicas is R, the first R stream executors are assigned - // to the replicas of the default stream executor. - std::vector Replicas() const; - - // Returns the replicas for the given device_ordinal. The given device ordinal - // is considered to be the first device ordinal among the replicas. Returns an - // error status if the stream executor for the given given device ordinal does - // not exist or if there are not enough stream executors for the replicas. - StatusOr> Replicas( - int device_ordinal) const; - - // Return the stream executor for the given device ordinal. + // Returns the stream executor for the given device ordinal. StatusOr stream_executor( int device_ordinal) const; - // Return the stream executor for the default device ordinal. + // Returns the stream executor for the default device ordinal. This stream + // executor can only be used when the number of computations is 1 (replication + // can be > 1). perftools::gputools::StreamExecutor* default_stream_executor() const { CHECK(!stream_executors_.empty()); return stream_executors_[0]; @@ -174,18 +157,19 @@ class Backend { private: struct EigenThreadPoolWrapper; - Backend(int64 replica_count, perftools::gputools::Platform* platform, - Compiler* compiler, + Backend(perftools::gputools::Platform* platform, Compiler* compiler, tensorflow::gtl::ArraySlice stream_executors, - TransferManager* transfer_manager, int intra_op_parallelism_threads); + TransferManager* transfer_manager, + ComputationPlacer* computation_placer, + int intra_op_parallelism_threads); Backend(const Backend&) = delete; Backend& operator=(const Backend&) = delete; perftools::gputools::Platform* platform_; Compiler* compiler_; TransferManager* transfer_manager_; - int64 replica_count_ = -1; + ComputationPlacer* computation_placer_; // Vector of stream executors. stream_executors_[0] is the default executor. std::vector stream_executors_; diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index f91eb0207a23fe55394d59ed99a0d08cf16aa285..f372b18f7e7b91a300cbcf483c69434de2850bd6 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -22,12 +22,12 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -66,6 +66,7 @@ BufferAllocation::Slice BufferAllocation::GetSlice( void BufferAllocation::AddAssignment(const LogicalBuffer& buffer, int64 offset, int64 size) { + VLOG(4) << "Trying to add " << buffer << " to " << this; CHECK(assigned_buffers_.count(&buffer) == 0) << "LogicalBuffer " << buffer << " already assigned to allocation " << index_; @@ -212,10 +213,14 @@ bool BufferAssignment::HasTopLevelAllocation( StatusOr BufferAssignment::GetUniqueSlice( const HloInstruction* instruction, const ShapeIndex& index) const { + VLOG(3) << "Trying to find unique slice for " << instruction->name() << " [" + << index << "]"; BufferAllocation::Slice result; for (const LogicalBuffer* buffer : GetPointsToSet(instruction).element(index)) { + VLOG(3) << "Examining buffer " << *buffer; if (HasAllocation(*buffer)) { + VLOG(3) << "Has allocation"; const BufferAllocation::Slice slice = GetAssignedAllocation(*buffer).GetSlice(*buffer); if (result.allocation() == nullptr) { @@ -226,6 +231,8 @@ StatusOr BufferAssignment::GetUniqueSlice( "be determined at compile-time.", instruction->name().c_str(), index.ToString().c_str()); } + } else { + VLOG(3) << "No allocation"; } } if (result.allocation() == nullptr) { @@ -320,8 +327,9 @@ void BufferAssignment::CombineTempAllocations() { // Each temp allocation is placed end-to-end, accounting for alignment. // The offset of each buffer in the combined allocation is computed from // the base offset of the allocation. + int64 alignment = color_alignment_(color); const int64 base = - RoundUpToNearest(combined_allocation->size(), alignment_); + RoundUpToNearest(combined_allocation->size(), alignment); combined_allocation->set_size(base + temp_allocation.size()); for (const auto& buffer_offset_size : temp_allocation.assigned_buffers_) { const LogicalBuffer* buffer = buffer_offset_size.first; @@ -575,12 +583,13 @@ Status GatherComputationsByAllocationType( /* static */ StatusOr> BufferAssigner::Run( const HloModule* module, std::unique_ptr hlo_ordering, - LogicalBuffer::SizeFunction buffer_size, int64 alignment, + LogicalBuffer::SizeFunction buffer_size, + LogicalBuffer::AlignmentFunction color_alignment, bool allow_input_output_aliasing, TuplePointsToAnalysis::Colorer colorer) { - BufferAssigner assigner(alignment, allow_input_output_aliasing, - std::move(colorer)); + BufferAssigner assigner(allow_input_output_aliasing, std::move(colorer)); return assigner.CreateAssignment(module, std::move(hlo_ordering), - std::move(buffer_size)); + std::move(buffer_size), + std::move(color_alignment)); } bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, @@ -662,7 +671,8 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, } Status BufferAssigner::AssignBuffersForComputation( - const HloComputation* computation, bool is_thread_local, + const HloComputation* computation, const DebugOptions& debug_options, + bool is_thread_local, const FlatSet& colocated_buffers, const FlatSet& colocated_allocations, FlatMap>* @@ -786,10 +796,7 @@ Status BufferAssigner::AssignBuffersForComputation( continue; } - legacy_flags::BufferAssignmentFlags* flags = - legacy_flags::GetBufferAssignmentFlags(); - if (!flags->xla_enable_buffer_reuse || is_thread_local || - instruction->opcode() == HloOpcode::kCustomCall) { + if (is_thread_local || instruction->opcode() == HloOpcode::kCustomCall) { // Custom call operations never have reusable buffers. Also we do not // reuse thread-local buffers for now, because they are dynamically // allocated and their lifetimes are hard to compute. @@ -938,11 +945,13 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( } auto color_map = SplitBuffersByColor(all_buffers_to_assign); for (auto& single_colored_set : color_map) { - VLOG(2) << "Simulating heap for color " << single_colored_set.first; + auto color = single_colored_set.first; + VLOG(2) << "Simulating heap for color " << color; + int64 alignment = assignment->color_alignment_(color); TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, HeapSimulator::Run(MakeUnique( - MakeUnique(alignment_)), + MakeUnique(alignment)), assignment->module(), module_sequence, assignment->points_to_analysis(), assignment->buffer_size_, @@ -963,11 +972,13 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( CHECK(instruction_sequence != nullptr) << computation->name(); auto color_map = SplitBuffersByColor(buffers_to_assign); for (auto& single_colored_set : color_map) { - VLOG(2) << "Simulating heap for color " << single_colored_set.first; + auto color = single_colored_set.first; + VLOG(2) << "Simulating heap for color " << color; + int64 alignment = assignment->color_alignment_(color); TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, HeapSimulator::Run(MakeUnique( - MakeUnique(alignment_)), + MakeUnique(alignment)), *computation, *instruction_sequence, assignment->points_to_analysis(), assignment->buffer_size_, @@ -1074,7 +1085,8 @@ void BufferAssigner::AddSetToColocatedBufferSets( // different while instructions. void BufferAssigner::AddWhileSetToColocatedBufferSets( const std::vector& colocated_set, - const LogicalBuffer* while_init_buffer, const HloInstruction* while_hlo, + const LogicalBuffer* while_init_buffer, + const LogicalBuffer* while_result_buffer, const HloInstruction* while_hlo, const HloComputation& computation, const BufferLiveness& buffer_liveness, const LogicalBuffer::SizeFunction& buffer_size, std::vector* colocated_buffer_sets) { @@ -1137,16 +1149,30 @@ void BufferAssigner::AddWhileSetToColocatedBufferSets( continue; } - // Skip predecessor set if the live range of any predecessor buffers - // overlaps with 'while_init_buffer'. Note that tuple element buffer - // forwarding can cause the same buffer to appear on both sides of the - // interference comparison below. - if (std::any_of( - predecessor_while_buffers.begin(), predecessor_while_buffers.end(), - [while_init_buffer, &buffer_liveness](const LogicalBuffer* buffer) { - return while_init_buffer->id() != buffer->id() && - buffer_liveness.MayInterfere(*while_init_buffer, *buffer); - })) { + // Skip predecessor set if the live range of any predecessor + // buffers overlaps with 'while_init_buffer' or + // 'while_result_buffer' (we need to check both since they're + // aliased together, but the points-to analysis is unaware of this + // aliasing). Note that tuple element buffer forwarding can cause + // the same buffer to appear on both sides of the interference + // comparison below. + auto may_interfere_with_init_or_result = [&](const LogicalBuffer* buffer) { + if (while_init_buffer->id() != buffer->id() && + buffer_liveness.MayInterfere(*while_init_buffer, *buffer)) { + return true; + } + + if (while_result_buffer->id() != buffer->id() && + buffer_liveness.MayInterfere(*while_result_buffer, *buffer)) { + return true; + } + + return false; + }; + + if (std::any_of(predecessor_while_buffers.begin(), + predecessor_while_buffers.end(), + may_interfere_with_init_or_result)) { continue; } @@ -1209,8 +1235,8 @@ void BufferAssigner::BuildColocatedBufferSets( AddBufferToColocatedSet(while_hlo->operand(0), index, points_to_analysis, &colocated_set); // Add while.result. - AddBufferToColocatedSet(while_hlo, index, points_to_analysis, - &colocated_set); + auto* result_buffer = AddBufferToColocatedSet( + while_hlo, index, points_to_analysis, &colocated_set); // Add while.cond.parameter. AddBufferToColocatedSet( while_hlo->while_condition()->parameter_instruction(0), index, @@ -1224,8 +1250,9 @@ void BufferAssigner::BuildColocatedBufferSets( while_hlo->while_body()->root_instruction(), index, points_to_analysis, &colocated_set); AddWhileSetToColocatedBufferSets( - colocated_set, init_buffer, while_hlo, *computation, - buffer_liveness, buffer_size, colocated_buffer_sets); + colocated_set, init_buffer, result_buffer, while_hlo, + *computation, buffer_liveness, buffer_size, + colocated_buffer_sets); }); } else if (opcode == HloOpcode::kCall) { const HloInstruction* call_hlo = instruction; @@ -1300,10 +1327,10 @@ void BufferAssigner::AssignColocatedBufferSets( StatusOr> BufferAssigner::CreateAssignment( const HloModule* module, std::unique_ptr hlo_ordering, - LogicalBuffer::SizeFunction buffer_size) { + LogicalBuffer::SizeFunction buffer_size, + LogicalBuffer::AlignmentFunction color_alignment) { TF_ASSIGN_OR_RETURN(std::unique_ptr liveness, - BufferLiveness::Run(module, std::move(hlo_ordering), - std::move(colorer_))); + BufferLiveness::Run(module, std::move(hlo_ordering))); VLOG(1) << "Assigning buffers to module " << module->name(); XLA_VLOG_LINES(2, module->ToString()); @@ -1311,8 +1338,9 @@ StatusOr> BufferAssigner::CreateAssignment( XLA_VLOG_LINES(3, liveness->points_to_analysis().ToString()); // Can't use MakeUnique because BufferAssignment constructor is private. - std::unique_ptr assignment(new BufferAssignment( - module, std::move(liveness), alignment_, std::move(buffer_size))); + std::unique_ptr assignment( + new BufferAssignment(module, std::move(liveness), std::move(buffer_size), + std::move(color_alignment))); // Assign buffers with the tightest constraints first (colocated buffer sets). // Once b/32491382 enables module-level liveness analysis, we may be able @@ -1323,6 +1351,10 @@ StatusOr> BufferAssigner::CreateAssignment( std::vector colocated_buffer_sets; BuildColocatedBufferSets(module, assignment->liveness(), assignment->buffer_size_, &colocated_buffer_sets); + TF_RETURN_IF_ERROR(colorer_(&assignment->mutable_points_to_analysis())); + VLOG(3) << "After coloring:"; + XLA_VLOG_LINES(3, assignment->points_to_analysis().ToString()); + AssignColocatedBufferSets(colocated_buffer_sets, assignment.get(), &colocated_buffers, &colocated_allocations); @@ -1337,9 +1369,9 @@ StatusOr> BufferAssigner::CreateAssignment( buffers_to_assign_sequentially; for (auto* computation : global_computations) { TF_RETURN_IF_ERROR(AssignBuffersForComputation( - computation, /*is_thread_local=*/false, colocated_buffers, - colocated_allocations, &buffers_to_assign_sequentially, - assignment.get())); + computation, module->config().debug_options(), + /*is_thread_local=*/false, colocated_buffers, colocated_allocations, + &buffers_to_assign_sequentially, assignment.get())); } // Assign buffers with sequential ordering, if any. If all global computations // are sequential, we can run heap simuation on the whole module, which @@ -1355,9 +1387,9 @@ StatusOr> BufferAssigner::CreateAssignment( for (auto* computation : thread_local_computations) { TF_RET_CHECK(computation != module->entry_computation()); TF_RETURN_IF_ERROR(AssignBuffersForComputation( - computation, /*is_thread_local=*/true, colocated_buffers, - colocated_allocations, /*buffers_to_assign_sequentially=*/nullptr, - assignment.get())); + computation, module->config().debug_options(), + /*is_thread_local=*/true, colocated_buffers, colocated_allocations, + /*buffers_to_assign_sequentially=*/nullptr, assignment.get())); } // Mark all buffers which may be live out of the entry computation as diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index b3933f11c1e6ae3e7ffcc990442183338788caf4..e0b49505d283ad92ab50d47a807bb31ed9c5f9d0 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -320,6 +320,10 @@ class BufferAssignment { return liveness_->points_to_analysis(); } + TuplePointsToAnalysis& mutable_points_to_analysis() const { + return liveness_->mutable_points_to_analysis(); + } + // Returns the BufferLiveness object used to construct this assignment. const BufferLiveness& liveness() const { return *liveness_; } @@ -351,12 +355,12 @@ class BufferAssignment { explicit BufferAssignment(const HloModule* module, std::unique_ptr liveness, - int64 alignment, - LogicalBuffer::SizeFunction buffer_size) + LogicalBuffer::SizeFunction buffer_size, + LogicalBuffer::AlignmentFunction color_alignment) : module_(module), liveness_(std::move(liveness)), - alignment_(alignment), - buffer_size_(std::move(buffer_size)) {} + buffer_size_(std::move(buffer_size)), + color_alignment_(std::move(color_alignment)) {} // Creates and returns a new BufferAllocation, with no assigned // LogicalBuffers. Ownership is maintained internally. @@ -402,11 +406,13 @@ class BufferAssignment { const HloModule* module_; const std::unique_ptr liveness_; - const int64 alignment_; // Function which returns the buffer size for a given logical buffer (shape). LogicalBuffer::SizeFunction buffer_size_; + // Function which returns the alignment for a given logical buffer color. + LogicalBuffer::AlignmentFunction color_alignment_; + Stats stats_; std::vector heap_simulator_traces_; @@ -417,36 +423,38 @@ class BufferAssignment { class BufferAssigner { public: // Build and return a BufferAssignment for the given module. The given - // HloOrdering is used to determine buffer liveness. buffer_size is a function - // which returns the size of a LogicalBuffer. Alignment is the minimum - // alignment of any buffer. allow_input_output_aliasing specifies whether - // input buffer are allowed to be reused as outbut buffers by the client code. + // HloOrdering is used to determine buffer liveness. buffer_size and + // color_alignment are functions which returns the size and alignment of a + // LogicalBuffer. allow_input_output_aliasing specifies whether input buffer + // are allowed to be reused as outbut buffers by the client code. static StatusOr> Run( const HloModule* module, std::unique_ptr hlo_ordering, - LogicalBuffer::SizeFunction buffer_size, int64 alignment, + LogicalBuffer::SizeFunction buffer_size, + LogicalBuffer::AlignmentFunction color_alignment, bool allow_input_output_aliasing = false, TuplePointsToAnalysis::Colorer colorer = TuplePointsToAnalysis::DefaultColorer()); private: - BufferAssigner(int64 alignment, bool allow_input_output_aliasing, + BufferAssigner(bool allow_input_output_aliasing, TuplePointsToAnalysis::Colorer colorer) - : alignment_(alignment), - allow_input_output_aliasing_(allow_input_output_aliasing), + : allow_input_output_aliasing_(allow_input_output_aliasing), colorer_(colorer) {} virtual ~BufferAssigner() = default; // Create a buffer assignment. StatusOr> CreateAssignment( const HloModule* module, std::unique_ptr hlo_ordering, - LogicalBuffer::SizeFunction buffer_size); + LogicalBuffer::SizeFunction buffer_size, + LogicalBuffer::AlignmentFunction color_alignment); // Assigns buffers to the instructions in the given computation. "assignment" // is modified to reflect the new buffer assignments. If is_thread_local is // true, then all assigned buffers have the is_thread_local flag set to // true. Status AssignBuffersForComputation( - const HloComputation* computation, bool is_thread_local, + const HloComputation* computation, const DebugOptions& debug_options, + bool is_thread_local, const tensorflow::gtl::FlatSet& colocated_buffers, const tensorflow::gtl::FlatSet& colocated_allocations, @@ -511,7 +519,8 @@ class BufferAssigner { // colocated buffers for while instructions. void AddWhileSetToColocatedBufferSets( const std::vector& colocated_set, - const LogicalBuffer* while_init_buffer, const HloInstruction* while_hlo, + const LogicalBuffer* while_init_buffer, + const LogicalBuffer* while_result_buffer, const HloInstruction* while_hlo, const HloComputation& computation, const BufferLiveness& buffer_liveness, const LogicalBuffer::SizeFunction& buffer_size, std::vector* colocated_buffer_sets); @@ -524,9 +533,6 @@ class BufferAssigner { SplitBuffersByColor( const tensorflow::gtl::FlatSet& buffers); - // Minimum alignment of any buffer. - int64 alignment_; - // If true, buffer assignments assumes that input parameter buffers and output // buffers can be shared if their sizes match. bool allow_input_output_aliasing_; diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 892f67a8812823a6f156dc6098bf6b39fa800d3c..21ba083dd9d9717b47f7a28bdafea623eaf0e0c5 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -85,17 +86,19 @@ class BufferAssignmentTest : public HloTestBase { int64 alignment = 1) { return BufferAssigner::Run( module, MakeUnique(module), - backend_->compiler()->BufferSizeBytesFunction(), alignment) + backend_->compiler()->BufferSizeBytesFunction(), + [alignment](LogicalBuffer::Color) { return alignment; }) .ConsumeValueOrDie(); } std::unique_ptr RunColoredBufferAssignment( HloModule* module, TuplePointsToAnalysis::Colorer colorer, int64 alignment = 1) { - return BufferAssigner::Run(module, - MakeUnique(module), - backend_->compiler()->BufferSizeBytesFunction(), - alignment, false, std::move(colorer)) + return BufferAssigner::Run( + module, MakeUnique(module), + backend_->compiler()->BufferSizeBytesFunction(), + [alignment](LogicalBuffer::Color) { return alignment; }, false, + std::move(colorer)) .ConsumeValueOrDie(); } @@ -105,7 +108,7 @@ class BufferAssignmentTest : public HloTestBase { auto param = builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x")); auto value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, value)); return builder.Build(); @@ -122,7 +125,7 @@ class BufferAssignmentTest : public HloTestBase { const string& name) { auto builder = HloComputation::Builder(name); auto const4 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(4))); + HloInstruction::CreateConstant(Literal::CreateR0(4))); auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, t_s32_f32v4_, "x")); auto index = builder.AddInstruction( @@ -147,9 +150,9 @@ class BufferAssignmentTest : public HloTestBase { const string& name) { auto builder = HloComputation::Builder(name); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + HloInstruction::CreateConstant(Literal::CreateR0(1))); auto constv = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); + Literal::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, t_s32_f32v4_, "x")); auto indexc = builder.AddInstruction( @@ -264,7 +267,7 @@ static bool BuffersDistinct(const std::vector& a, TEST_F(BufferAssignmentTest, ScalarConstant) { auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -278,9 +281,9 @@ TEST_F(BufferAssignmentTest, BufferForConst) { // no buffers assigned, and their consumer has a buffer. auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); + Literal::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({4.1f, 4.2f, 4.3f, 4.4f}))); + Literal::CreateR1({4.1f, 4.2f, 4.3f, 4.4f}))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, const0, const1)); auto module = CreateNewModule(); @@ -298,7 +301,7 @@ TEST_F(BufferAssignmentTest, BufferForOutputConst) { // This computation copies a constant to output. auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); + Literal::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); auto copy = builder.AddInstruction( HloInstruction::CreateUnary(const0->shape(), HloOpcode::kCopy, const0)); auto module = CreateNewModule(); @@ -378,12 +381,15 @@ TEST_F(BufferAssignmentTest, BasicUniquelyColored) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunColoredBufferAssignment( - module.get(), - [](const HloInstruction* instruction, const ShapeIndex& index) { - static int64 serial = 0; - return LogicalBuffer::Color(serial++); - }); + auto colorer = [](TuplePointsToAnalysis* points_to_analysis) { + int color = 0; + for (auto& buffer : points_to_analysis->logical_buffers()) { + buffer->set_color(LogicalBuffer::Color(color++)); + } + return Status::OK(); + }; + + auto buffers = RunColoredBufferAssignment(module.get(), colorer); // Distinct input buffers were assigned for parameters. BufferAllocation paramscalar_buffer = @@ -430,14 +436,23 @@ TEST_F(BufferAssignmentTest, BasicPartiallyColored) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunColoredBufferAssignment( - module.get(), - [](const HloInstruction* instruction, const ShapeIndex& index) { - return (instruction->opcode() == HloOpcode::kAdd || - instruction->opcode() == HloOpcode::kMultiply) - ? LogicalBuffer::Color(1) - : LogicalBuffer::Color(0); - }); + auto colorer = [](TuplePointsToAnalysis* points_to_analysis) { + for (auto& buffer : points_to_analysis->logical_buffers()) { + const auto& aliases = points_to_analysis->GetBufferAliases(*buffer); + for (const auto& alias : aliases) { + if (alias.instruction()->opcode() == HloOpcode::kAdd || + alias.instruction()->opcode() == HloOpcode::kMultiply) { + buffer->set_color(LogicalBuffer::Color(1)); + } + } + if (!buffer->has_color()) { + buffer->set_color(LogicalBuffer::Color(0)); + } + } + return Status::OK(); + }; + + auto buffers = RunColoredBufferAssignment(module.get(), colorer); // Distinct input buffers were assigned for parameters. BufferAllocation paramscalar_buffer = @@ -586,7 +601,7 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) { auto exp2 = builder.AddInstruction( HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, exp1)); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); auto reduce = builder.AddInstruction(HloInstruction::CreateReduce( /*shape=*/f32vec10_, /*operand=*/exp2, @@ -634,9 +649,9 @@ TEST_F(BufferAssignmentTest, ExampleWhile) { // Creates the main kernel and verifies instruction counts. auto builder = HloComputation::Builder(TestName()); auto const3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + HloInstruction::CreateConstant(Literal::CreateR0(0))); auto const4 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); + Literal::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({const3, const4})); auto while_op = builder.AddInstruction(HloInstruction::CreateWhile( @@ -996,9 +1011,10 @@ TEST_F(BufferAssignmentTest, TupleParameterAsOutput) { // Test a computation that returns a tuple parameter. auto builder = HloComputation::Builder(TestName()); auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}), - ShapeUtil::MakeShape(F32, {}), - ShapeUtil::MakeShape(S32, {42})}), + 0, + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}), + ShapeUtil::MakeShape(F32, {}), + ShapeUtil::MakeShape(S32, {42})}), "param0")); auto module = CreateNewModule(); @@ -1027,10 +1043,11 @@ TEST_F(BufferAssignmentTest, ElementOfNestedTupleParameterAsOutput) { // parameter. auto builder = HloComputation::Builder(TestName()); auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeTupleShape( - {ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}), - ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {42}), - ShapeUtil::MakeShape(S32, {101})})}), + 0, + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {42}), + ShapeUtil::MakeShape(S32, {101})})}), "param0")); auto tuple_element = builder.AddInstruction(HloInstruction::CreateGetTupleElement( @@ -1075,9 +1092,8 @@ TEST_F(BufferAssignmentTest, DISABLED_TupleConstantAsOutput) { // Test that a tuple constant which is forwarded to the computation output is // properly handled. auto builder = HloComputation::Builder(TestName()); - builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(0).get(), - LiteralUtil::CreateR0(1).get()}))); + builder.AddInstruction(HloInstruction::CreateConstant(Literal::MakeTuple( + {Literal::CreateR0(0).get(), Literal::CreateR0(1).get()}))); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -1369,9 +1385,9 @@ class WhileBufferAssignmentTest : public HloTestBase { builder.AddInstruction( HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + HloInstruction::CreateConstant(Literal::CreateR0(0))); auto ten = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(10))); + HloInstruction::CreateConstant(Literal::CreateR0(10))); builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, zero, ten)); return builder.Build(); @@ -1399,7 +1415,8 @@ class WhileBufferAssignmentTest : public HloTestBase { CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie(); return BufferAssigner::Run( module, MakeUnique(module, sequence), - ByteSizeOf, alignment) + ByteSizeOf, + [alignment](LogicalBuffer::Color) { return alignment; }) .ConsumeValueOrDie(); } @@ -1429,7 +1446,7 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { HloInstruction::CreateParameter(2, data_shape_, "weights1")); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0))); auto output0 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, zero, {1})); auto output1 = builder.AddInstruction( @@ -1484,7 +1501,7 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { HloInstruction::CreateParameter(1, data_shape_, "weights0")); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0))); auto output0 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, zero, {1})); auto output1 = builder.AddInstruction( @@ -1532,16 +1549,16 @@ TEST_F(BufferAssignmentTest, TwoCalls) { auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param")); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, constant1)); sub_computation = module->AddEmbeddedComputation(builder.Build(add)); } auto builder = HloComputation::Builder(TestName()); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + HloInstruction::CreateConstant(Literal::CreateR0(3.0))); auto call1 = builder.AddInstruction( HloInstruction::CreateCall(r0f32, {constant2}, sub_computation)); auto call2 = builder.AddInstruction( @@ -1565,6 +1582,105 @@ TEST_F(BufferAssignmentTest, TwoCalls) { EXPECT_TRUE(BuffersDistinct({call1}, {call2}, *assignment)); } +static bool IsPostOrderTraversal( + const std::vector& sequence) { + tensorflow::gtl::FlatSet seen_so_far; + auto has_not_been_seen_yet = [&](const HloInstruction* instruction) { + return seen_so_far.count(instruction) == 0; + }; + + for (auto instruction : sequence) { + if (std::any_of(instruction->operands().begin(), + instruction->operands().end(), has_not_been_seen_yet) || + std::any_of(instruction->control_predecessors().begin(), + instruction->control_predecessors().end(), + has_not_been_seen_yet)) { + return false; // Not a post order. + } + if (!seen_so_far.insert(instruction).second) { + return false; // Not a "traversal". + } + } + + return true; +} + +TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { + auto module = MakeUnique(TestName()); + auto builder = HloComputation::Builder(TestName()); + + auto zero = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0))); + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + + auto input0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape_, "input0")); + auto weights0 = builder.AddInstruction( + HloInstruction::CreateParameter(1, data_shape_, "weights0")); + auto output0 = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + + auto input1 = builder.AddInstruction( + HloInstruction::CreateParameter(2, data_shape_, "input1")); + auto weights1 = builder.AddInstruction( + HloInstruction::CreateParameter(3, data_shape_, "weights1")); + auto output1 = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape_, one, {1})); + + auto cond = + module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); + auto body = module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); + + auto tuple0 = builder.AddInstruction( + HloInstruction::CreateTuple({input0, weights0, output0})); + auto tuple1 = builder.AddInstruction( + HloInstruction::CreateTuple({input1, weights1, output1})); + + auto while0 = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple0)); + auto while1 = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple1)); + + auto root_add = builder.AddInstruction(HloInstruction::CreateBinary( + while0->shape(), HloOpcode::kAdd, while0, while1)); + module->AddEntryComputation(builder.Build()); + + RunCopyInsertion(module.get()); + + { + FlattenCallGraph flatten; + TF_ASSIGN_OR_ASSERT_OK(bool result, flatten.Run(module.get())); + EXPECT_TRUE(result); + } + + auto sequence = + CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie(); + + // To trigger b/38494731, we want a specific Hlo sequence for the + // root computation, so we overwrite that entry with a manually + // crafted sequence. + std::vector sequence_for_buffer_assigment = { + input1, weights1, one, output1, tuple1, while1, input0, + weights0, zero, output0, tuple0, while0, root_add}; + + // If this ASSERT_TRUE fails, we constructed a bogus sequence above + // and this test itself is buggy. + ASSERT_TRUE(IsPostOrderTraversal(sequence_for_buffer_assigment)); + + sequence[module->entry_computation()] = + std::move(sequence_for_buffer_assigment); + + auto assignment = + BufferAssigner::Run( + module.get(), + MakeUnique(module.get(), sequence), ByteSizeOf, + [](LogicalBuffer::Color) { return 1; }) + .ConsumeValueOrDie(); + + EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); +} + // Test buffer assignment for while nodes with multiple uses. // TODO(b/37245345): Fix buffer assignment for this case. TEST_F(WhileBufferAssignmentTest, DISABLED_TwoWhiles) { @@ -1577,7 +1693,7 @@ TEST_F(WhileBufferAssignmentTest, DISABLED_TwoWhiles) { HloInstruction::CreateParameter(1, data_shape_, "weights0")); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0))); auto output0 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, zero, {1})); diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc index 1b14c26340f6c1922bf35457fe7f1367ed953df0..6720a90ef85173e1b3116340e6ab906c54965c78 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -37,17 +37,15 @@ namespace xla { /* static */ StatusOr> BufferLiveness::Run( - const HloModule* module, std::unique_ptr hlo_ordering, - TuplePointsToAnalysis::Colorer colorer) { + const HloModule* module, std::unique_ptr hlo_ordering) { std::unique_ptr liveness( - new BufferLiveness(module, std::move(hlo_ordering), std::move(colorer))); + new BufferLiveness(module, std::move(hlo_ordering))); TF_RETURN_IF_ERROR(liveness->Analyze()); return std::move(liveness); } tensorflow::Status BufferLiveness::Analyze() { - TF_ASSIGN_OR_RETURN(points_to_analysis_, - TuplePointsToAnalysis::Run(module_, colorer_)); + TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module_)); for (auto& computation : module_->computations()) { // Gather all instructions whose buffers might alias other instructions into // the set aliased_buffers_. This includes those contained as a tuple @@ -122,7 +120,7 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, if (b.instruction()->IsUserOf(alias.instruction()) && !CanShareOperandBufferWithUser(alias.instruction(), alias.index(), b.instruction(), b.index(), - points_to_analysis())) { + &points_to_analysis())) { return false; } } diff --git a/tensorflow/compiler/xla/service/buffer_liveness.h b/tensorflow/compiler/xla/service/buffer_liveness.h index 9bb2564a8312f0d80e01f40cb18f99d5ad0e1771..7e484430eab466b7aa8aa88138acfc7b03bd9873 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.h +++ b/tensorflow/compiler/xla/service/buffer_liveness.h @@ -39,9 +39,7 @@ class BufferLiveness { // Constructs a buffer liveness object for the given module assuming the given // HLO instruction ordering. static StatusOr> Run( - const HloModule* module, std::unique_ptr hlo_ordering, - TuplePointsToAnalysis::Colorer colorer = - TuplePointsToAnalysis::DefaultColorer()); + const HloModule* module, std::unique_ptr hlo_ordering); // Returns true if the live range of the buffer containing the output of 'a' // may overlap with the live range of the buffer of 'b'. If instruction 'a' @@ -63,6 +61,9 @@ class BufferLiveness { const TuplePointsToAnalysis& points_to_analysis() const { return *points_to_analysis_; } + TuplePointsToAnalysis& mutable_points_to_analysis() const { + return *points_to_analysis_; + } // Returns the underlying hlo ordering used for this liveness analysis. const HloOrdering& hlo_ordering() const { return *hlo_ordering_; } @@ -71,11 +72,8 @@ class BufferLiveness { private: explicit BufferLiveness(const HloModule* module, - std::unique_ptr hlo_ordering, - TuplePointsToAnalysis::Colorer colorer) - : module_(module), - hlo_ordering_(std::move(hlo_ordering)), - colorer_(colorer) {} + std::unique_ptr hlo_ordering) + : module_(module), hlo_ordering_(std::move(hlo_ordering)) {} // Perform buffer liveness analysis. This method must be called prior to // MayInterfere or MaybeLiveOut. @@ -98,8 +96,6 @@ class BufferLiveness { tensorflow::gtl::FlatSet maybe_live_out_buffers_; std::unique_ptr points_to_analysis_; - - TuplePointsToAnalysis::Colorer colorer_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index fda44ff4d2df18b90d308617cf845c9946227249..a5f7cc0aebe856931a122eb4bf56f87666ee38a0 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -397,13 +397,11 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) { // computation. The buffer containing {0, 1} is copied by GetTupleElement, and // the buffers containing {3} and 3 are dead. auto builder = HloComputation::Builder(TestName()); - auto inner_tuple0 = - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(0).get(), - LiteralUtil::CreateR0(1).get()}); - auto inner_tuple1 = - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(3).get()}); + auto inner_tuple0 = Literal::MakeTuple( + {Literal::CreateR0(0).get(), Literal::CreateR0(1).get()}); + auto inner_tuple1 = Literal::MakeTuple({Literal::CreateR0(3).get()}); auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::MakeTuple({inner_tuple0.get(), inner_tuple1.get()}))); + Literal::MakeTuple({inner_tuple0.get(), inner_tuple1.get()}))); builder.AddInstruction(HloInstruction::CreateGetTupleElement( inner_tuple0->shape(), tuple_constant, 0)); @@ -450,7 +448,7 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( tuple_element0_shape, tuple_param0, 0)); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + Literal::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0)); @@ -462,7 +460,7 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( tuple_element1_shape, tuple_param0, 1)); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}))); + Literal::CreateR1({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}))); auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( tuple_element1_shape, HloOpcode::kAdd, tuple_element1, const1)); @@ -513,7 +511,7 @@ TEST_F(BufferLivenessTest, DependentTupleElements) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( tuple_element0_shape, tuple_param0, 0)); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + Literal::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0)); @@ -585,7 +583,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1)); auto update = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); + Literal::CreateR1({2.f, 2.f, 2.f}))); HloInstruction* slice = nullptr; if (update_uses_tuple_element1) { // Create a slice instruction as an additional user of 'gte1'. @@ -596,7 +594,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { } // Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(Literal::CreateR1({2}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); @@ -715,7 +713,7 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1)); auto update = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); + Literal::CreateR1({2.f, 2.f, 2.f}))); if (tuple_element1_has_two_uses) { // Add 'gte0' and 'gte1' to create another user of 'gte1'. @@ -724,7 +722,7 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { } // Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(Literal::CreateR1({2}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index fa7b2a309525dd80d655e10474c5d49f9da14ea8..b450e0c40074344778109ed2ba8b2238cff7940e 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -133,6 +133,37 @@ CallGraphNode& CallGraph::GetNode(const HloComputation* computation) { return nodes_[it->second]; } +bool CallGraph::DominatesHelper( + const HloComputation* a, const HloComputation* b, + tensorflow::gtl::FlatSet* visited) const { + if (a == b || ContainsKey(*visited, b)) { + // The call graph is guaranteed to be acyclic so any previously visited node + // we encounter was already determined to be dominated. + return true; + } + + const CallGraphNode& b_node = GetNode(b); + if (b_node.callers().empty()) { + // We reached a root node without hitting 'a'. 'a' does not dominate 'b'. + return false; + } + + // Walk up the callers of 'b' until we hit 'a' or a root node (no callers). + visited->insert(b); + for (const HloComputation* b_caller : b_node.callers()) { + if (!DominatesHelper(a, b_caller, visited)) { + return false; + } + } + return true; +} + +bool CallGraph::Dominates(const HloComputation* a, + const HloComputation* b) const { + tensorflow::gtl::FlatSet visited; + return DominatesHelper(a, b, &visited); +} + namespace { // Returns the call context of a computation which is called from contexts 'a' diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h index 7f9990f06d4fee4c52fa516fc2f6031f5dab2bb9..a3297ff534f429279fd4674517db545f289af627 100644 --- a/tensorflow/compiler/xla/service/call_graph.h +++ b/tensorflow/compiler/xla/service/call_graph.h @@ -189,6 +189,20 @@ class CallGraph { Status VisitNodes(const VisitorFunction& visitor_func, bool visit_unreachable_nodes = true) const; + // Returns true if 'a' dominates 'b' in the call graph. Computation 'a' + // dominates computation 'b' iff all callgraph paths in the caller-to-callee + // direction from a root computation to 'b' pass through computation + // 'a'. Trivially, a computation dominates itself. + bool Dominates(const HloComputation* a, const HloComputation* b) const; + + // Returns whether 'instruction' is contained in 'computation' either directly + // ('instruction->parent' is 'computation') or indirectly ('computation' + // dominates 'instruction->parent' in the call graph). + bool InstructionIsNestedIn(const HloInstruction* instruction, + const HloComputation* computation) const { + return Dominates(computation, instruction->parent()); + } + string ToString() const; private: @@ -205,6 +219,13 @@ class CallGraph { const VisitorFunction& visitor_func, const CallGraphNode& node, tensorflow::gtl::FlatSet* visited) const; + // Recursive helper for computing whether 'a' dominates 'b' in the call + // graph. 'b_ancestor' is the currently visited node (which starts at 'b'), + // and 'visited' is the set of computations which have been visited. + bool DominatesHelper( + const HloComputation* a, const HloComputation* b, + tensorflow::gtl::FlatSet* visited) const; + // The HLO module represented by this call graph. const HloModule* module_ = nullptr; diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc index e276473c90aa3fcc6b494537db6bceb841ade91e..3c22871b3bff193c27ee2eb639fe72306d532b97 100644 --- a/tensorflow/compiler/xla/service/call_graph_test.cc +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -81,7 +81,7 @@ class CallGraphTest : public HloTestBase { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, kScalarShape, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero)); return builder.Build(); @@ -314,6 +314,37 @@ TEST_F(CallGraphTest, ComplexGraph) { EXPECT_LT(index_of(cond_computation), index_of(a_computation)); EXPECT_LT(index_of(c_computation), index_of(b_computation)); EXPECT_LT(index_of(b_computation), index_of(a_computation)); + + // Verify dominance relations between computation in the graph. + + // Entry dominates everybody, and is dominated by no one except itself. + EXPECT_TRUE(call_graph->Dominates(entry_computation, entry_computation)); + EXPECT_TRUE(call_graph->Dominates(entry_computation, a_computation)); + EXPECT_TRUE(call_graph->Dominates(entry_computation, b_computation)); + EXPECT_TRUE(call_graph->Dominates(entry_computation, c_computation)); + EXPECT_TRUE(call_graph->Dominates(entry_computation, cond_computation)); + EXPECT_FALSE(call_graph->Dominates(a_computation, entry_computation)); + EXPECT_FALSE(call_graph->Dominates(b_computation, entry_computation)); + EXPECT_FALSE(call_graph->Dominates(c_computation, entry_computation)); + EXPECT_FALSE(call_graph->Dominates(cond_computation, entry_computation)); + + // 'a' only dominates 'b' and 'c'. + EXPECT_TRUE(call_graph->Dominates(a_computation, a_computation)); + EXPECT_TRUE(call_graph->Dominates(a_computation, b_computation)); + EXPECT_TRUE(call_graph->Dominates(a_computation, c_computation)); + EXPECT_FALSE(call_graph->Dominates(b_computation, a_computation)); + EXPECT_FALSE(call_graph->Dominates(c_computation, a_computation)); + EXPECT_FALSE(call_graph->Dominates(a_computation, cond_computation)); + + EXPECT_TRUE(call_graph->Dominates(b_computation, b_computation)); + EXPECT_FALSE(call_graph->Dominates(b_computation, c_computation)); + EXPECT_FALSE(call_graph->Dominates(b_computation, cond_computation)); + + EXPECT_TRUE(call_graph->Dominates(c_computation, c_computation)); + EXPECT_FALSE(call_graph->Dominates(c_computation, cond_computation)); + EXPECT_FALSE(call_graph->Dominates(cond_computation, c_computation)); + + EXPECT_TRUE(call_graph->Dominates(cond_computation, cond_computation)); } TEST_F(CallGraphTest, VisitSingletonComputation) { diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index 0d1a439724a95231240227cfdf089cb2d74b3dd2..dc81323c997723d13403c7c9c4dde3b6af62a9bf 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -52,17 +52,17 @@ CompileOnlyService::NewService(const ServiceOptions& options) { TF_ASSIGN_OR_RETURN(std::unique_ptr compute_constant_backend, CreateComputeConstantBackend()); - std::unique_ptr service( - new CompileOnlyService(compiler, std::move(compute_constant_backend))); + std::unique_ptr service(new CompileOnlyService( + options, compiler, std::move(compute_constant_backend))); return std::move(service); } CompileOnlyService::CompileOnlyService( - Compiler* compiler, std::unique_ptr compute_constant_backend) - : Service(/*backend=*/nullptr, std::move(compute_constant_backend)), - compiler_(compiler) { - runs_in_client_process_ = true; -} + const ServiceOptions& options, Compiler* compiler, + std::unique_ptr compute_constant_backend) + : Service(options, /*backend=*/nullptr, + std::move(compute_constant_backend)), + compiler_(compiler) {} StatusOr>> CompileOnlyService::CompileAheadOfTime( @@ -122,8 +122,7 @@ CompileOnlyService::CompileAheadOfTime( hlo_modules.push_back(std::move(hlo_module)); } - return compiler_->CompileAheadOfTime(std::move(hlo_modules), - MakeHloDumper(), options); + return compiler_->CompileAheadOfTime(std::move(hlo_modules), options); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h index b19f4bd592162045a41e2ec82266826ce84096ef..0a1911cbd15b0278ec2c3ccc944ce4df80a683ed 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.h +++ b/tensorflow/compiler/xla/service/compile_only_service.h @@ -55,7 +55,7 @@ class CompileOnlyService : public Service { // Override Service methods that require or imply the existence of an // execute backend. Note that this does not include TransferToClient, as - // computing contants produces global data that we may wish to transfer. + // computing constants produces global data that we may wish to transfer. tensorflow::Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override { return Unimplemented("CompileOnlyService does not support execution."); @@ -103,7 +103,8 @@ class CompileOnlyService : public Service { private: explicit CompileOnlyService( - Compiler* compiler, std::unique_ptr compute_constant_backend); + const ServiceOptions& options, Compiler* compiler, + std::unique_ptr compute_constant_backend); CompileOnlyService(const CompileOnlyService&) = delete; void operator=(const CompileOnlyService&) = delete; diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 7ae285170e4b99ecf036eeb81eaee49ef34034ea..d5bd9214be44f4abd5f672168335ae1a259c9118 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -92,13 +92,6 @@ class AotCompilationOptions { // platform. class Compiler { public: - // Callback signature used to dump the HLO graph during compilation. - // Different compiler backends will call this as they please, providing - // a view of the HLO at different points in compilation -- context for the - // dump is indicated by the label string. - using HloDumper = - std::function; - virtual ~Compiler() {} // Returns the ID of the platform that this compiler targets. @@ -113,21 +106,20 @@ class Compiler { // // Use the overload below to compile computations that run in parallel. virtual StatusOr> Compile( - std::unique_ptr module, HloDumper dump_hlo, + std::unique_ptr module, perftools::gputools::StreamExecutor* executor) = 0; // Compiles a set of HLO modules that can run in parallel, potentially // communicating data between the modules, and returns a corresponding // sequence of executable objects. virtual StatusOr>> Compile( - std::vector> modules, HloDumper dump_hlo, + std::vector> modules, std::vector stream_exec) = 0; // Compiles the HLO module for ahead-of-time execution. This is intended for // use in static compilation. virtual StatusOr>> CompileAheadOfTime(std::vector> modules, - HloDumper dump_hlo, const AotCompilationOptions& options) = 0; ///// diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc new file mode 100644 index 0000000000000000000000000000000000000000..cdfa30dd9a7b6a5b9e58087491a9d99caaa1b998 --- /dev/null +++ b/tensorflow/compiler/xla/service/computation_placer.cc @@ -0,0 +1,152 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/computation_placer.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace se = ::perftools::gputools; + +namespace xla { + +Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const { + proto->set_replica_count(replica_count()); + proto->set_computation_count(computation_count()); + for (int computation = 0; computation < computation_count(); ++computation) { + DeviceAssignmentProto::ComputationDevice* computation_device = + proto->add_computation_devices(); + for (int replica = 0; replica < replica_count(); ++replica) { + computation_device->add_replica_device_ids((*this)(replica, computation)); + } + } + return Status::OK(); +} + +/* static */ StatusOr> +DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) { + TF_RET_CHECK(proto.computation_devices_size() == proto.computation_count()); + auto assignment = MakeUnique(proto.replica_count(), + proto.computation_count()); + for (int computation = 0; computation < proto.computation_count(); + ++computation) { + const auto& computation_device = proto.computation_devices(computation); + TF_RET_CHECK(computation_device.replica_device_ids_size() == + proto.replica_count()); + for (int replica = 0; replica < proto.replica_count(); ++replica) { + (*assignment)(replica, computation) = + computation_device.replica_device_ids(replica); + } + } + return std::move(assignment); +} + +StatusOr ComputationPlacer::DeviceId(int replica, int computation, + int replica_count, + int computation_count) { + TF_RET_CHECK(replica < replica_count); + TF_RET_CHECK(computation < computation_count); + + return computation * replica_count + replica; +} + +StatusOr ComputationPlacer::AssignDevices( + int replica_count, int computation_count) { + DeviceAssignment assignment(replica_count, computation_count); + for (int replica = 0; replica < replica_count; ++replica) { + for (int computation = 0; computation < computation_count; ++computation) { + TF_ASSIGN_OR_RETURN( + int device_id, + DeviceId(replica, computation, replica_count, computation_count)); + assignment(replica, computation) = device_id; + } + } + return std::move(assignment); +} + +/* static */ void ComputationPlacer::RegisterComputationPlacer( + se::Platform::Id platform_id, + ComputationPlacerCreationFunction creation_function) { + tensorflow::mutex_lock lock( + *ComputationPlacer::platform_computation_placer_mutex()); + auto* computation_placers = GetPlatformComputationPlacers(); + CHECK(computation_placers->find(platform_id) == computation_placers->end()); + (*computation_placers)[platform_id].creation_function = creation_function; +} + +/* static */ StatusOr ComputationPlacer::GetForPlatform( + const se::Platform* platform) { + tensorflow::mutex_lock lock( + *ComputationPlacer::platform_computation_placer_mutex()); + auto* computation_placers = GetPlatformComputationPlacers(); + + auto it = computation_placers->find(platform->id()); + if (it == computation_placers->end()) { + return NotFound( + "could not find registered computation placer for platform %s -- check " + "target linkage", + platform->Name().c_str()); + } + + if (it->second.placer == nullptr) { + // Lazily create the computation placer the first time it is needed. + it->second.placer = (*it->second.creation_function)(); + } + + return it->second.placer.get(); +} + +/* static */ tensorflow::mutex* +ComputationPlacer::platform_computation_placer_mutex() { + static tensorflow::mutex* m = new tensorflow::mutex; + return m; +} + +/* static */ std::map* +ComputationPlacer::GetPlatformComputationPlacers() { + static auto* r = + new std::map; + return r; +} + +} // namespace xla + +static std::unique_ptr CreateComputationPlacer() { + return xla::MakeUnique(); +} + +static bool InitModule() { + xla::ComputationPlacer::RegisterComputationPlacer(se::host::kHostPlatformId, + &CreateComputationPlacer); + xla::ComputationPlacer::RegisterComputationPlacer(se::cuda::kCudaPlatformId, + &CreateComputationPlacer); + return true; +} +static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/computation_placer.h b/tensorflow/compiler/xla/service/computation_placer.h new file mode 100644 index 0000000000000000000000000000000000000000..7d9abcd100dd9e878da885110bc1bd1ac65e3f84 --- /dev/null +++ b/tensorflow/compiler/xla/service/computation_placer.h @@ -0,0 +1,117 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_PLACER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_PLACER_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Class that represents the device assignment for a set of XLA replicated +// computations. For R replicas and C computations, R * C devices are required +// execute the computation in parallel. The assigned device ids can be accessed +// by assignment(replica, computation). +class DeviceAssignment : public Array2D { + public: + DeviceAssignment() {} + DeviceAssignment(int replica_count, int computation_count) + : Array2D(replica_count, computation_count, -1) { + CHECK_GT(replica_count, 0); + CHECK_GT(computation_count, 0); + } + + int replica_count() const { return height(); } + int computation_count() const { return width(); } + + // Protocol buffer serialization and deserialization. + Status Serialize(DeviceAssignmentProto* proto) const; + + // Return a std::unique_ptr instead of a DeviceAssignment + // directly because one of the supported TF platforms (mac) does not compile + // due to a StatusOr of an incomplete type (DeviceAssignment). + static StatusOr> Deserialize( + const DeviceAssignmentProto& proto); +}; + +// A generic implementation of the XLA computation placer, which assigns device +// ids to a set of replicated computations. +class ComputationPlacer { + public: + ComputationPlacer() {} + virtual ~ComputationPlacer() {} + + // Returns the device id assigned to the given replica and computation + // instance for [replica_count x computation_count] setup. The returned device + // id must match the assignement from PlaceReplicatedComputation(). + virtual StatusOr DeviceId(int replica, int computation, + int replica_count, int computation_count); + + // Returns the device ids assigned to a set of replicated computations, given + // the number of replicas and the number of computations. + virtual StatusOr AssignDevices(int replica_count, + int computation_count); + + using ComputationPlacerCreationFunction = + std::unique_ptr (*)(); + + // Registers a computation placer creation function for a particular platform. + static void RegisterComputationPlacer( + perftools::gputools::Platform::Id platform_id, + ComputationPlacerCreationFunction creation_function); + + // Returns the computation placer singleton pointer if it is available for the + // given platform, or an error status if it is not. + static StatusOr GetForPlatform( + const perftools::gputools::Platform* platform); + + private: + // Routine that returns the mutex that guards the platform-to-computation + // placer map. Done as a routine to ensure correct initialization ordering, + // since RegisterComputationPlacer can be called during program initialization + // time. + static tensorflow::mutex* platform_computation_placer_mutex(); + + // State kept for each kind of ComputationPlacer. Registration functions set + // up creation_function, and then we use that to lazily create "placer" the + // first time GetForPlatform is invoked for a particular id. + struct State { + std::unique_ptr placer; + ComputationPlacerCreationFunction creation_function = nullptr; + }; + + // Map from platform kind to computation placer singleton. + static std::map* + GetPlatformComputationPlacers(); + + perftools::gputools::Platform::Id platform_id_; + + TF_DISALLOW_COPY_AND_ASSIGN(ComputationPlacer); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_PLACER_H_ diff --git a/tensorflow/compiler/xla/service/computation_tracker.cc b/tensorflow/compiler/xla/service/computation_tracker.cc index 9aa32a1fb76616e6c81043fabb053570a86d2619..70e25eebdb068db893e24aec0f72d09090ac7027 100644 --- a/tensorflow/compiler/xla/service/computation_tracker.cc +++ b/tensorflow/compiler/xla/service/computation_tracker.cc @@ -216,6 +216,7 @@ StatusOr> ComputationTracker::BuildHloModule( TF_ASSIGN_OR_RETURN( std::unique_ptr hlo_computation, computation->BuildHloComputation(versioned_handle.version, resolver, + config.debug_options(), include_unreachable_instructions)); // Add the newly created computation to VersionedHandle-to-HloComputation diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index cc77339bb63220d8c9da0500ee818c7b9fb02a4b..026be75757a9129c94e2c1c3083f226790d482f4 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -87,7 +87,7 @@ TEST_F(CopyInsertionTest, SingleParameter) { TEST_F(CopyInsertionTest, SingleConstant) { auto builder = HloComputation::Builder(TestName()); HloInstruction* constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); HloInstruction* tuple = builder.AddInstruction(HloInstruction::CreateTuple({constant})); @@ -110,9 +110,9 @@ TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) { auto builder = HloComputation::Builder(TestName()); HloInstruction* constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); HloInstruction* constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); HloInstruction* x = builder.AddInstruction( HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x")); @@ -140,11 +140,11 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { // the computation result. Verify that copies are added properly. auto builder = HloComputation::Builder(TestName()); HloInstruction* constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); HloInstruction* constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); HloInstruction* constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + HloInstruction::CreateConstant(Literal::CreateR0(3.0))); HloInstruction* tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); @@ -152,7 +152,7 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { HloInstruction::CreateTuple({constant3, constant2})); HloInstruction* pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); builder.AddInstruction(HloInstruction::CreateTernary( tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); @@ -196,9 +196,8 @@ TEST_F(CopyInsertionTest, BitcastConstant) { // The output of a bitcast is its operand (same buffer), so a bitcast // constant feeding the result must have a copy added. auto builder = HloComputation::Builder(TestName()); - HloInstruction* constant = - builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.0, 42.0}))); + HloInstruction* constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({1.0, 42.0}))); HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, constant)); @@ -308,9 +307,9 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { // copy is added. auto builder = HloComputation::Builder(TestName()); HloInstruction* constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); HloInstruction* constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); HloInstruction* tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); @@ -318,7 +317,7 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { HloInstruction::CreateTuple({constant2, constant1})); HloInstruction* pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); HloInstruction* select = builder.AddInstruction(HloInstruction::CreateTernary( tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); HloInstruction* gte = @@ -350,7 +349,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { bool nested = false) { auto builder = HloComputation::Builder(TestName() + ".Condition"); auto limit_const = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(10))); + HloInstruction::CreateConstant(Literal::CreateR0(10))); const Shape& loop_state_shape = nested ? nested_loop_state_shape_ : loop_state_shape_; auto loop_state = builder.AddInstruction( @@ -381,7 +380,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { builder.AddInstruction(HloInstruction::CreateGetTupleElement( induction_variable_shape_, loop_state, 0)); auto inc = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + HloInstruction::CreateConstant(Literal::CreateR0(1))); auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); // Update data GTE(1). @@ -419,7 +418,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { builder.AddInstruction(HloInstruction::CreateGetTupleElement( induction_variable_shape_, loop_state, 0)); auto inc = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + HloInstruction::CreateConstant(Literal::CreateR0(1))); // add0 = Add(in0, 1) auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( @@ -488,7 +487,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { builder.AddInstruction(HloInstruction::CreateGetTupleElement( induction_variable_shape_, loop_state, 0)); auto inc = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + HloInstruction::CreateConstant(Literal::CreateR0(1))); // add0 = Add(in0, 1) auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); @@ -503,9 +502,8 @@ class WhileCopyInsertionTest : public CopyInsertionTest { data = builder.AddInstruction( HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); } - auto update = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1( - {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); // add1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1}) auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( data_shape_, HloOpcode::kAdd, data, update)); @@ -538,7 +536,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto gte0 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( induction_variable_shape_, loop_state, 0)); auto inc = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + HloInstruction::CreateConstant(Literal::CreateR0(1))); auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( gte0->shape(), HloOpcode::kAdd, gte0, inc)); @@ -548,9 +546,8 @@ class WhileCopyInsertionTest : public CopyInsertionTest { // GTE(GTE(loop_state, 1), 0) -> Add auto gte10 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(data_shape_, gte1, 0)); - auto update10 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1( - {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + auto update10 = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); auto add10 = builder.AddInstruction(HloInstruction::CreateBinary( data_shape_, HloOpcode::kAdd, gte10, update10)); @@ -574,11 +571,10 @@ class WhileCopyInsertionTest : public CopyInsertionTest { bool nested = false) { auto builder = HloComputation::Builder(TestName() + ".While"); auto induction_var_init = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + HloInstruction::CreateConstant(Literal::CreateR0(0))); - auto data_init = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1( - {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); + auto data_init = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); if (nested) { auto inner_init = builder.AddInstruction( @@ -601,9 +597,8 @@ class WhileCopyInsertionTest : public CopyInsertionTest { HloInstruction* BuildWhileInstruction_InitPointsToConstant() { auto builder = HloComputation::Builder(TestName() + ".While"); - auto data_init = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1( - {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); + auto data_init = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init, &builder); } @@ -620,11 +615,11 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto builder = HloComputation::Builder(TestName() + ".While"); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto v1 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, one, {1})); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto v2 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, zero, {1})); @@ -632,7 +627,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto tuple2 = builder.AddInstruction(HloInstruction::CreateTuple({v2, v1})); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); auto data_init = builder.AddInstruction(HloInstruction::CreateTernary( nested_tuple_shape_, HloOpcode::kSelect, pred, tuple1, tuple2)); @@ -644,7 +639,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto builder = HloComputation::Builder(TestName() + ".While"); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto one_vec = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, one, {1})); auto data_init = @@ -657,12 +652,11 @@ class WhileCopyInsertionTest : public CopyInsertionTest { HloInstruction* BuildWhileInstruction_InitPointsToInterfering() { auto builder = HloComputation::Builder(TestName() + ".While"); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto data_init = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, one, {1})); - auto one_vec = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1( - {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); // Take a reference to 'data_init' to make it interfere with while result. builder.AddInstruction(HloInstruction::CreateBinary( data_shape_, HloOpcode::kAdd, data_init, one_vec)); @@ -677,7 +671,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { const bool nested = ShapeUtil::Equal(loop_state_shape, nested_loop_state_shape_); auto induction_var_init = builder->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + HloInstruction::CreateConstant(Literal::CreateR0(0))); auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(nested)); auto body = module_->AddEmbeddedComputation( diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 51ecbccd494fced68d5e92eda752f5292580a190..53410b09c87d55fa8595acd7ffb5a29cd2ddb1da 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -52,7 +52,6 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/service:algebraic_simplifier", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", @@ -69,6 +68,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_pass_pipeline", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:hlo_proto_util", + "//tensorflow/compiler/xla/service:hlo_scheduling", "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:inliner", @@ -151,9 +151,12 @@ cc_library( cc_library( name = "parallel_cpu_executable", srcs = ["parallel_cpu_executable.cc"], - hdrs = ["parallel_cpu_executable.h"], + hdrs = [ + "parallel_cpu_executable.h", + ], deps = [ ":cpu_runtime", + ":shape_partition", ":simple_orc_jit", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -177,7 +180,9 @@ cc_library( cc_library( name = "ir_emitter", srcs = ["ir_emitter.cc"], - hdrs = ["ir_emitter.h"], + hdrs = [ + "ir_emitter.h", + ], deps = [ ":cpu_runtime", ":dot_op_emitter", @@ -191,7 +196,6 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo", @@ -222,7 +226,6 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", @@ -283,8 +286,6 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/legacy_flags:compiler_functor_flags", - "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", "@llvm//:analysis", @@ -334,6 +335,7 @@ cc_library( copts = runtime_copts(), deps = [ "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", ], ) @@ -405,6 +407,7 @@ cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -437,10 +440,15 @@ cc_library( cc_library( name = "cpu_parallelization_preparation", srcs = ["cpu_parallelization_preparation.cc"], - hdrs = ["cpu_parallelization_preparation.h"], + hdrs = [ + "cpu_parallelization_preparation.h", + ], deps = [ + ":ir_emission_utils", + ":shape_partition", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:logical_buffer", "//tensorflow/compiler/xla/service:tuple_points_to_analysis", @@ -472,7 +480,6 @@ cc_library( ":cpu_runtime", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:window_util", - "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", "//tensorflow/compiler/xla/service:hlo", ], ) @@ -499,9 +506,9 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:lib", ], ) @@ -511,6 +518,7 @@ cc_test( srcs = ["conv_canonicalization_test.cc"], deps = [ ":conv_canonicalization", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo", @@ -518,6 +526,26 @@ cc_test( ], ) +cc_library( + name = "shape_partition", + srcs = ["shape_partition.cc"], + hdrs = ["shape_partition.h"], + deps = [ + "//tensorflow/compiler/xla/service:hlo", + ], +) + +cc_test( + name = "shape_partition_test", + srcs = ["shape_partition_test.cc"], + deps = [ + ":shape_partition", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test_main", + ], +) # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index 8ebf9ab110d080a017abb2077ac588672c8099bb..cd6cc09cfb55ba6b4697481cb6f8698610ecda92 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -35,8 +35,6 @@ limitations under the License. #include "external/llvm/include/llvm/Transforms/IPO.h" #include "external/llvm/include/llvm/Transforms/IPO/AlwaysInliner.h" #include "external/llvm/include/llvm/Transforms/IPO/PassManagerBuilder.h" -#include "tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h" @@ -45,7 +43,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -66,14 +63,9 @@ operator()(llvm::Module& module) const { VLOG(2) << "IR before optimizations"; XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(module)); - legacy_flags::CompilerFunctorFlags* flags = - legacy_flags::GetCompilerFunctorFlags(); - string dump_path = flags->xla_debug_cpu_dump_ir; - if (!dump_path.empty()) { - std::unique_ptr f; - TF_CHECK_OK(tensorflow::Env::Default()->NewAppendableFile(dump_path, &f)); - TF_CHECK_OK(f->Append(llvm_ir::DumpModuleToString(module))); - TF_CHECK_OK(f->Close()); + + if (pre_optimization_callback_) { + TF_CHECK_OK(pre_optimization_callback_(module)); } // Build up optimization pipeline. @@ -99,6 +91,10 @@ operator()(llvm::Module& module) const { VLOG(2) << "IR after optimizations"; XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(module)); + if (post_optimization_callback_) { + TF_CHECK_OK(post_optimization_callback_(module)); + } + // Generate code. llvm::MCContext* mc_context; llvm::legacy::PassManager codegen_passes; @@ -156,12 +152,7 @@ std::vector VectorFunctionsForTargetLibraryInfoImpl( {"llvm.tanh.f32", runtime::kTanhV8F32, 8}, }; - // Our vectorized library calls are currently implement by calling into Eigen. - // As such, only emit calls to these routines if --xla_cpu_use_eigen is - // enabled. - legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags(); - if (flags->xla_cpu_use_eigen && - (arch == llvm::Triple::x86 || llvm::Triple::x86_64)) { + if (arch == llvm::Triple::x86 || llvm::Triple::x86_64) { llvm::SmallVector features; feature_string.split(features, ',', -1, /*KeepEmpty=*/false); if (std::find(features.begin(), features.end(), "+sse4.1") != diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.h b/tensorflow/compiler/xla/service/cpu/compiler_functor.h index 17dadebe975b936b7d5d7a78ac69b890d9c8e7ac..94611b5814079140d99fbbadd9afc6248695f5e5 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.h +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.h @@ -39,13 +39,22 @@ class CompilerFunctor { // Returns a VectorIntrinsics where all intrinsics are available. static VectorIntrinsics AllIntrinsics(); - explicit CompilerFunctor(llvm::TargetMachine* target_machine, - const Disassembler* disassembler, int opt_level, - const VectorIntrinsics& available_intrinsics) + // A callback of this type can be run before and/or after IR-level + // optimization to e.g. dump out the generated IR to disk or gather some + // statistics. + using OptimizationCallback = std::function; + + explicit CompilerFunctor( + llvm::TargetMachine* target_machine, const Disassembler* disassembler, + int opt_level, const VectorIntrinsics& available_intrinsics, + OptimizationCallback pre_optimization_callback = nullptr, + OptimizationCallback post_optimization_callback = nullptr) : target_machine_(target_machine), disassembler_(CHECK_NOTNULL(disassembler)), opt_level_(opt_level), - available_intrinsics_(available_intrinsics) {} + available_intrinsics_(available_intrinsics), + pre_optimization_callback_(pre_optimization_callback), + post_optimization_callback_(post_optimization_callback) {} // Compile a Module to an ObjectFile. llvm::object::OwningBinary operator()( @@ -61,6 +70,8 @@ class CompilerFunctor { const Disassembler* disassembler_; const unsigned opt_level_; const VectorIntrinsics available_intrinsics_; + OptimizationCallback pre_optimization_callback_; + OptimizationCallback post_optimization_callback_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc index cdf43587b683e4e22d14d4fc08fa3705bc636de8..069979c6611e90ed2d95cbbe341198577cdf56cf 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -30,11 +29,6 @@ namespace xla { namespace cpu { StatusOr ConvCanonicalization::Run(HloModule* module) { - legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags(); - if (!flags->xla_cpu_use_eigen) { - return false; - } - bool changed = false; for (HloInstruction* hlo : module->entry_computation()->MakeInstructionPostOrder()) { diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index f5ad431277d94039cd20cf51e0932413e87a0436..ec992f15e63b29ee67d16b6d841fedffd9c90f5b 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -59,11 +59,11 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { auto builder = HloComputation::Builder(TestName()); // The input dimensions are in CNHW order. auto input = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR4FromArray4D(Array4D( + Literal::CreateR4FromArray4D(Array4D( kInputFeatureCount, kBatchSize, kInputSize, kInputSize)))); // The kernel dimensions are in OIHW order. auto kernel = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR4FromArray4D(Array4D( + Literal::CreateR4FromArray4D(Array4D( kOutputFeatureCount, kInputFeatureCount, kWindowSize, kWindowSize)))); ConvolutionDimensionNumbers dnums; @@ -113,11 +113,11 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { auto builder = HloComputation::Builder(TestName()); // The input dimensions are in NHWC order. auto input = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR4FromArray4D(Array4D( + Literal::CreateR4FromArray4D(Array4D( kBatchSize, kInputSize, kInputSize, kInputFeatureCount)))); // The kernel dimensions are in HWIO order. auto kernel = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR4FromArray4D(Array4D( + Literal::CreateR4FromArray4D(Array4D( kWindowSize, kWindowSize, kInputFeatureCount, kOutputFeatureCount)))); ConvolutionDimensionNumbers dnums; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 34b99f2440b935402283d76d4a09475f4bfcb315..45da0cea330ca166ce7c8b0232ef334c0e4c1004 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -37,7 +37,6 @@ limitations under the License. #include "external/llvm/include/llvm/Support/TargetSelect.h" #include "external/llvm/include/llvm/Target/TargetMachine.h" #include "external/llvm/include/llvm/Target/TargetOptions.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/protobuf_util.h" @@ -70,6 +69,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" #include "tensorflow/compiler/xla/service/hlo_proto_util.h" +#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/inliner.h" @@ -82,6 +82,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/env.h" namespace se = ::perftools::gputools; @@ -245,9 +246,9 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { }; } // namespace -Status CpuCompiler::RunHloPasses(HloModule* module, HloDumper dump_hlo) { +Status CpuCompiler::RunHloPasses(HloModule* module) { // Optimization pipeline. - HloPassPipeline pipeline("CPU", dump_hlo); + HloPassPipeline pipeline("CPU"); pipeline.AddInvariantChecker(); // TODO(b/35786417): Re-enable inliner pass after fixing the bug and deciding @@ -256,8 +257,8 @@ Status CpuCompiler::RunHloPasses(HloModule* module, HloDumper dump_hlo) { pipeline.AddPass(); { - auto& pass = pipeline.AddPass>("simplification", - dump_hlo); + auto& pass = + pipeline.AddPass>("simplification"); pass.AddPass( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }, @@ -285,8 +286,13 @@ Status CpuCompiler::RunHloPasses(HloModule* module, HloDumper dump_hlo) { /*enable_dot_simplification=*/false); pipeline.AddPass(/*is_layout_sensitive=*/true); // Outline ops in the entry computation into calls to subcomputations. + const int max_parallelism = + module->config().intra_op_parallelism_threads() > 0 + ? module->config().intra_op_parallelism_threads() + : tensorflow::port::NumSchedulableCPUs(); if (CpuParallelBackendRequested(module->config())) { - pipeline.AddPass(); + pipeline.AddPass(max_parallelism, + ShapeSizeBytesFunction()); } // Copy insertion should be performed immediately before IR emission to avoid // inserting unnecessary copies (later pass adds an instruction which @@ -299,7 +305,8 @@ Status CpuCompiler::RunHloPasses(HloModule* module, HloDumper dump_hlo) { if (CpuParallelBackendRequested(module->config())) { // Re-run the outlining, in case any copies were inserted into the entry // computation. - pipeline.AddPass(); + pipeline.AddPass(max_parallelism, + ShapeSizeBytesFunction()); } pipeline.AddPass(); pipeline.AddPass(); @@ -310,6 +317,7 @@ namespace { // Align buffers to 16-byte boundaries. constexpr int64 kMemoryAlignment = 16; +auto memory_alignment = [](LogicalBuffer::Color) { return kMemoryAlignment; }; llvm::TargetOptions CompilerTargetOptions( const HloModuleConfig& module_config) { @@ -338,25 +346,45 @@ llvm::CodeGenOpt::Level CodeGenOptLevel(const HloModuleConfig& module_config) { } } +Status AppendIRToFile(const string& file_name, const string& ir_module_string) { + std::unique_ptr f; + TF_RETURN_IF_ERROR( + tensorflow::Env::Default()->NewAppendableFile(file_name, &f)); + TF_RETURN_IF_ERROR(f->Append(ir_module_string)); + TF_RETURN_IF_ERROR(f->Close()); + return Status::OK(); +} + } // namespace StatusOr> CpuCompiler::Compile( - std::unique_ptr module, HloDumper dump_hlo, - se::StreamExecutor* stream_exec) { + std::unique_ptr module, se::StreamExecutor* stream_exec) { + VLOG(1) << "Compiling: " << module->name(); TF_RET_CHECK(stream_exec != nullptr); std::call_once(llvm_command_line_options_initialized, &InitializeLLVMCommandLineOptions, module->config()); + const string dump_ir_to = module->config().debug_options().xla_dump_ir_to(); + + auto dump_ir_to_disk = [dump_ir_to](const llvm::Module& module) { + if (!dump_ir_to.empty()) { + TF_RETURN_IF_ERROR( + AppendIRToFile(dump_ir_to, llvm_ir::DumpModuleToString(module))); + } + return Status::OK(); + }; + // Compile must be thread-safe so create a new LLVM context for the module. auto llvm_context = MakeUnique(); auto llvm_module = MakeUnique("__compute_module", *llvm_context); auto jit = MakeUnique(CompilerTargetOptions(module->config()), - CodeGenOptLevel(module->config())); + CodeGenOptLevel(module->config()), + dump_ir_to_disk, dump_ir_to_disk); llvm_module->setDataLayout(jit->data_layout()); llvm_module->setTargetTriple(jit->target_triple().getTriple()); - TF_RETURN_IF_ERROR(RunHloPasses(module.get(), dump_hlo)); + TF_RETURN_IF_ERROR(RunHloPasses(module.get())); HloComputation* computation = module->entry_computation(); std::unordered_map hlo_to_profile_idx; @@ -367,8 +395,17 @@ StatusOr> CpuCompiler::Compile( } std::unique_ptr cpu_executable; - legacy_flags::CpuCompilerFlags* flags = legacy_flags::GetCpuCompilerFlags(); + + // Cache these flags here since we'll want to access them after the module's + // ownership is std::moved. + const bool embed_ir_in_executable = + module->config().debug_options().xla_embed_ir_in_executable(); + const string dump_debug_json_to = + module->config().debug_options().xla_dump_debug_json_to(); + if (CpuParallelBackendRequested(module->config())) { + VLOG(1) << "Using parallel cpu backend"; + // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. // DependencyHloOrdering is used for the parallel emitter because the order @@ -379,12 +416,12 @@ StatusOr> CpuCompiler::Compile( std::unique_ptr assignment, BufferAssigner::Run(module.get(), MakeUnique(module.get()), - BufferSizeBytesFunction(), kMemoryAlignment)); + BufferSizeBytesFunction(), memory_alignment)); - if (!flags->xla_cpu_dump_debug_json_to.empty()) { + if (!dump_debug_json_to.empty()) { HloProto proto = MakeHloProto(*module, *assignment); TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( - proto, flags->xla_cpu_dump_debug_json_to, module->name())); + proto, dump_debug_json_to, module->name())); } // If we are using the parallel CPU backend, we need to create map from @@ -400,7 +437,7 @@ StatusOr> CpuCompiler::Compile( if (instruction->opcode() == HloOpcode::kConstant) { // Copy the constant out of the ProtocolBuffer so that we can give it a // higher alignment. - const void* data = LiteralUtil::InternalData(instruction->literal()); + const void* data = instruction->literal().InternalData(); int64 size = CpuExecutable::ShapeSizeBytes(instruction->shape()); auto iter = aligned_constants.emplace( instruction, MakeUnique(size)); @@ -419,6 +456,7 @@ StatusOr> CpuCompiler::Compile( IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), &hlo_to_profile_idx); + std::unique_ptr> function_names( new std::map()); for (auto embedded_computation : @@ -446,7 +484,7 @@ StatusOr> CpuCompiler::Compile( } string ir_module_string; - if (flags->xla_cpu_embed_ir) { + if (embed_ir_in_executable) { ir_module_string = llvm_ir::DumpModuleToString(*llvm_module); } @@ -457,11 +495,13 @@ StatusOr> CpuCompiler::Compile( std::move(function_names), std::move(hlo_to_profile_idx), std::move(aligned_constants))); - if (flags->xla_cpu_embed_ir) { + if (embed_ir_in_executable) { static_cast(*cpu_executable) .set_ir_module_string(ir_module_string); } } else { + VLOG(1) << "Using sequential cpu backend"; + // Select an order for emitting the HLO instructions for each // computation. Using this sequence enables tighter buffer liveness analysis // and reduced memory usage (as compared to using DependencyHloOrdering). @@ -476,12 +516,12 @@ StatusOr> CpuCompiler::Compile( BufferAssigner::Run( module.get(), MakeUnique(module.get(), module_sequence), - BufferSizeBytesFunction(), kMemoryAlignment)); + BufferSizeBytesFunction(), memory_alignment)); - if (!flags->xla_cpu_dump_debug_json_to.empty()) { + if (!dump_debug_json_to.empty()) { HloProto proto = MakeHloProto(*module, *assignment); TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( - proto, flags->xla_cpu_dump_debug_json_to, module->name())); + proto, dump_debug_json_to, module->name())); } // Each computation is a single function. Emit all embedded computations @@ -490,6 +530,7 @@ StatusOr> CpuCompiler::Compile( // before a caller computation. IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), &hlo_to_profile_idx); + for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { TF_RETURN_IF_ERROR( @@ -510,7 +551,7 @@ StatusOr> CpuCompiler::Compile( string function_name = llvm_ir::AsString(entry_function->getName()); string ir_module_string; - if (flags->xla_cpu_embed_ir) { + if (embed_ir_in_executable) { ir_module_string = llvm_ir::DumpModuleToString(*llvm_module); } @@ -520,17 +561,18 @@ StatusOr> CpuCompiler::Compile( std::move(jit), std::move(assignment), std::move(module), function_name, std::move(hlo_to_profile_idx))); - if (flags->xla_cpu_embed_ir) { + if (embed_ir_in_executable) { static_cast(*cpu_executable) .set_ir_module_string(ir_module_string); } } + VLOG(1) << "Compilation finished"; return std::move(cpu_executable); } StatusOr>> CpuCompiler::Compile( - std::vector> modules, HloDumper dump_hlos, + std::vector> modules, std::vector stream_execs) { return Unimplemented( "Compilation of multiple HLO modules is not yet supported on CPU."); @@ -538,7 +580,6 @@ StatusOr>> CpuCompiler::Compile( StatusOr>> CpuCompiler::CompileAheadOfTime(std::vector> modules, - HloDumper dump_hlo, const AotCompilationOptions& aot_options) { TF_RET_CHECK(!modules.empty()); std::call_once(llvm_command_line_options_initialized, @@ -627,8 +668,9 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, std::vector> results; for (size_t i = 0; i < modules.size(); ++i) { HloModule* module = modules[i].get(); + VLOG(1) << "Compiling ahead-of-time: " << module->name(); - TF_RETURN_IF_ERROR(RunHloPasses(module, dump_hlo)); + TF_RETURN_IF_ERROR(RunHloPasses(module)); TF_ASSIGN_OR_RETURN( SequentialHloOrdering::HloModuleSequence module_sequence, @@ -640,13 +682,14 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, std::unique_ptr assignment, BufferAssigner::Run( module, MakeUnique(module, module_sequence), - BufferSizeBytesFunction(), kMemoryAlignment)); + BufferSizeBytesFunction(), memory_alignment)); - legacy_flags::CpuCompilerFlags* flags = legacy_flags::GetCpuCompilerFlags(); - if (!flags->xla_cpu_dump_debug_json_to.empty()) { + const string dump_debug_json_to = + module->config().debug_options().xla_dump_debug_json_to(); + if (!dump_debug_json_to.empty()) { HloProto proto = MakeHloProto(*module, *assignment); TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( - proto, flags->xla_cpu_dump_debug_json_to, module->name())); + proto, dump_debug_json_to, module->name())); } IrEmitter ir_emitter(*module, *assignment, &llvm_module, @@ -704,6 +747,8 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, std::move(object_file_data), std::move(buffer_sizes), result_slice.index())); } + + VLOG(1) << "Compilation finished"; return std::move(results); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index 29fa4eac61beaa25e1662b1be5afa9757ab077ea..b82e181df2b883ddac7e7d39212fb28b07ca7b0c 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -110,16 +110,15 @@ class CpuCompiler : public Compiler { ~CpuCompiler() override {} StatusOr> Compile( - std::unique_ptr module, HloDumper dump_hlo, + std::unique_ptr module, perftools::gputools::StreamExecutor* stream_exec) override; StatusOr>> Compile( - std::vector> modules, HloDumper dump_hlo, + std::vector> modules, std::vector stream_exec) override; StatusOr>> CompileAheadOfTime(std::vector> modules, - HloDumper dump_hlo, const AotCompilationOptions& options) override; perftools::gputools::Platform::Id PlatformId() const override; @@ -132,7 +131,7 @@ class CpuCompiler : public Compiler { // Runs the HLO passes which are necessary for both optimizations and // correctness. - Status RunHloPasses(HloModule* hlo_module, HloDumper dump_hlo); + Status RunHloPasses(HloModule* module); TF_DISALLOW_COPY_AND_ASSIGN(CpuCompiler); }; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc index f6b1dcae75a773811f8c652dea36b7f3ca36e901..af931f7b0132bf1fa5714f268c463830259a779d 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc @@ -15,19 +15,28 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/cpu/shape_partition.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace cpu { StatusOr ParallelizationPreparation::Run(HloModule* module) { + XLA_VLOG_LINES(2, "ParallelizationPreparation ENTRY"); + XLA_VLOG_LINES(2, module->ToString()); + bool changed = false; + TF_ASSIGN_OR_RETURN(changed, RunParallelTaskAssignment(module)); + HloComputation* entry_computation = module->entry_computation(); std::unordered_set outlined; std::vector instructions_to_outline; @@ -44,13 +53,21 @@ StatusOr ParallelizationPreparation::Run(HloModule* module) { instruction->opcode() == HloOpcode::kConstant) { continue; } + + // Outline 'instruction' in isolation if it was assigned parallel tasks. + if (OutlineParallelizableInstruction(instruction)) { + outlined.insert(instruction); + changed = true; + continue; + } + instructions_to_outline.clear(); HloInstruction* outline_candidate = instruction; instructions_to_outline.push_back(outline_candidate); bool all_bitcasts = outline_candidate->opcode() == HloOpcode::kBitcast; // Outline sole users with the current instruction. - while (outline_candidate->users().size() == 1) { + while (CanOutlineWithUser(outline_candidate)) { HloInstruction* prior_candidate = outline_candidate; outline_candidate = *outline_candidate->users().begin(); all_bitcasts |= outline_candidate->opcode() == HloOpcode::kBitcast; @@ -120,8 +137,136 @@ StatusOr ParallelizationPreparation::Run(HloModule* module) { changed = true; } } + + XLA_VLOG_LINES(2, "ParallelizationPreparation EXIT"); + XLA_VLOG_LINES(2, module->ToString()); return changed; } +StatusOr ParallelizationPreparation::RunParallelTaskAssignment( + HloModule* module) { + VLOG(1) << "RunParallelTaskAssignment max_parallelism_: " << max_parallelism_; + bool changed = false; + // Run cost analysis on entry computation. + HloCostAnalysis cost_analysis(shape_size_); + HloComputation* computation = module->entry_computation(); + Status cost_status = computation->root_instruction()->Accept(&cost_analysis); + for (auto& instruction : computation->instructions()) { + // Currently, we do not assign parallel tasks to instructions with at least + // one of the following properties: + // *) Internal threading (library calls to kConv, kDot, and kCustomCall). + // *) Emit custom loops (kSelectAndScatter, FusionKind::kTransposeDot). + // *) Tuple-shaped. + // TODO(b/27458679) Parallelize instructions which are skipped here. + if (instruction->opcode() == HloOpcode::kParameter || + instruction->opcode() == HloOpcode::kConstant || + instruction->opcode() == HloOpcode::kCall || + instruction->opcode() == HloOpcode::kCustomCall || + instruction->opcode() == HloOpcode::kSelectAndScatter || + (instruction->opcode() == HloOpcode::kConvolution && + PotentiallyImplementedAsEigenConvolution(*instruction)) || + PotentiallyImplementedAsEigenDot(*instruction) || + (instruction->opcode() == HloOpcode::kFusion && + instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) || + ShapeUtil::IsTuple(instruction->shape())) { + continue; + } + + // Calculate target parallel task count in [1, max_parallelism_]. + const int64 target_parallel_task_count = GetTargetParallelTaskCount( + cost_status.ok() ? &cost_analysis : nullptr, instruction.get()); + if (target_parallel_task_count == 1) { + continue; + } + + // Assign feasible dimension partitions (based on actual dimension sizes). + auto dim_partition_counts = ShapePartitionAssigner(instruction->shape()) + .Run(target_parallel_task_count); + const int64 total_partition_count = + ShapePartitionAssigner::GetTotalPartitionCount(dim_partition_counts); + if (total_partition_count <= 1) { + // Feasible partition calculation resulting in no partitioning, so skip. + continue; + } + VLOG(2) << "Assigning parallel task count: " << total_partition_count + << " to instruction: " << instruction->name(); + // Map 'instruction' to assigned dimension partitioning. + instruction->set_outer_dimension_partitions(dim_partition_counts); + } + + return changed; +} + +int64 ParallelizationPreparation::GetTargetParallelTaskCount( + const HloCostAnalysis* cost_analysis, HloInstruction* instruction) { + // Default to a simple cost model based on hlo size and typical L2 cache size. + // Note that 'cost_analysis' can be 'nullptr' if HloCostAnalysis returns an + // error status (likely because HLOs like CustomCall are not yet implemented + // in the HloCostAnalysis). + int64 instruction_cost = shape_size_(instruction->shape()); + int64 min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size. + if (cost_analysis != nullptr) { + // Calculate the instruction cost in cycles. + // TODO(29630486) Improve on this linear cost model. + // Consider making 'min_cost_per_thread' be a function of the target + // bandwidth limit for instructions with low arithmetic complexity. + instruction_cost = 1 * cost_analysis->flop_count(*instruction) + + 2 * cost_analysis->transcendental_count(*instruction) + + 10 * cost_analysis->bytes_accessed(*instruction); + // Minimum per-thread cost is 100us of work on a 2GHz core. + min_cost_per_thread = 100000; + } + // Return target parallel task count in [1, max_parallelism_]. + return std::min(max_parallelism_, + std::max(1LL, instruction_cost / min_cost_per_thread)); +} + +bool ParallelizationPreparation::OutlineParallelizableInstruction( + HloInstruction* instruction) { + if (instruction->outer_dimension_partitions().empty()) { + return false; + } + // Store dimension partition counts before outlining (which clones + // 'instruction'). + std::vector dim_partition_counts = + instruction->outer_dimension_partitions(); + // Outline 'instruction' in its own sub-computation. + HloModule* module = instruction->parent()->parent(); + auto* call = module->OutlineExpressionFromComputation( + {instruction}, tensorflow::strings::StrCat("pp_", instruction->name()), + module->entry_computation()); + // Map previously assigned 'dim_partition_counts' to cloned root instruction. + VLOG(1) << "Outlining parallelizable" + << " caller: " << call->name() + << " callee: " << call->to_apply()->root_instruction()->name(); + call->to_apply()->root_instruction()->set_outer_dimension_partitions( + dim_partition_counts); + return true; +} + +bool ParallelizationPreparation::CanOutlineWithUser( + HloInstruction* instruction) { + if (instruction->users().size() != 1) { + // Do not outline 'instruction' with multiple users. + return false; + } + if (AssignedParallelTasks(instruction) || + AssignedParallelTasks(*instruction->users().begin())) { + // Do not outline if 'instruction' (or user) were assigned parallel tasks. + return false; + } + return true; +} + +bool ParallelizationPreparation::AssignedParallelTasks( + HloInstruction* instruction) { + return !instruction->outer_dimension_partitions().empty() || + (instruction->opcode() == HloOpcode::kCall && + !instruction->to_apply() + ->root_instruction() + ->outer_dimension_partitions() + .empty()); +} + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h index 62999f5686db2e4db3ace0c5580bd156edbfa994..d53fc461509cad51778dba37922212731236952f 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_PARALLELIZATION_PREPARATION_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_PARALLELIZATION_PREPARATION_H_ +#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -32,14 +33,51 @@ namespace cpu { // handle While constructs. class ParallelizationPreparation : public HloPassInterface { public: + // 'max_parallelism': the maximum parallel task count per instruction. + // 'shape_size': shape size function used by HloCostAnalysis during parallel + // task assignment. + ParallelizationPreparation( + const int64 max_parallelism, + const HloCostAnalysis::ShapeSizeFunction& shape_size) + : max_parallelism_(max_parallelism), shape_size_(shape_size) {} ~ParallelizationPreparation() override {} + tensorflow::StringPiece name() const override { return "cpu-parallel-prepare"; } - // Run instruction fusion on the given computation. Returns whether the + // Run parallel preparation on the given computation. Returns whether the // computation was changed. StatusOr Run(HloModule* module) override; + + private: + // Assigns parallel task partitions to conformant instructions in 'module'. + // Returns true on success or error status otherwise. + StatusOr RunParallelTaskAssignment(HloModule* module); + + // Returns the target parallel task count for 'instruction'. + // Utilizes 'cost_analysis' if non-null. + // Otherwise defaults to a simple HLO output size-based cost model. + int64 GetTargetParallelTaskCount(const HloCostAnalysis* cost_analysis, + HloInstruction* instruction); + + // Outlines 'instruction' from entry computation, if it had + // been assigned parallel tasks in an earlier pass through the computation. + // Returns true if 'instruction' was successfully outlined, false otherwise. + bool OutlineParallelizableInstruction(HloInstruction* instruction); + + // Returns true if 'instruction' can be outlined into the same sub-computation + // with its single user (parallelizable instructions are not outlined with + // each other). Returns false otherwise. + bool CanOutlineWithUser(HloInstruction* instruction); + + // Returns true if 'instruction' (or the root of the sub-computation that + // 'instruction' calls) has had parallel tasks assigned in earlier pass. + // Returns false otherwise. + bool AssignedParallelTasks(HloInstruction* instruction); + + const int64 max_parallelism_; + const HloCostAnalysis::ShapeSizeFunction shape_size_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 420f9cebc5b1ded365c20079589ebc79a03b3164..c21b8a9addabb3fb409b2e7420e8abbce62a1190 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -22,9 +22,9 @@ limitations under the License. #include "external/llvm/include/llvm/IR/Instructions.h" #include "external/llvm/include/llvm/IR/Module.h" #include "external/llvm/include/llvm/IR/Value.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -44,7 +44,8 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs, const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, llvm::Value* executable_run_options_value, - llvm::IRBuilder<>* ir_builder) + llvm::IRBuilder<>* ir_builder, + const HloModuleConfig& hlo_module_config) : dot_(dot), transpose_lhs_(transpose_lhs), transpose_rhs_(transpose_rhs), @@ -52,18 +53,20 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs, lhs_array_(lhs_array), rhs_array_(rhs_array), executable_run_options_value_(executable_run_options_value), - ir_builder_(ir_builder) {} + ir_builder_(ir_builder), + hlo_module_config_(hlo_module_config) {} /* static */ tensorflow::Status DotOpEmitter::EmitDotOperation( const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, - llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder) { + llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, + const HloModuleConfig& hlo_module_config) { PrimitiveType type = target_array.GetShape().element_type(); TF_RET_CHECK(F32 == type || F64 == type); DotOpEmitter dot_emitter(dot, transpose_lhs, transpose_rhs, target_array, lhs_array, rhs_array, executable_run_options_value, - ir_builder); + ir_builder, hlo_module_config); return dot_emitter.Emit(); } @@ -233,20 +236,20 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() { // The two transpose_... parameters are actually booleans, but we use int32 // to avoid target-dependent calling convention details. - legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags(); - bool multi_threaded = flags->xla_cpu_multi_thread_eigen; + bool multi_threaded_eigen = + hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); PrimitiveType type = target_array_.GetShape().element_type(); llvm::Type* float_type; const char* fn_name; switch (type) { case F32: - fn_name = multi_threaded + fn_name = multi_threaded_eigen ? runtime::kEigenMatmulF32SymbolName : runtime::kEigenSingleThreadedMatmulF32SymbolName; float_type = ir_builder_->getFloatTy(); break; case F64: - fn_name = multi_threaded + fn_name = multi_threaded_eigen ? runtime::kEigenMatmulF64SymbolName : runtime::kEigenSingleThreadedMatmulF64SymbolName; float_type = ir_builder_->getDoubleTy(); diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index 44dfe5f2a91222d99907e31062fb1d8f74aed3ff..b6147163802dde12a8bf7dde91ac8dad45ba1990 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -18,6 +18,7 @@ limitations under the License. #include "external/llvm/include/llvm/IR/IRBuilder.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/types.h" @@ -39,7 +40,8 @@ class DotOpEmitter { const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, - llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder); + llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, + const HloModuleConfig& hlo_module_config); private: DotOpEmitter(const HloInstruction& dot, bool transpose_lhs, @@ -47,7 +49,8 @@ class DotOpEmitter { const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, llvm::Value* executable_run_options_value, - llvm::IRBuilder<>* ir_builder); + llvm::IRBuilder<>* ir_builder, + const HloModuleConfig& hlo_module_config); // Emits the IR to perform the dot operation. tensorflow::Status Emit(); @@ -82,6 +85,7 @@ class DotOpEmitter { const llvm_ir::IrArray& rhs_array_; llvm::Value* executable_run_options_value_; llvm::IRBuilder<>* ir_builder_; + const HloModuleConfig& hlo_module_config_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/infeed_manager.cc b/tensorflow/compiler/xla/service/cpu/infeed_manager.cc index 14c882a06ee9fdfc66f3d6db55146431634dd85e..2ce27d22c761c139255de7a7c4c04412af0b79f7 100644 --- a/tensorflow/compiler/xla/service/cpu/infeed_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/infeed_manager.cc @@ -34,10 +34,12 @@ void InfeedManager::Reset() { enqueued_buffer_.clear(); } -void InfeedManager::EnqueueBuffer(InfeedBuffer* buffer) { +void InfeedManager::EnqueueBuffers(const std::vector& buffers) { tensorflow::mutex_lock l(mu_); bool was_empty = enqueued_buffer_.empty(); - enqueued_buffer_.push_back(buffer); + for (InfeedBuffer* b : buffers) { + enqueued_buffer_.push_back(b); + } if (was_empty) { // This has the potential to suffer from the notified thread // immediately trying and failing to acquire mu_, but seems diff --git a/tensorflow/compiler/xla/service/cpu/infeed_manager.h b/tensorflow/compiler/xla/service/cpu/infeed_manager.h index 77472746e659b2ddbd9b54a036775ebdd0084fdd..e9659884530dc34c7d131fc0b59920e8d17144ef 100644 --- a/tensorflow/compiler/xla/service/cpu/infeed_manager.h +++ b/tensorflow/compiler/xla/service/cpu/infeed_manager.h @@ -21,6 +21,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_INFEED_MANAGER_H_ #include +#include #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/mutex.h" @@ -52,11 +53,12 @@ class InfeedManager { // condition is to call Reset when no computation is taking place. void Reset(); - // Adds buffer to the infeed queue. buffer->Done will be called when - // the buffer will no longer be accessed by the InfeedManager, - // either as a result of a call to Reset or because the runtime has - // dequeued and used the buffer. - void EnqueueBuffer(InfeedBuffer* buffer); + // Adds a set of buffers to the infeed queue + // atomically. buffer->Done will be called when the buffer will no + // longer be accessed by the InfeedManager, either as a result of a + // call to Reset or because the runtime has dequeued and used the + // buffer. + void EnqueueBuffers(const std::vector& buffers); // Blocks until the infeed queue is non-empty, then returns the // buffer at the head of the queue. Sets the current buffer to be diff --git a/tensorflow/compiler/xla/service/cpu/infeed_manager_test.cc b/tensorflow/compiler/xla/service/cpu/infeed_manager_test.cc index c65d8216606a1caa561adea5a83c8f1aa2c82906..a59fa35fdbc9ef2628582b30c6d348a281912469 100644 --- a/tensorflow/compiler/xla/service/cpu/infeed_manager_test.cc +++ b/tensorflow/compiler/xla/service/cpu/infeed_manager_test.cc @@ -57,8 +57,8 @@ TEST_F(InfeedManagerTest, SingleThreadedSequential) { cpu::runtime::InfeedManager* infeed = cpu::runtime::GetInfeedManager(); - infeed->EnqueueBuffer(a); - infeed->EnqueueBuffer(b); + infeed->EnqueueBuffers({a}); + infeed->EnqueueBuffers({b}); ProcessNextBuffer(a->length()); ProcessNextBuffer(b->length()); } @@ -69,9 +69,9 @@ TEST_F(InfeedManagerTest, SingleThreadedInterleaved) { cpu::runtime::InfeedManager* infeed = cpu::runtime::GetInfeedManager(); - infeed->EnqueueBuffer(a); + infeed->EnqueueBuffers({a}); ProcessNextBuffer(a->length()); - infeed->EnqueueBuffer(b); + infeed->EnqueueBuffers({b}); ProcessNextBuffer(b->length()); } @@ -92,7 +92,7 @@ TEST_F(InfeedManagerTest, MultiThreaded) { } } TestInfeedBuffer* a = new TestInfeedBuffer(length); - infeed->EnqueueBuffer(a); + infeed->EnqueueBuffers({a}); }); ProcessNextBuffer(length); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc index 2d855d0eb1e9448707b3916d20803cebf2ebabe4..859329e2c1ddca9dbea14c16b67f63d4803b6acd 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/window_util.h" @@ -26,11 +25,6 @@ namespace cpu { bool PotentiallyImplementedAsEigenConvolution( const HloInstruction& convolution) { - legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags(); - if (!flags->xla_cpu_use_eigen) { - return false; - } - // The following conditions are necessary (but not sufficient) for // implementing `convolution` with Eigen convolution: // - the input and kernel have a non-zero number of elements. @@ -82,11 +76,6 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, } // namespace bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) { - legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags(); - if (!flags->xla_cpu_use_eigen) { - return false; - } - // For certain types of Dot, we can call Eigen if (hlo.opcode() == HloOpcode::kDot) { const Shape& lhs_shape = hlo.operand(0)->shape(); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 7ad497ff1a27ff083517de6a82a8c4b903800cce..cc27db8993801de8297816f0f38d67ed445dd379 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -33,7 +33,6 @@ limitations under the License. #include "external/llvm/include/llvm/IR/Intrinsics.h" #include "external/llvm/include/llvm/IR/LLVMContext.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" @@ -55,6 +54,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { @@ -84,6 +84,12 @@ StatusOr IrEmitter::EmitComputation( std::vector* instruction_order) { string function_name = name_uniquer_.GetUniqueName(function_name_prefix); VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]"; + num_dynamic_loop_bounds_ = 0; + if (!computation->root_instruction()->outer_dimension_partitions().empty()) { + num_dynamic_loop_bounds_ = + computation->root_instruction()->outer_dimension_partitions().size(); + } + InitializeIrFunction(function_name, is_entry_computation); // The rdtscp instruction is x86 specific. We will fallback to LLVM's generic // readcyclecounter if it is unavailable. @@ -112,7 +118,7 @@ void IrEmitter::InitializeIrFunction(const string& function_name, bool is_entry_computation) { // The function signature is: // void function(i8* retval, i8* run_options, i8** params, i8** temps, - // i64* prof_counters) + // i64* dynamic_loop_bounds, i64* prof_counters) // // retval: points to the returned value. // params: address of an array with pointers to parameters. @@ -152,6 +158,10 @@ void IrEmitter::InitializeIrFunction(const string& function_name, // | temp 0 | | temp 1 | | temp N-1 | // \---------/ \---------/ \-----------/ // + // /--------------------------------------------\ + // dynamic loop bounds -> | outer_dim0_start | outer_dim0_limit | .....| + // (elided for aot) \--------------------------------------------/ + // // /---------------------------------------------\ // prof counters -> | counter 0 | counter 1 | ..... | counter N-1 | // (elided for aot) \---------------------------------------------/ @@ -164,6 +174,9 @@ void IrEmitter::InitializeIrFunction(const string& function_name, llvm::Type* i64_ptr_type = llvm::Type::getInt64PtrTy(module_->getContext()); std::vector compute_function_params( {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type}); + if (num_dynamic_loop_bounds_ > 0) { + compute_function_params.push_back(i64_ptr_type); + } if (hlo_to_profile_idx_) { compute_function_params.push_back(i64_ptr_type); } @@ -190,6 +203,9 @@ void IrEmitter::InitializeIrFunction(const string& function_name, (++arg_iter)->setName("run_options"); (++arg_iter)->setName("params"); (++arg_iter)->setName("temps"); + if (num_dynamic_loop_bounds_ > 0) { + (++arg_iter)->setName("dynamic_loop_bounds"); + } if (hlo_to_profile_idx_) { (++arg_iter)->setName("prof_counters"); } @@ -242,12 +258,12 @@ Status IrEmitter::HandleConstant(HloInstruction* constant, return Status::OK(); } -Status IrEmitter::HandleCopy(HloInstruction* copy, HloInstruction* operand) { +Status IrEmitter::HandleCopy(HloInstruction* copy) { if (ShapeUtil::IsTuple(copy->shape())) { // kCopy shallow copies a tuple so just memcpy the top-level buffer. TF_ASSIGN_OR_RETURN(llvm::Value * copy_value, EmitTargetAddressForOp(copy)); emitted_value_[copy] = copy_value; - return EmitMemcpy(*operand, *copy); + return EmitMemcpy(*(copy->operand(0)), *copy); } else { // Use the elemental emitter for non-tuple shapes. return DefaultAction(copy); @@ -358,6 +374,57 @@ Status IrEmitter::HandleSelect(HloInstruction* select, HloInstruction* pred, Status IrEmitter::HandleInfeed(HloInstruction* infeed) { VLOG(2) << "HandleInfeed: " << infeed->ToString(); + const Shape& shape = infeed->shape(); + + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(infeed)); + + if (ShapeUtil::IsTuple(shape)) { + TF_RET_CHECK(!ShapeUtil::IsNestedTuple(shape)); + + // For a tuple, we first copy each of the internal elements to + // their corresponding target locations. We then construct the + // tuple outer buffer containing pointers to the internal + // elements. + std::vector tuple_element_addresses; + for (int64 i = 0; i < shape.tuple_shapes_size(); ++i) { + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice buffer, + assignment_.GetUniqueSlice(infeed, {i})); + + const Shape& tuple_element_shape = + ShapeUtil::GetTupleElementShape(shape, i); + + // Only the outer tuple buffer's target address is obtained from + // EmitTargetAddressForOp to handle the case when Infeed is the + // root instruction. Target addresses for internal elements can + // be obtained from EmitTempBufferPointer. + llvm::Value* tuple_element_address = + EmitTempBufferPointer(buffer, tuple_element_shape); + + TF_RETURN_IF_ERROR(EmitInfeedTransfer(ByteSizeOf(tuple_element_shape), + tuple_element_address)); + + tuple_element_addresses.push_back(tuple_element_address); + } + + llvm_ir::EmitTuple(llvm_ir::IrArray(target_address, shape), + tuple_element_addresses, &ir_builder_); + } else { + TF_RETURN_IF_ERROR(EmitInfeedTransfer(ByteSizeOf(shape), target_address)); + } + + emitted_value_[infeed] = target_address; + + return Status::OK(); +} + +Status IrEmitter::EmitInfeedTransfer(int64 length, + llvm::Value* target_address) { + if (length > std::numeric_limits::max()) { + return InvalidArgument("infeed buffer length %lld is too large", length); + } + int32 length_32 = static_cast(length); + // The signature of the acquire infeed buffer function is: // // (void*)(int32 length); @@ -384,26 +451,14 @@ Status IrEmitter::HandleInfeed(HloInstruction* infeed) { runtime::kReleaseInfeedBufferAfterDequeueSymbolName, release_type)); release_func->setCallingConv(llvm::CallingConv::C); - const Shape& shape = infeed->shape(); - int64 length = ByteSizeOf(shape); - if (length > std::numeric_limits::max()) { - return InvalidArgument("infeed buffer length %lld is too large", length); - } - int32 length_32 = static_cast(length); - llvm::Value* acquired_pointer = ir_builder_.CreateCall(acquire_func, {ir_builder_.getInt32(length_32)}); - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(infeed)); - ir_builder_.CreateMemCpy(target_address, acquired_pointer, length_32, 1); ir_builder_.CreateCall(release_func, {ir_builder_.getInt32(length_32), acquired_pointer}); - emitted_value_[infeed] = target_address; - return Status::OK(); } @@ -760,7 +815,8 @@ Status IrEmitter::HandleDot(HloInstruction* dot, HloInstruction* lhs, // Dot operation is complicated so we delegate to a helper class. TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array, - lhs_array, rhs_array, GetExecutableRunOptionsArgument(), &ir_builder_)); + lhs_array, rhs_array, GetExecutableRunOptionsArgument(), &ir_builder_, + hlo_module_config_)); emitted_value_[dot] = target_address; return Status::OK(); @@ -845,9 +901,10 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution, int64_type, int64_type, int64_type, int64_type, int64_type, int64_type, int64_type, int64_type}, /*isVarArg=*/false); - legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags(); + bool multi_threaded_eigen = + hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); const char* fn_name = - (flags->xla_cpu_multi_thread_eigen + (multi_threaded_eigen ? runtime::kEigenConvF32SymbolName : runtime::kEigenSingleThreadedConvF32SymbolName); llvm::Function* conv_func = llvm::cast( @@ -1039,6 +1096,231 @@ Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { "Cross replica sum not implemented on CPU. See b/33011107."); } +// Fills up the free variables in 'index_with_free_var' with values from +// 'filler_index'. The size of free variables must be the same as the +// size of 'filler_index'. +// +// This is often used after dimension reduction, where +// 'index_with_free_var' has one or more dimensions reduced, which serves as +// free variables (represented as nullptr). For example, if we have a 4 +// dimensional input and index for the dimension being reduced is +// 2 (third dimension), we will have an index like [i, j, NULL, k] +// after reduced dimension. +// +// Here we fill up that free variable by 'filler_index', which contains +// the value in the reduced dimension. +static llvm_ir::IrArray::Index FillReducedDimensionIndex( + llvm_ir::IrArray::Index index_with_free_var, + llvm_ir::IrArray::Index filler_index) { + llvm_ir::IrArray::Index::const_iterator it = filler_index.begin(); + + for (size_t i = 0; i < index_with_free_var.size(); ++i) { + if (index_with_free_var[i] == nullptr) { + index_with_free_var[i] = *it++; + } + } + CHECK(filler_index.end() == it); + return index_with_free_var; +} + +Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) { + // The output of BatchNormTraining is a tuple of three element: + // - An N-dimensional array containing normalized values. + // - A 1 dimensional array containing the mean value for each feature. + // - A 1 dimensional array containing the variance value for each feature. + HloInstruction* operand = batch_norm_training->operands()[0]; + HloInstruction* scale = batch_norm_training->operands()[1]; + HloInstruction* offset = batch_norm_training->operands()[2]; + float epsilon = batch_norm_training->epsilon(); + int64 feature_index = batch_norm_training->feature_index(); + TF_RET_CHECK(ShapeUtil::IsTuple(batch_norm_training->shape()) && + ShapeUtil::TupleElementCount(batch_norm_training->shape()) == 3); + + const Shape& output_shape = + ShapeUtil::GetTupleElementShape(batch_norm_training->shape(), 0); + const Shape& feature_shape = + ShapeUtil::GetTupleElementShape(batch_norm_training->shape(), 1); + + // Reduce vector of the non-feature dimensions. + std::vector dimensions_to_reduce; + + for (int64 i = 0; i < operand->shape().dimensions_size(); ++i) { + if (i != feature_index) { + dimensions_to_reduce.push_back(i); + } + } + + // Get the second and third allocations in the output tuple, which should be + // used to store the result of mean and variance value calculation. + TF_ASSIGN_OR_RETURN( + const BufferAllocation::Slice slice_mean, + assignment_.GetUniqueSlice(batch_norm_training, /*index=*/{1})); + TF_ASSIGN_OR_RETURN( + const BufferAllocation::Slice slice_var, + assignment_.GetUniqueSlice(batch_norm_training, /*index=*/{2})); + const int feature_count = output_shape.dimensions(feature_index); + const int size_in_elements = ShapeUtil::ElementsIn(output_shape); + TF_RET_CHECK(ShapeUtil::ElementsIn(operand->shape()) == size_in_elements); + const int elements_per_feature = size_in_elements / feature_count; + + llvm::Value* mean = EmitTempBufferPointer(slice_mean, feature_shape); + llvm_ir::IrArray mean_array(mean, feature_shape); + + llvm::Value* var = EmitTempBufferPointer(slice_var, feature_shape); + llvm_ir::IrArray var_array(var, feature_shape); + + // This loop calculates mean and variance for each feature. + // + // In theory this could be swapped by multi-output fusion. We will evaluate + // this when it's ready. + // + // For variance calculation, we use a simplified formula so we can fuse the + // computation into the same loop to calculate mean: Var=E(X^2) - E(X)^2. + TF_RETURN_IF_ERROR( + llvm_ir::LoopEmitter( + [this, operand, dimensions_to_reduce, feature_shape, var_array, + elements_per_feature](const llvm_ir::IrArray::Index& index) { + PrimitiveType element_type = operand->shape().element_type(); + // Used to calculate E(X). + llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(element_type, &ir_builder_), + "sum_address", &ir_builder_, + MinimumAlignmentForPrimitiveType(element_type)); + + // Used to calculate E(X^2). + llvm::Value* sum_square_address = + llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(element_type, &ir_builder_), + "sum_square_address", &ir_builder_, + MinimumAlignmentForPrimitiveType(element_type)); + + ir_builder_.CreateStore( + llvm::ConstantFP::get(ir_builder_.getFloatTy(), 0.0), + sum_address); + + ir_builder_.CreateStore( + llvm::ConstantFP::get(ir_builder_.getFloatTy(), 0.0), + sum_square_address); + + llvm_ir::ForLoopNest loops(&ir_builder_); + + const llvm_ir::IrArray::Index reduced_dims_index = + loops.AddLoopsForShapeOnDimensions( + operand->shape(), dimensions_to_reduce, "reduction_dim"); + + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), + &ir_builder_); + + llvm_ir::IrArray operand_array(GetIrArrayForOp(operand)); + llvm_ir::IrArray::Index input_index = + FillReducedDimensionIndex(reduced_dims_index, index); + llvm::Value* new_value = + operand_array.EmitReadArrayElement(input_index, &ir_builder_); + + llvm::Value* new_value_square = + ir_builder_.CreateFMul(new_value, new_value); + + llvm::Value* current_sum = ir_builder_.CreateLoad(sum_address); + llvm::Value* current_sum_square = + ir_builder_.CreateLoad(sum_square_address); + // Update sum. + ir_builder_.CreateStore( + ir_builder_.CreateFAdd(current_sum, new_value), sum_address); + + // Update sum square. + ir_builder_.CreateStore( + ir_builder_.CreateFAdd(current_sum_square, new_value_square), + sum_square_address); + + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), + &ir_builder_); + + llvm::Value* sum = ir_builder_.CreateLoad(sum_address); + llvm::Value* elements_per_feature_value = llvm::ConstantFP::get( + ir_builder_.getFloatTy(), elements_per_feature); + llvm::Value* mean = + ir_builder_.CreateFDiv(sum, elements_per_feature_value); + llvm::Value* mean_square = ir_builder_.CreateFMul(mean, mean); + llvm::Value* sum_square = + ir_builder_.CreateLoad(sum_square_address); + + // Var=E(X^2) - E(X)^2. + llvm::Value* var = ir_builder_.CreateFSub( + ir_builder_.CreateFDiv(sum_square, elements_per_feature_value), + mean_square); + + var_array.EmitWriteArrayElement(index, var, &ir_builder_); + return mean; + }, + mean_array, &ir_builder_) + .EmitLoop()); + + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(batch_norm_training)); + + TF_ASSIGN_OR_RETURN( + const BufferAllocation::Slice slice, + assignment_.GetUniqueSlice(batch_norm_training, /*index=*/{0})); + + llvm::Value* normalized = EmitTempBufferPointer(slice, output_shape); + + llvm_ir::IrArray target_array(normalized, output_shape); + + AddAliasingInformationToIrArray(*batch_norm_training, &target_array); + + TF_RETURN_IF_ERROR( + llvm_ir::LoopEmitter( + [this, mean_array, var_array, epsilon, operand, dimensions_to_reduce, + feature_index, offset, scale](const llvm_ir::IrArray::Index& index) { + // The following logic normalizes the input value, scales and shifts + // it: + // + // normalized = (input - mean) / sqrt(variance + epsilon) + // result = normalized * scale + offset + + // Current index in the feature dimension. + llvm_ir::IrArray::Index feature_index_value(1, + index[feature_index]); + + llvm::Value* mean = mean_array.EmitReadArrayElement( + feature_index_value, &ir_builder_); + llvm::Value* var = var_array.EmitReadArrayElement( + feature_index_value, &ir_builder_); + + llvm_ir::IrArray operand_array(GetIrArrayForOp(operand)); + llvm::Value* input = + operand_array.EmitReadArrayElement(index, &ir_builder_); + + llvm::Value* variance_with_epsilon = ir_builder_.CreateFAdd( + var, llvm::ConstantFP::get(ir_builder_.getFloatTy(), epsilon)); + llvm::Function* func_llvm_sqrt = llvm::Intrinsic::getDeclaration( + module_, llvm::Intrinsic::sqrt, {ir_builder_.getFloatTy()}); + llvm::Value* variance_sqrt = + ir_builder_.CreateCall(func_llvm_sqrt, {variance_with_epsilon}); + llvm::Value* normalized = ir_builder_.CreateFDiv( + ir_builder_.CreateFSub(input, mean), variance_sqrt); + llvm_ir::IrArray offset_array(GetIrArrayForOp(offset)); + llvm::Value* offset = offset_array.EmitReadArrayElement( + feature_index_value, &ir_builder_); + llvm_ir::IrArray scale_array(GetIrArrayForOp(scale)); + llvm::Value* scale = scale_array.EmitReadArrayElement( + feature_index_value, &ir_builder_); + llvm::Value* result = ir_builder_.CreateFAdd( + ir_builder_.CreateFMul(normalized, scale), offset); + + return result; + }, + target_array, &ir_builder_) + .EmitLoop()); + + llvm_ir::EmitTuple( + llvm_ir::IrArray(target_address, batch_norm_training->shape()), + {normalized, mean, var}, &ir_builder_); + emitted_value_[batch_norm_training] = target_address; + + return Status::OK(); +} + Status IrEmitter::HandleParameter(HloInstruction* parameter) { VLOG(2) << "HandleParameter: " << parameter->ToString(); auto param_number = parameter->parameter_number(); @@ -1283,7 +1565,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( *dot, dot->operand(0)->IsRank2Transpose(), dot->operand(1)->IsRank2Transpose(), target_array, lhs_array, rhs_array, - GetExecutableRunOptionsArgument(), &ir_builder_)); + GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_)); emitted_value_[fusion] = target_address; return Status::OK(); @@ -1606,13 +1888,24 @@ llvm::Argument* IrEmitter::GetResultArgument() { } llvm::Argument* IrEmitter::GetProfileCountersArgument() { - return hlo_to_profile_idx_ ? GetArg(compute_function_, 4) : nullptr; + const int64 arg_index = num_dynamic_loop_bounds_ > 0 ? 5 : 4; + return hlo_to_profile_idx_ ? GetArg(compute_function_, arg_index) : nullptr; } llvm::Value* IrEmitter::GetTempBuffersArgument() { return GetArg(compute_function_, 3); } +llvm::Value* IrEmitter::GetDynamicLoopBound(const int64 offset) { + CHECK_GT(num_dynamic_loop_bounds_, 0); + CHECK_LT(offset, num_dynamic_loop_bounds_ * 2); + llvm::Argument* loop_bounds_arg = GetArg(compute_function_, 4); + string name = tensorflow::strings::StrCat("dynamic_loop_bound_", offset); + return ir_builder_.CreateLoad( + ir_builder_.CreateGEP(loop_bounds_arg, ir_builder_.getInt64(offset), + llvm_ir::AsStringRef(name))); +} + llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() { return GetArg(compute_function_, 1); } @@ -1645,11 +1938,14 @@ llvm::Value* IrEmitter::EmitTempBufferPointer( GetTempBuffersArgument(), slice.index(), &ir_builder_); llvm::LoadInst* tempbuf_address_base = ir_builder_.CreateLoad(tempbuf_address_ptr); - // Loading the address of a buffer is invariant of the point at which the - // load is executed in the program because we never reassign buffers. - tempbuf_address_base->setMetadata( - llvm::LLVMContext::MD_invariant_load, - llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{})); + if (hlo_module_config_.debug_options() + .xla_llvm_enable_invariant_load_metadata()) { + // Loading the address of a buffer is invariant of the point at which the + // load is executed in the program because we never reassign buffers. + tempbuf_address_base->setMetadata( + llvm::LLVMContext::MD_invariant_load, + llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{})); + } llvm_ir::SetTbaaForInstruction(tempbuf_address_base, target_shape, /*is_pointer_to=*/true); AttachAlignmentMetadataForLoad(tempbuf_address_base, allocation.size()); @@ -1739,13 +2035,13 @@ llvm::Value* IrEmitter::EmitArrayFunctionCall( } StatusOr IrEmitter::EmitTargetAddressForOp( - const HloInstruction* op) { - const Shape& target_shape = op->shape(); - if (op == op->parent()->root_instruction()) { + const HloInstruction* op, const ShapeIndex& shape_index) { + const Shape& target_shape = ShapeUtil::GetSubshape(op->shape(), shape_index); + if (op == op->parent()->root_instruction() && shape_index.empty()) { // For the root node, we write directly to the output buffer of the // function. llvm::Argument* retval = GetResultArgument(); - if (!ShapeUtil::HasZeroElements(target_shape)) { + if (!ShapeUtil::IsNil(target_shape)) { llvm::AttrBuilder attr_builder; attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape)); attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape)); @@ -1773,16 +2069,106 @@ Status IrEmitter::EmitTargetElementLoop( TF_ASSIGN_OR_RETURN(llvm::Value * target_address, EmitTargetAddressForOp(target_op)); VLOG(2) << " target address: " << llvm_ir::DumpToString(*target_address); - llvm_ir::IrArray target_array(target_address, target_shape); - AddAliasingInformationToIrArray(*target_op, &target_array); - TF_RETURN_IF_ERROR( - llvm_ir::LoopEmitter(element_generator, target_array, &ir_builder_) - .EmitLoop()); + if (target_op->IsMultiOutputFusion()) { + // For multiple outputs fusion, we need to emit each operand and the root. + TF_RET_CHECK(num_dynamic_loop_bounds_ == 0); + std::vector output_arrays; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(target_shape); ++i) { + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, + assignment_.GetUniqueSlice(target_op, {i})); + const Shape& element_shape = ShapeUtil::GetSubshape(target_shape, {i}); + llvm::Value* op_target_address = + EmitTempBufferPointer(slice, element_shape); + output_arrays.push_back( + llvm_ir::IrArray(op_target_address, element_shape)); + } + TF_RETURN_IF_ERROR( + llvm_ir::LoopEmitter(element_generator, output_arrays, &ir_builder_) + .EmitLoop()); + + std::vector tuple_operand_ptrs; + for (int64 i = 0; i < output_arrays.size(); ++i) { + tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); + } + llvm_ir::EmitTuple(llvm_ir::IrArray(target_address, target_shape), + tuple_operand_ptrs, &ir_builder_); + + } else { + llvm_ir::IrArray target_array(target_address, target_shape); + AddAliasingInformationToIrArray(*target_op, &target_array); + + if (num_dynamic_loop_bounds_ > 0 && + target_op == target_op->parent()->root_instruction()) { + // Emit parallel loop for root instruction if dynamic outer-dimension loop + // bounds were specified. + TF_RETURN_IF_ERROR(EmitParallelTargetElementLoop( + target_shape, element_generator, &target_array)); + } else { + TF_RETURN_IF_ERROR( + llvm_ir::LoopEmitter(element_generator, target_array, &ir_builder_) + .EmitLoop()); + } + } + emitted_value_[target_op] = target_address; return Status::OK(); } +Status IrEmitter::EmitParallelTargetElementLoop( + const Shape& target_shape, + const llvm_ir::ElementGenerator& element_generator, + llvm_ir::IrArray* target_array) { + CHECK(!ShapeUtil::IsTuple(target_shape)); + CHECK(!ShapeUtil::IsScalar(target_shape)); + + // Emit code to read dynamic loop bounds from function argument 4. + std::vector dynamic_loop_bounds(2 * num_dynamic_loop_bounds_); + for (int i = 0; i < 2 * num_dynamic_loop_bounds_; ++i) { + dynamic_loop_bounds[i] = GetDynamicLoopBound(i); + } + + llvm_ir::ForLoopNest loop_nest(&ir_builder_); + const int64 num_dims = target_shape.dimensions_size(); + llvm_ir::IrArray::Index array_index(num_dims); + + // Add loops from outer-most to inner-most dimensions. + for (int i = target_shape.layout().minor_to_major_size() - 1; i >= 0; --i) { + const int64 dimension = target_shape.layout().minor_to_major(i); + const int bounds_index = num_dims - 1 - i; + if (bounds_index < num_dynamic_loop_bounds_) { + // Emit dynamic loop bounds for this dimension. Dynamic loop bounds + // are read from ir function dynamic loop bounds argument. + llvm::Value* start_index = dynamic_loop_bounds[bounds_index * 2 + 0]; + llvm::Value* end_index = dynamic_loop_bounds[bounds_index * 2 + 1]; + + std::unique_ptr loop = loop_nest.AddLoop( + /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension), + start_index, end_index); + array_index[dimension] = loop->GetIndVarValue(); + } else { + // Emit static loop bounds for this dimension. + std::unique_ptr loop = loop_nest.AddLoop( + /*start_index=*/0, + /*end_index=*/target_shape.dimensions(dimension), + /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension)); + array_index[dimension] = loop->GetIndVarValue(); + } + } + // Point IR builder at inner loop BB. + SetToFirstInsertPoint(loop_nest.GetInnerLoopBodyBasicBlock(), &ir_builder_); + + // Emit loop body. + TF_ASSIGN_OR_RETURN(llvm::Value * target_element, + element_generator(array_index)); + target_array->EmitWriteArrayElement(array_index, target_element, + &ir_builder_); + // Point IR builder at outer loop exit BB. + SetToFirstInsertPoint(loop_nest.GetOuterLoopExitBasicBlock(), &ir_builder_); + + return Status::OK(); +} + Status IrEmitter::EmitMemcpy(const HloInstruction& source, const HloInstruction& destination) { llvm::Value* source_value = GetEmittedValueFor(&source); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index ebb7296a075f266870fa179a0791dd6d0f77e29f..ff1ce7218b39e8556cd433878dde26432adf9624 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -96,7 +96,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleBitcast(HloInstruction* bitcast) override; Status HandleConstant(HloInstruction* constant, const Literal& literal) override; - Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override; + Status HandleCopy(HloInstruction* copy) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element, HloInstruction* operand) override; Status HandleSelect(HloInstruction* select, HloInstruction* pred, @@ -106,6 +106,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { HloInstruction* rhs) override; Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, HloInstruction* rhs, const Window& window) override; + Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; Status HandleInfeed(HloInstruction* infeed) override; Status HandleOutfeed(HloInstruction* infeed) override; @@ -192,6 +193,11 @@ class IrEmitter : public DfsHloVisitorWithDefault { // computation function being emitted by this emitter. llvm::Value* GetTempBuffersArgument(); + // Emit ir to read and return the ir value for the dynamic loop bound at + // 'offset' from the "dynamic_loop_bounds" argument of the computation + // function being emitted by this emitter. + llvm::Value* GetDynamicLoopBound(const int64 offset); + // Emits code that computes the address of the given temporary buffer to the // function. target_shape is the shape of this temporary buffer. // The returned Value's type is a pointer to element_type. @@ -262,6 +268,15 @@ class IrEmitter : public DfsHloVisitorWithDefault { HloInstruction* target_op, const llvm_ir::ElementGenerator& element_generator); + // Emit IR to perform a computation for every element in a partition/slice of + // 'target_shape'. The loop bounds for the outer-dimension partitions are + // passed into the compute function as a runtime argument (accessible from + // GetDynamicLoopBound). + Status EmitParallelTargetElementLoop( + const Shape& target_shape, + const llvm_ir::ElementGenerator& element_generator, + llvm_ir::IrArray* target_array); + // Emits a memcpy from the source instruction's result value to the // destination's. Both source and destination must have an entry in the // emitted_value_ table. @@ -271,7 +286,8 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Emit IR to compute the target address of the buffer for the given op. // The returned Value is a pointer to a IR type that represents the op's // element type. - StatusOr EmitTargetAddressForOp(const HloInstruction* op); + StatusOr EmitTargetAddressForOp( + const HloInstruction* op, const ShapeIndex& shape_index = {}); // Structurizes "array_elements" into an MD array that represents "shape". // This is a recursive function, and "dimension_index" indicates the index of @@ -319,6 +335,10 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm_ir::AliasAnalysis alias_analysis_; + // The number of root instruction outer dimensions used in parallel loop + // emission (EmitParallelTargetElementLoop). + int64 num_dynamic_loop_bounds_ = 0; + // This struct contains all the state needed to emit instructions for // profiling a computation. class ProfilingState { @@ -404,6 +424,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Returns the number of bytes within the shape. int64 ByteSizeOf(const Shape& shape) const; + // Emit IR to transfer an infeed buffer to the target address. + Status EmitInfeedTransfer(int64 length, llvm::Value* target_address); + const HloModuleConfig& hlo_module_config_; TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter); diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc index bdddca99c2f50c47ab112eda92ab1509f5448849..598858d4ed96fbefae1ca70518bc8457b3801bcb 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" +#include "tensorflow/compiler/xla/service/cpu/shape_partition.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -70,7 +71,7 @@ ParallelCpuExecutable::ParallelCpuExecutable( // Type of the computation function we expect in the JIT. using ComputeFunctionType = void (*)(void*, const void*, const void**, void**, - uint64*); + int64*, uint64*); // Given a pointer to an output buffer (following the CPU JIT calling // conventions), mark addresses that are "live". The initial pointer itself is @@ -95,6 +96,232 @@ static void MarkLiveAddressesInOutput( } } +namespace { + +// Executor manages the concurrent execution of 'functions' for instructions +// in 'pending' on 'thread_pool' (storing resulting data in 'results'). +class Executor { + public: + Executor(const std::map& functions, + const ServiceExecutableRunOptions* run_options, + std::list* pending, + std::map* results, void** temps_array, + uint64* profile_counters_array, BufferAssignment* assignment) + : functions_(functions), + run_options_(run_options), + pending_(pending), + results_(results), + temps_array_(temps_array), + profile_counters_array_(profile_counters_array), + thread_pool_(CHECK_NOTNULL(run_options_->xla_intra_op_thread_pool())), + assignment_(assignment) {} + + // Executes pending list of instructions on thread pool. + // Returns OK status on success, error status otherwise. + Status Run(); + + private: + // Schedules a parallel invocation of compute function for 'instruction' on + // 'thread_pool_', storing result in 'result_buffer'. + // If 'partition_buffers' is non-null, parallel task will be invoked on + // per-dimension partition [start, limit) values stored in + // 'partition_buffers'. + void Schedule(HloInstruction* instruction, int64* partition_buffers, + void* result_buffer); + + // Returns true if 'instruction' has been assigned parallel tasks (returns + // false otherwise). + bool HasParallelTasks(HloInstruction* instruction); + + // Returns in 'partition_buffers' the partition [size, limit) for each + // dimension. + int64* GetPartitionBuffers( + const std::vector>& partition); + + // Returns array of result buffers for all operands in 'instruction'. + const void** GetOperandBuffers(HloInstruction* instruction); + + // Arguments passed into Executor. + const std::map& functions_; + const ServiceExecutableRunOptions* run_options_; + std::list* pending_; + std::map* results_; + void** temps_array_; + uint64* profile_counters_array_; + tensorflow::thread::ThreadPool* thread_pool_; + BufferAssignment* assignment_; + + // Members used to manage instruction execution. + tensorflow::mutex completion_queue_lock_; + tensorflow::condition_variable completion_queue_cv_; + std::deque completion_queue_; + int64 instructions_in_flight_ = 0; + std::unordered_map tasks_in_flight_; +}; + +Status Executor::Run() { + while (!pending_->empty() || instructions_in_flight_ > 0) { + auto pending_it = pending_->begin(); + while (pending_it != pending_->end()) { + HloInstruction* instruction = *pending_it; + // Skip pending instructions whose operands aren't ready. + if (std::any_of(instruction->operands().begin(), + instruction->operands().end(), + [&](HloInstruction* operand) { + return !ContainsKey(*results_, operand); + })) { + ++pending_it; + continue; + } + + // Get 'result_buffer' reference to result buffer for 'instruction'. + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, + assignment_->GetUniqueTopLevelSlice(instruction)); + void* result_buffer = + static_cast(temps_array_[result_slice.index()]) + + result_slice.offset(); + + if (HasParallelTasks(instruction)) { + // 'instruction' has been assigned parallel task partitions. + CHECK_EQ(HloOpcode::kCall, instruction->opcode()); + HloInstruction* root = instruction->to_apply()->root_instruction(); + + // Create ShapePartitionIterator to iterate through all outer dimension + // partitions of 'instruction'. + ShapePartitionIterator partition_iterator( + root->shape(), root->outer_dimension_partitions()); + + const int64 partition_count = + partition_iterator.GetTotalPartitionCount(); + + // Record total parallel task count for 'instruction' before dispatch. + { + tensorflow::mutex_lock l(completion_queue_lock_); + tasks_in_flight_.insert(std::make_pair(instruction, partition_count)); + VLOG(2) << "Schedule PARALLEL" + << " instruction: " << instruction->name() + << " instruction.callee: " + << instruction->to_apply()->root_instruction()->name() + << " partition_count: " << partition_count; + } + + for (int64 i = 0; i < partition_count; ++i) { + // Get partition [start, limit) for each dimension. + auto partition_buffers = + GetPartitionBuffers(partition_iterator.GetPartition(i)); + Schedule(instruction, partition_buffers, result_buffer); + } + + } else { + // Set tasks in-flight to '1' for sequential instruction execution. + { + tensorflow::mutex_lock l(completion_queue_lock_); + tasks_in_flight_.insert(std::make_pair(instruction, 1)); + VLOG(2) << "Schedule SEQUENTIAL" + << " instruction: " << instruction->name() + << " instruction.callee: " + << instruction->to_apply()->root_instruction()->name(); + } + Schedule(instruction, nullptr, result_buffer); + } + + ++instructions_in_flight_; + pending_it = pending_->erase(pending_it); + } + // Wait for a completed HLO instruction to be present in the queue. We will + // pop it out of the queue and make the result available to its users. + HloInstruction* instruction; + do { + tensorflow::mutex_lock l(completion_queue_lock_); + if (completion_queue_.empty()) { + completion_queue_cv_.wait(l); + } + if (!completion_queue_.empty()) { + instruction = completion_queue_.front(); + completion_queue_.pop_front(); + break; + } + } while (1); + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, + assignment_->GetUniqueTopLevelSlice(instruction)); + void* result_buffer = + static_cast(temps_array_[result_slice.index()]) + + result_slice.offset(); + InsertOrDie(results_, instruction, result_buffer); + --instructions_in_flight_; + } + return Status::OK(); +} + +void Executor::Schedule(HloInstruction* instruction, int64* partition_buffers, + void* result_buffer) { + // The thread pool entry takes ownership of |operand_buffers|. + auto operand_buffers = GetOperandBuffers(instruction); + + auto function = FindOrDie(functions_, instruction); + const auto* exec_run_options = &run_options_->run_options(); + thread_pool_->Schedule([this, instruction, result_buffer, operand_buffers, + partition_buffers, exec_run_options, function]() { + function(result_buffer, exec_run_options, operand_buffers, temps_array_, + partition_buffers, profile_counters_array_); + + delete[] operand_buffers; + delete[] partition_buffers; + // Push the completed HLO instruction on the queue, the main + // thread will pop it off and potentially launch more work which + // uses the result. + // TODO(b/27458679) Consider alternative task scheduling and synchronization + // schemes. For example, we could avoid the overhead associate with the + // condvar here if the thread just dequed the next instruction to execute + // on completion. + { + tensorflow::mutex_lock l(completion_queue_lock_); + // Decrement in-flight task count for this completion. + if (--FindOrDie(tasks_in_flight_, instruction) == 0) { + completion_queue_.push_back(instruction); + completion_queue_cv_.notify_all(); + tasks_in_flight_.erase(instruction); + } + } + }); +} + +int64* Executor::GetPartitionBuffers( + const std::vector>& partition) { + // Return in 'partition_buffers' partition [size, limit) for each dimension. + auto partition_buffers = new int64[partition.size() * 2]; + for (int i = 0; i < partition.size(); ++i) { + partition_buffers[2 * i + 0] = partition[i].first; + partition_buffers[2 * i + 1] = partition[i].first + partition[i].second; + } + return partition_buffers; +} + +bool Executor::HasParallelTasks(HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kCall && + !instruction->to_apply() + ->root_instruction() + ->outer_dimension_partitions() + .empty(); +} + +const void** Executor::GetOperandBuffers(HloInstruction* instruction) { + // We cannot use a move-only RAII type like std::unique_ptr because the + // list of operands is allocated on the main thread and transferred to the + // worker via the lambda passed to enqueue_function. In order for the + // lambda to take ownership, we would need to use generalized lambda + // capture which is a feature new to C++14. + // TODO(b/27458679) Avoid dynamic allocations in Executor. + auto operand_buffers = new const void*[instruction->operand_count()]; + std::transform(instruction->operands().begin(), instruction->operands().end(), + operand_buffers, [this](HloInstruction* operand) { + return FindOrDie(*results_, operand); + }); + return operand_buffers; +} + +} // namespace + Status ParallelCpuExecutable::AllocateBuffers( DeviceMemoryAllocator* memory_allocator, int device_ordinal, std::vector* buffers) { @@ -210,88 +437,16 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( } } - void** temps_array = buffer_pointers.data(); - uint64* profile_counters_array = profile_counters.data(); - auto* thread_pool = CHECK_NOTNULL(run_options->xla_intra_op_thread_pool()); - tensorflow::mutex completion_queue_lock; - tensorflow::condition_variable completion_queue_cv; - std::deque completion_queue; - int64 instructions_in_flight = 0; - while (!pending.empty() || instructions_in_flight > 0) { - auto pending_it = pending.begin(); - while (pending_it != pending.end()) { - HloInstruction* instruction = *pending_it; - // Skip pending instructions whose operands aren't ready. - if (std::any_of(instruction->operands().begin(), - instruction->operands().end(), - [&](HloInstruction* operand) { - return !ContainsKey(results, operand); - })) { - ++pending_it; - continue; - } + // TODO(b/27458679) Manage scheduling based on in-flight concurrency limits. + // For example, if we expect a library conv/matmul call to run at max + // concurrency, we should not dispatch runnable instructions until the + // library call is finished (to avoid expensive cache invalidation). + Executor executor(functions, run_options, &pending, &results, + buffer_pointers.data(), profile_counters.data(), + assignment_.get()); - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, - assignment_->GetUniqueTopLevelSlice(instruction)); - void* result_buffer = - static_cast(temps_array[result_slice.index()]) + - result_slice.offset(); - // We cannot use a move-only RAII type like std::unique_ptr because the - // list of operands is allocated on the main thread and transferred to the - // worker via the lambda passed to enqueue_function. In order for the - // lambda to take ownership, we would need to use generalized lambda - // capture which is a feature new to C++14. - auto operand_buffers = new const void*[instruction->operand_count()]; - std::transform(instruction->operands().begin(), - instruction->operands().end(), operand_buffers, - [&results](HloInstruction* operand) { - return FindOrDie(results, operand); - }); - auto function = FindOrDie(functions, instruction); - // The thread pool entry takes ownership of |operand_buffers|. - const auto* exec_run_options = &run_options->run_options(); - thread_pool->Schedule([instruction, &completion_queue, - &completion_queue_lock, &completion_queue_cv, - result_buffer, exec_run_options, operand_buffers, - temps_array, profile_counters_array, function] { - function(result_buffer, exec_run_options, operand_buffers, temps_array, - profile_counters_array); - delete[] operand_buffers; - // Push the completed HLO instruction on the queue, the main thread - // will pop it off and potentially launch more work which uses the - // result. - { - tensorflow::mutex_lock l(completion_queue_lock); - completion_queue.push_back(instruction); - completion_queue_cv.notify_all(); - } - }); + TF_RETURN_IF_ERROR(executor.Run()); - ++instructions_in_flight; - pending_it = pending.erase(pending_it); - } - // Wait for a completed HLO instruction to be present in the queue. We will - // pop it out of the queue and make the result available to its users. - HloInstruction* instruction; - do { - tensorflow::mutex_lock l(completion_queue_lock); - if (completion_queue.empty()) { - completion_queue_cv.wait(l); - } - if (!completion_queue.empty()) { - instruction = completion_queue.front(); - completion_queue.pop_front(); - break; - } - } while (1); - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, - assignment_->GetUniqueTopLevelSlice(instruction)); - void* result_buffer = - static_cast(temps_array[result_slice.index()]) + - result_slice.offset(); - InsertOrDie(&results, instruction, result_buffer); - --instructions_in_flight; - } uint64 end_micros = tensorflow::Env::Default()->NowMicros(); { diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc index 8f1ce82d49a1c7cabfb62bf30e69faedc0318138..b3f4609d465efb4df8921abb684bafd263fe040f 100644 --- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc +++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc @@ -38,13 +38,12 @@ int main(int argc, char** argv) { // Transfer parameters. std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); + xla::Literal::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = client->TransferToServer(*param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = - xla::LiteralUtil::CreateR2( - {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}}); + std::unique_ptr param1_literal = xla::Literal::CreateR2( + {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}}); std::unique_ptr param1_data = client->TransferToServer(*param1_literal).ConsumeValueOrDie(); @@ -69,7 +68,7 @@ int main(int argc, char** argv) { LOG(INFO) << tensorflow::strings::Printf("computation took %lldns", profile.compute_time_ns()); - LOG(INFO) << xla::LiteralUtil::ToString(*actual); + LOG(INFO) << actual->ToString(); return 0; } diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition.cc b/tensorflow/compiler/xla/service/cpu/shape_partition.cc new file mode 100644 index 0000000000000000000000000000000000000000..61b408b8c24dded134218110d4e219c31f1685a8 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/shape_partition.cc @@ -0,0 +1,160 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/shape_partition.h" + +namespace xla { +namespace cpu { + +std::vector ShapePartitionAssigner::Run(int64 target_partition_count) { + // Gather outer-most dims where dim_size >= 'target_partition_count'. + // Note: always leave inner-dim static for vectorization/optimizations. + std::vector outer_dims; + int64 outer_dim_size = 1; + // TODO(b/27458679) Consider reserving enough minor dimensions (based on + // target vector register width) to enable vector instructions. + for (int i = shape_.layout().minor_to_major_size() - 1; i >= 1; --i) { + const int64 dimension = shape_.layout().minor_to_major(i); + outer_dims.push_back(dimension); + outer_dim_size *= shape_.dimensions(dimension); + if (outer_dim_size >= target_partition_count) { + break; + } + } + + // Clip target partition count if outer dim size is insufficient to cover. + target_partition_count = std::min(outer_dim_size, target_partition_count); + + // Calculate the target number of partitions per-dimension, by factoring + // 'target_partition_count' into 'num_outer_dims' equal terms. + // EX: + // *) target_partition_count = 16 + // *) out_dim_count = 2 + // *) target_dim_partition_count = 16 ^ (1.0 / 2) == 4 + const int64 target_dim_partition_count = std::pow( + static_cast(target_partition_count), 1.0 / outer_dims.size()); + + // Assign feasible dimension partitions based on 'target_dim_partition_count' + // and actual dimension sizes from 'shape_'. + std::vector dimension_partition_counts(outer_dims.size()); + for (int64 i = 0; i < outer_dims.size(); ++i) { + dimension_partition_counts[i] = + std::min(static_cast(shape_.dimensions(outer_dims[i])), + target_dim_partition_count); + } + + // Check if total partition count is below 'target_partition_count'. + // This can occur if some dimensions in 'shape_' are below the + // 'target_dim_partition_count' threshold. + if (GetTotalPartitionCount(dimension_partition_counts) < + target_partition_count) { + // Assign additional partitions (greedily to outer dimensions), if doing + // so would keep the total number of partitions <= 'target_partition_count', + // using one pass over 'dimension_partition_counts'. + for (int64 i = 0; i < dimension_partition_counts.size(); ++i) { + const int64 current_dim_partition_count = dimension_partition_counts[i]; + const int64 other_dims_partition_count = + GetTotalPartitionCount(dimension_partition_counts) / + current_dim_partition_count; + // Constraint: (current + additional) * other <= target + // Calculate: additional = target / other - current + int64 additional_partition_count = + target_partition_count / other_dims_partition_count - + current_dim_partition_count; + // Clip 'additional_partition_count' by current dimension size. + additional_partition_count = std::min( + shape_.dimensions(outer_dims[i]) - dimension_partition_counts[i], + additional_partition_count); + if (additional_partition_count > 0) { + dimension_partition_counts[i] += additional_partition_count; + } + } + } + + return dimension_partition_counts; +} + +int64 ShapePartitionAssigner::GetTotalPartitionCount( + const std::vector& dimension_partition_counts) { + int64 total_partition_count = 1; + for (int64 dim_partition_count : dimension_partition_counts) { + total_partition_count *= dim_partition_count; + } + return total_partition_count; +} + +ShapePartitionIterator::ShapePartitionIterator( + const Shape& shape, const std::vector& dimension_partition_counts) + : shape_(shape), + dimension_partition_counts_(dimension_partition_counts), + dimensions_(dimension_partition_counts_.size()), + dimension_partition_sizes_(dimension_partition_counts_.size()), + dimension_partition_strides_(dimension_partition_counts_.size()) { + // Store partitioned outer dimensions from 'shape_'. + for (int i = 0; i < dimensions_.size(); ++i) { + dimensions_[i] = shape_.layout().minor_to_major( + shape_.layout().minor_to_major_size() - 1 - i); + } + + // Calculate partition size for each dimension (note that the size of + // the last partition in each dimension may be different if the dimension + // size is not a multiple of partition size). + for (int i = 0; i < dimension_partition_sizes_.size(); ++i) { + const int64 dim_size = shape_.dimensions(dimensions_[i]); + dimension_partition_sizes_[i] = + std::max(1LL, dim_size / dimension_partition_counts_[i]); + } + + // Calculate the partition strides for each dimension. + dimension_partition_strides_[dimension_partition_strides_.size() - 1] = 1; + for (int i = dimension_partition_strides_.size() - 2; i >= 0; --i) { + dimension_partition_strides_[i] = dimension_partition_strides_[i + 1] * + dimension_partition_counts_[i + 1]; + } +} + +std::vector> ShapePartitionIterator::GetPartition( + int64 index) const { + // Calculate and return the partition for 'index'. + // Returns for each dimension: (partition_start, partition_size). + std::vector> partition(dimensions_.size()); + for (int64 i = 0; i < partition.size(); ++i) { + // Calculate the index for dimension 'i'. + const int64 partition_index = index / dimension_partition_strides_[i]; + // Calculate dimension partition start at 'partition_index'. + partition[i].first = partition_index * dimension_partition_sizes_[i]; + // Calculate dimension partition size (note that the last partition size + // may be adjusted if dimension size is not a multiple of partition size). + if (partition_index == dimension_partition_counts_[i] - 1) { + // Last partition in this dimension. + partition[i].second = + shape_.dimensions(dimensions_[i]) - partition[i].first; + } else { + partition[i].second = dimension_partition_sizes_[i]; + } + CHECK_GT(partition[i].second, 0); + // Update index to remove conribution from current dimension. + index -= partition_index * dimension_partition_strides_[i]; + } + return partition; +} + +int64 ShapePartitionIterator::GetTotalPartitionCount() const { + return ShapePartitionAssigner::GetTotalPartitionCount( + dimension_partition_counts_); +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition.h b/tensorflow/compiler/xla/service/cpu/shape_partition.h new file mode 100644 index 0000000000000000000000000000000000000000..7a2d00421cfdc8e41ec48698a16665621de16bda --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/shape_partition.h @@ -0,0 +1,105 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SHAPE_PARTITION_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SHAPE_PARTITION_H_ + +#include + +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { +namespace cpu { + +// ShapePartitionAssigner partitions the most-major dimensions of 'shape' such +// that the total partition count <= 'target_partition_count'. +// +// Example 1: +// +// Let 'shape' = [8, 16, 32] and 'target_partition_count' = 6. +// +// Because the most-major dimension size is <= 'target_partition_count', we +// can generate our target number of partitions by partition the most-major +// dimensions. +// +// This will result in the following partitions of the most-major dimension: +// +// [0, 1), [1, 2), [2, 3), [3, 4), [4, 5) [5, 8) +// +// Note that the last partition has residule because the dimension size is +// not a multiple of the partition count. +// +// +// Example 2: +// +// Let 'shape' = [8, 16, 32] and 'target_partition_count' = 16. +// +// Because the most-major dimension only has size 8, we must also partition +// the next most-major dimension to generate the target of 16 partitions. +// We factor 'target_partition_count' by the number of most-major dimensions +// we need to partition, to get a per-dimension target partition count: +// +// target_dimension_partition_count = 16 ^ (1 / 2) == 4 +// +// This will result in the following partitions of the most-major dimension: +// +// [0, 2), [2, 4), [4, 6), [6, 8) +// +// This will result in the following partitions of the second most-major +// dimension: +// +// [0, 4), [4, 8), [8, 12), [12, 16) +// +class ShapePartitionAssigner { + public: + ShapePartitionAssigner(const Shape& shape) : shape_(shape) {} + + // Returns dimension partition counts (starting at outer-most dimension). + std::vector Run(int64 target_partition_count); + + // Returns the total partition count based on 'dimension_partition_counts'. + static int64 GetTotalPartitionCount( + const std::vector& dimension_partition_counts); + + private: + const Shape& shape_; +}; + +// ShapePartitionIterator iterates through outer-dimension partitions of +// 'shape' as specified by 'dimension_partition_counts'. +class ShapePartitionIterator { + public: + ShapePartitionIterator(const Shape& shape, + const std::vector& dimension_partition_counts); + + // Returns a partition [start, size] for each dimension. + // Partitions are listed starting from outer-most dimension first. + std::vector> GetPartition(int64 index) const; + + int64 GetTotalPartitionCount() const; + + private: + const Shape& shape_; + const std::vector dimension_partition_counts_; + + std::vector dimensions_; + std::vector dimension_partition_sizes_; + std::vector dimension_partition_strides_; +}; + +} // namespace cpu +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SHAPE_PARTITION_H_ diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6cc6e3fe85b2464c3049e277db4535b733909d41 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc @@ -0,0 +1,248 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/shape_partition.h" + +#include +#include + +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +namespace cpu { +namespace { + +class ShapePartitionAssignerTest : public HloTestBase { + protected: + typedef std::vector Vec; + + void RunR2Test(const Shape& shape, const int64 expected_max_partition_count) { + ShapePartitionAssigner assigner(shape); + // Check all partitions of outer dimension. + for (int64 i = 1; i <= expected_max_partition_count; ++i) { + EXPECT_TRUE(ContainersEqual(Vec({i}), + assigner.Run(/*target_partition_count=*/i))); + } + // Check target_partition_count > outer dimension size. + EXPECT_TRUE(ContainersEqual( + Vec({expected_max_partition_count}), + assigner.Run( + /*target_partition_count=*/expected_max_partition_count + 1))); + } +}; + +TEST_F(ShapePartitionAssignerTest, Shape13WithLayout10) { + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {1, 3}, {1, 0}), 1); +} + +TEST_F(ShapePartitionAssignerTest, Shape31WithLayout01) { + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {3, 1}, {0, 1}), 1); +} + +TEST_F(ShapePartitionAssignerTest, Shape53WithLayout10) { + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3}, {1, 0}), 5); +} + +TEST_F(ShapePartitionAssignerTest, Shape53WithLayout01) { + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3}, {0, 1}), 3); +} + +TEST_F(ShapePartitionAssignerTest, Shape532WithLayout210) { + Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {5, 3, 2}, {2, 1, 0}); + ShapePartitionAssigner assigner(shape); + + for (int64 i = 1; i <= 5; ++i) { + EXPECT_TRUE(ContainersEqual(Vec({i}), assigner.Run( + /*target_partition_count=*/i))); + } + + EXPECT_TRUE( + ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/6))); + EXPECT_TRUE( + ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/7))); + EXPECT_TRUE( + ContainersEqual(Vec({4, 2}), assigner.Run(/*target_partition_count=*/8))); + EXPECT_TRUE( + ContainersEqual(Vec({3, 3}), assigner.Run(/*target_partition_count=*/9))); + EXPECT_TRUE(ContainersEqual(Vec({3, 3}), + assigner.Run(/*target_partition_count=*/10))); + EXPECT_TRUE(ContainersEqual(Vec({3, 3}), + assigner.Run(/*target_partition_count=*/11))); + EXPECT_TRUE(ContainersEqual(Vec({4, 3}), + assigner.Run(/*target_partition_count=*/12))); + EXPECT_TRUE(ContainersEqual(Vec({4, 3}), + assigner.Run(/*target_partition_count=*/13))); + EXPECT_TRUE(ContainersEqual(Vec({4, 3}), + assigner.Run(/*target_partition_count=*/14))); + EXPECT_TRUE(ContainersEqual(Vec({5, 3}), + assigner.Run(/*target_partition_count=*/15))); + EXPECT_TRUE(ContainersEqual(Vec({5, 3}), + assigner.Run(/*target_partition_count=*/16))); +} + +TEST_F(ShapePartitionAssignerTest, Shape532WithLayout201) { + Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {5, 3, 2}, {2, 0, 1}); + ShapePartitionAssigner assigner(shape); + + for (int64 i = 1; i <= 3; ++i) { + EXPECT_TRUE(ContainersEqual(Vec({i}), assigner.Run( + /*target_partition_count=*/i))); + } + + EXPECT_TRUE( + ContainersEqual(Vec({2, 2}), assigner.Run(/*target_partition_count=*/4))); + EXPECT_TRUE( + ContainersEqual(Vec({2, 2}), assigner.Run(/*target_partition_count=*/5))); + EXPECT_TRUE( + ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/6))); + EXPECT_TRUE( + ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/7))); + EXPECT_TRUE( + ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/8))); + EXPECT_TRUE( + ContainersEqual(Vec({3, 3}), assigner.Run(/*target_partition_count=*/9))); + EXPECT_TRUE(ContainersEqual(Vec({3, 3}), + assigner.Run(/*target_partition_count=*/10))); + EXPECT_TRUE(ContainersEqual(Vec({3, 3}), + assigner.Run(/*target_partition_count=*/11))); + EXPECT_TRUE(ContainersEqual(Vec({3, 4}), + assigner.Run(/*target_partition_count=*/12))); + EXPECT_TRUE(ContainersEqual(Vec({3, 4}), + assigner.Run(/*target_partition_count=*/13))); + EXPECT_TRUE(ContainersEqual(Vec({3, 4}), + assigner.Run(/*target_partition_count=*/14))); + EXPECT_TRUE(ContainersEqual(Vec({3, 5}), + assigner.Run(/*target_partition_count=*/15))); + EXPECT_TRUE(ContainersEqual(Vec({3, 5}), + assigner.Run(/*target_partition_count=*/16))); +} + +class ShapePartitionIteratorTest : public HloTestBase { + protected: + typedef std::vector> Partition; +}; + +TEST_F(ShapePartitionIteratorTest, Shape53WithLayout10) { + Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {5, 3}, {1, 0}); + + { + ShapePartitionIterator iterator(shape, {1}); + EXPECT_EQ(1, iterator.GetTotalPartitionCount()); + EXPECT_TRUE(ContainersEqual(Partition({{0, 5}}), iterator.GetPartition(0))); + } + + { + ShapePartitionIterator iterator(shape, {2}); + EXPECT_EQ(2, iterator.GetTotalPartitionCount()); + EXPECT_TRUE(ContainersEqual(Partition({{0, 2}}), iterator.GetPartition(0))); + EXPECT_TRUE(ContainersEqual(Partition({{2, 3}}), iterator.GetPartition(1))); + } + + { + ShapePartitionIterator iterator(shape, {3}); + EXPECT_EQ(3, iterator.GetTotalPartitionCount()); + EXPECT_TRUE(ContainersEqual(Partition({{0, 1}}), iterator.GetPartition(0))); + EXPECT_TRUE(ContainersEqual(Partition({{1, 1}}), iterator.GetPartition(1))); + EXPECT_TRUE(ContainersEqual(Partition({{2, 3}}), iterator.GetPartition(2))); + } +} + +TEST_F(ShapePartitionIteratorTest, Shape532WithLayout210) { + Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {5, 3, 2}, {2, 1, 0}); + + { + ShapePartitionIterator iterator(shape, {1, 1}); + EXPECT_EQ(1, iterator.GetTotalPartitionCount()); + EXPECT_TRUE( + ContainersEqual(Partition({{0, 5}, {0, 3}}), iterator.GetPartition(0))); + } + + { + ShapePartitionIterator iterator(shape, {2, 2}); + EXPECT_EQ(4, iterator.GetTotalPartitionCount()); + EXPECT_TRUE( + ContainersEqual(Partition({{0, 2}, {0, 1}}), iterator.GetPartition(0))); + EXPECT_TRUE( + ContainersEqual(Partition({{0, 2}, {1, 2}}), iterator.GetPartition(1))); + EXPECT_TRUE( + ContainersEqual(Partition({{2, 3}, {0, 1}}), iterator.GetPartition(2))); + EXPECT_TRUE( + ContainersEqual(Partition({{2, 3}, {1, 2}}), iterator.GetPartition(3))); + } +} + +class RandomShapePartitionIteratorTest : public HloTestBase { + protected: + typedef std::vector> Partition; + RandomShapePartitionIteratorTest() + : generator_(rd_()), distribution_(1, 10) {} + + std::vector RandR4Dims() { return {Rand(), Rand(), Rand(), Rand()}; } + + int64 Rand() { return distribution_(generator_); } + + std::random_device rd_; + std::mt19937 generator_; + std::uniform_int_distribution distribution_; +}; + +TEST_F(RandomShapePartitionIteratorTest, RandomShapeAndPartitions) { + // Choose random dimensions for R4 shape. + Shape shape = ShapeUtil::MakeShapeWithLayout(F32, RandR4Dims(), {3, 2, 1, 0}); + // Choose random number of outer dimensions to partition. + const int num_outer_dims_to_partiton = 1 + (Rand() % 3); + // Choose random outer dimension partiton counts. + std::vector dim_sizes(num_outer_dims_to_partiton); + std::vector dim_partition_counts(num_outer_dims_to_partiton); + int64 total_dim_size = 1; + for (int i = 0; i < num_outer_dims_to_partiton; ++i) { + const int64 dimension = shape.layout().minor_to_major( + shape.layout().minor_to_major_size() - 1 - i); + dim_sizes[i] = shape.dimensions(dimension); + total_dim_size *= dim_sizes[i]; + // Choose dimension partition count in [1, dim_size] + const int64 dim_partition_count = 1 + Rand() % dim_sizes[i]; + dim_partition_counts[i] = dim_partition_count; + } + // Iterate through all partition: for each partition record covered + // index ranges by dimension. + std::vector> ranges(num_outer_dims_to_partiton); + ShapePartitionIterator partition_iterator(shape, dim_partition_counts); + const int64 partition_count = partition_iterator.GetTotalPartitionCount(); + for (int64 i = 0; i < partition_count; ++i) { + const auto& dim_partition = partition_iterator.GetPartition(i); + for (int dim = 0; dim < dim_partition.size(); ++dim) { + ranges[dim].insert( + std::make_pair(dim_partition[dim].first, + dim_partition[dim].first + dim_partition[dim].second)); + } + } + // Check that partitions cover entire dimension size range (for each + // partitioned dimension). + for (int i = 0; i < ranges.size(); ++i) { + int64 expected_index = 0; + for (auto& r : ranges[i]) { + EXPECT_EQ(expected_index, r.first); + expected_index = r.second; + } + EXPECT_EQ(expected_index, dim_sizes[i]); + } +} + +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index 7c74912a7ab9c388c9911fe8194f268623f0abd1..8f567b4f8c9dc2bcb87616045669ef845092ca1e 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -21,11 +21,12 @@ limitations under the License. #include #include +#include "external/llvm/include/llvm/ExecutionEngine/ExecutionEngine.h" +#include "external/llvm/include/llvm/ExecutionEngine/SectionMemoryManager.h" #include "external/llvm/include/llvm/IR/Mangler.h" #include "external/llvm/include/llvm/Support/CodeGen.h" #include "external/llvm/include/llvm/Support/Host.h" #include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h" @@ -141,7 +142,9 @@ CompilerFunctor::VectorIntrinsics GetAvailableIntrinsics() { } // namespace SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions &target_options, - llvm::CodeGenOpt::Level opt_level) + llvm::CodeGenOpt::Level opt_level, + OptimizationCallback pre_optimization_callback, + OptimizationCallback post_optimization_callback) : target_machine_( CHECK_NOTNULL(llvm::EngineBuilder() .setTargetOptions(target_options) @@ -152,30 +155,30 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions &target_options, /*MAttrs=*/DetectMachineAttributes()))), disassembler_(*target_machine_), data_layout_(target_machine_->createDataLayout()), + object_layer_( + [] { return std::make_shared(); }), compile_layer_(object_layer_, CompilerFunctor(target_machine_.get(), &disassembler_, - opt_level, GetAvailableIntrinsics())) { + opt_level, GetAvailableIntrinsics(), + std::move(pre_optimization_callback), + std::move(post_optimization_callback))) { VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str() << " features: " << target_machine_->getTargetFeatureString().str(); } SimpleOrcJIT::ModuleHandleT SimpleOrcJIT::AddModule( std::unique_ptr module) { - // The Orc API adds a whole iterable "set" of modules, so we wrap the module - // in a vector. - std::vector> module_set; - module_set.push_back(std::move(module)); - auto handle = compile_layer_.addModuleSet( - std::move(module_set), MakeUnique(), - MakeUnique()); + auto handle = + compile_layer_.addModule(std::move(module), MakeUnique()); module_handles_.push_back(handle); return handle; } void SimpleOrcJIT::RemoveModule(SimpleOrcJIT::ModuleHandleT handle) { module_handles_.erase( - std::remove(module_handles_.begin(), module_handles_.end(), handle)); - compile_layer_.removeModuleSet(handle); + std::remove(module_handles_.begin(), module_handles_.end(), handle), + module_handles_.end()); + compile_layer_.removeModule(handle); } llvm::JITSymbol SimpleOrcJIT::FindSymbol(const std::string &name) { diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h index 4d8653484a037a345321dbe11c384f650e0142d0..e28717bcd1dfa64bc39147531484071669eae9d9 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h @@ -25,6 +25,7 @@ limitations under the License. #include "external/llvm/include/llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" #include "external/llvm/include/llvm/IR/Module.h" #include "external/llvm/include/llvm/Target/TargetMachine.h" +#include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" #include "tensorflow/compiler/xla/service/cpu/disassembler.h" #include "tensorflow/compiler/xla/types.h" @@ -41,9 +42,13 @@ namespace cpu { // it's added to the JIT. class SimpleOrcJIT { public: - using ObjLayerT = llvm::orc::RTDyldObjectLinkingLayer<>; - using CompileLayerT = llvm::orc::IRCompileLayer; - using ModuleHandleT = CompileLayerT::ModuleSetHandleT; + using ObjLayerT = llvm::orc::RTDyldObjectLinkingLayer; + using CompileFtor = + std::function( + llvm::Module&)>; + using CompileLayerT = llvm::orc::IRCompileLayer; + using ModuleHandleT = CompileLayerT::ModuleHandleT; + using OptimizationCallback = CompilerFunctor::OptimizationCallback; // Create a new JIT, targeting the host architecture. // The |target_options| parameter allows customization of certain code @@ -51,8 +56,14 @@ class SimpleOrcJIT { // can be reassociated, etc.). // The |opt_level| parameter controls the optimization level of the code // generator. + // The |pre_optimization_callback| is invoked on the module before any IR + // level optimizations are applied. + // The |post_optimization_callback| is invoked on the module after all IR + // level optimizations are applied. SimpleOrcJIT(const llvm::TargetOptions& target_options, - llvm::CodeGenOpt::Level opt_level); + llvm::CodeGenOpt::Level opt_level, + OptimizationCallback pre_optimization_callback, + OptimizationCallback post_optimization_callback); // Data layout this JIT was created with. const llvm::DataLayout& data_layout() const { return data_layout_; } diff --git a/tensorflow/compiler/xla/service/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu_transfer_manager.cc index 2d9d9c7de62a34e4d18ef1d7f5552a85ad1c49cb..262ba83f3dc688c2edd1ca424e2b855153a08175 100644 --- a/tensorflow/compiler/xla/service/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu_transfer_manager.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" -#include "tensorflow/compiler/xla/service/cpu/infeed_manager.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -29,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -66,32 +66,85 @@ CpuTransferManager::CpuTransferManager() Status CpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, const Literal& literal) { const Shape& shape = literal.shape(); - VLOG(2) << "transferring literal shape to infeed: " + VLOG(2) << "Transferring literal to infeed with shape: " << ShapeUtil::HumanString(shape); - // TODO(b/31381668) handle tuples. - if (ShapeUtil::IsTuple(shape)) { - return Unimplemented("Infeed with a tuple shape is not supported: %s", - ShapeUtil::HumanString(literal.shape()).c_str()); + if (!ShapeUtil::IsTuple(shape)) { + int64 size = GetByteSizeRequirement(shape); + return TransferBufferToInfeed(executor, size, literal.InternalData()); } + if (ShapeUtil::IsNestedTuple(shape)) { + return Unimplemented( + "Infeed with a nested tuple shape is not supported: %s", + ShapeUtil::HumanString(literal.shape()).c_str()); + } + + // For a tuple, we transfer each of its elements to the device and + // enqueue the resulting destination device addresses with the + // infeed manager. + std::vector buffers; + buffers.reserve(literal.tuple_literals_size()); + auto cleanup = tensorflow::gtl::MakeCleanup([buffers]() { + for (cpu::runtime::InfeedBuffer* b : buffers) { + b->Done(); + } + }); + + for (const auto& tuple_element : literal.tuple_literals()) { + const Shape& tuple_element_shape = tuple_element.shape(); + int64 tuple_element_size = GetByteSizeRequirement(tuple_element_shape); + TF_ASSIGN_OR_RETURN( + cpu::runtime::InfeedBuffer * buffer, + TransferBufferToInfeedInternal(executor, tuple_element_size, + tuple_element.InternalData())); + buffers.push_back(buffer); + } + + cpu::runtime::InfeedManager* infeed_manager = + cpu::runtime::GetInfeedManager(); + infeed_manager->EnqueueBuffers(buffers); + + cleanup.release(); + return Status::OK(); +} + +Status CpuTransferManager::TransferBufferToInfeed(se::StreamExecutor* executor, + int64 size, + const void* source) { + TF_ASSIGN_OR_RETURN(cpu::runtime::InfeedBuffer * buffer, + TransferBufferToInfeedInternal(executor, size, source)); + cpu::runtime::InfeedManager* infeed_manager = cpu::runtime::GetInfeedManager(); + infeed_manager->EnqueueBuffers({buffer}); - int64 size = GetByteSizeRequirement(shape); + return Status::OK(); +} + +StatusOr +CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor, + int64 size, + const void* source) { if (size > std::numeric_limits::max()) { - return Unimplemented("Infeed shape is too large: %s needs %lld bytes", - ShapeUtil::HumanString(literal.shape()).c_str(), size); + return InvalidArgument("Infeed shape is too large: needs %lld bytes", size); + } + + if (size == 0) { + return InvalidArgument("Infeed shape needs 0 bytes"); } + int32 size_32 = static_cast(size); CpuInfeedBuffer* queued_buffer = new CpuInfeedBuffer(size_32); - TF_RETURN_IF_ERROR(TransferBufferToDevice( - executor, /*size=*/size, /*source=*/LiteralUtil::InternalData(literal), - queued_buffer->device_memory())); - - infeed_manager->EnqueueBuffer(queued_buffer); + Status s = + TransferBufferToDevice(executor, /*size=*/size, + /*source=*/source, queued_buffer->device_memory()); - return Status::OK(); + if (!s.ok()) { + queued_buffer->Done(); + return s; + } + return queued_buffer; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu_transfer_manager.h index 727462252d7291959fd09c05c87e36411eb3ddab..96ffb94d7127620ffe5c73441d11f6e28952b22f 100644 --- a/tensorflow/compiler/xla/service/cpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/cpu_transfer_manager.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/cpu/infeed_manager.h" #include "tensorflow/compiler/xla/service/generic_transfer_manager.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/statusor.h" @@ -37,8 +38,16 @@ class CpuTransferManager : public GenericTransferManager { Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor, const Literal& literal) override; + Status TransferBufferToInfeed(perftools::gputools::StreamExecutor* executor, + int64 size, const void* source) override; private: + // Transfers infeed data to device. InfeedBuffer->Done() must be + // called to clean up the memory allocated for InfeedBuffer. + StatusOr TransferBufferToInfeedInternal( + perftools::gputools::StreamExecutor* executor, int64 size, + const void* source); + TF_DISALLOW_COPY_AND_ASSIGN(CpuTransferManager); }; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc index 5b296861006923f438df1ad4fb5898f82f11b9e0..0f7ab111170a3152cbe86c1a4fa8d592a14d6241 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc @@ -24,51 +24,29 @@ limitations under the License. namespace xla { Status DfsHloVisitor::HandleElementwiseUnary(HloInstruction* hlo, - HloOpcode opcode, - HloInstruction* operand) { + HloOpcode opcode) { return Unimplemented("DfsHloVisitor::HandleElementwiseUnary: %s", HloOpcodeString(opcode).c_str()); } Status DfsHloVisitor::HandleElementwiseBinary(HloInstruction* hlo, - HloOpcode opcode, - HloInstruction* lhs, - HloInstruction* rhs) { + HloOpcode opcode) { return Unimplemented("DfsHloVisitor::HandleElementwiseBinary: %s", HloOpcodeString(opcode).c_str()); } void DfsHloVisitor::SetVisiting(const HloInstruction& instruction) { VLOG(3) << "marking HLO " << &instruction << " as visiting: "; - CHECK(NotVisited(instruction)); + DCHECK(NotVisited(instruction)); visit_state_[&instruction] = VisitState::kVisiting; } void DfsHloVisitor::SetVisited(const HloInstruction& instruction) { VLOG(3) << "marking HLO " << &instruction << " as visited: "; - CHECK(NotVisited(instruction) || IsVisiting(instruction)); + DCHECK(NotVisited(instruction) || IsVisiting(instruction)); visit_state_[&instruction] = VisitState::kVisited; } -bool DfsHloVisitor::IsVisiting(const HloInstruction& instruction) { - if (visit_state_.count(&instruction) == 0) { - return false; - } - return visit_state_[&instruction] == VisitState::kVisiting; -} - -bool DfsHloVisitor::DidVisit(const HloInstruction& instruction) { - if (visit_state_.count(&instruction) == 0) { - return false; - } - return visit_state_[&instruction] == VisitState::kVisited; -} - -bool DfsHloVisitor::NotVisited(const HloInstruction& instruction) { - return visit_state_.count(&instruction) == 0 || - visit_state_[&instruction] == VisitState::kNotVisited; -} - Status DfsHloVisitor::Preprocess(HloInstruction* hlo) { return Status::OK(); } Status DfsHloVisitor::Postprocess(HloInstruction* visited) { diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 78a398f8efa870fcfbda78a769b3f6878a8a429b..fcc4f85f01c42f34e5f495089eee0bd5671a4c72 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -65,43 +65,37 @@ class DfsHloVisitor { // These routines are self-descriptive, see class comment for usage // information. - virtual Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode, - HloInstruction* operand); - virtual Status HandleElementwiseBinary(HloInstruction* hlo, HloOpcode opcode, - HloInstruction* lhs, - HloInstruction* rhs); + virtual Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode); + virtual Status HandleElementwiseBinary(HloInstruction* hlo, HloOpcode opcode); virtual Status HandleClamp(HloInstruction* clamp, HloInstruction* min, HloInstruction* arg, HloInstruction* max) = 0; virtual Status HandleSelect(HloInstruction* select, HloInstruction* pred, HloInstruction* on_true, HloInstruction* on_false) = 0; - virtual Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs, - HloInstruction* rhs) { - return HandleElementwiseBinary(maximum, HloOpcode::kMaximum, lhs, rhs); + virtual Status HandleMaximum(HloInstruction* maximum) { + return HandleElementwiseBinary(maximum, HloOpcode::kMaximum); } - virtual Status HandleMinimum(HloInstruction* minimum, HloInstruction* lhs, - HloInstruction* rhs) { - return HandleElementwiseBinary(minimum, HloOpcode::kMinimum, lhs, rhs); + virtual Status HandleMinimum(HloInstruction* minimum) { + return HandleElementwiseBinary(minimum, HloOpcode::kMinimum); } virtual Status HandleConcatenate( HloInstruction* concatenate, tensorflow::gtl::ArraySlice operands) = 0; - virtual Status HandleConvert(HloInstruction* convert, - HloInstruction* operand) { - return HandleElementwiseUnary(convert, HloOpcode::kConvert, operand); + virtual Status HandleConvert(HloInstruction* convert) { + return HandleElementwiseUnary(convert, HloOpcode::kConvert); } - virtual Status HandleCopy(HloInstruction* copy, HloInstruction* operand) { - return HandleElementwiseUnary(copy, HloOpcode::kCopy, operand); + virtual Status HandleCopy(HloInstruction* copy) { + return HandleElementwiseUnary(copy, HloOpcode::kCopy); } virtual Status HandleMultiply(HloInstruction* multiply, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(multiply, HloOpcode::kMultiply, lhs, rhs); + return HandleElementwiseBinary(multiply, HloOpcode::kMultiply); } virtual Status HandleDot(HloInstruction* dot, HloInstruction* lhs, HloInstruction* rhs) = 0; virtual Status HandlePower(HloInstruction* power, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(power, HloOpcode::kPower, lhs, rhs); + return HandleElementwiseBinary(power, HloOpcode::kPower); } virtual Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, HloInstruction* rhs, @@ -109,64 +103,70 @@ class DfsHloVisitor { virtual Status HandleCrossReplicaSum(HloInstruction* crs) = 0; virtual Status HandleCompare(HloInstruction* compare, HloOpcode opcode, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(compare, opcode, lhs, rhs); + return HandleElementwiseBinary(compare, opcode); } virtual Status HandleAdd(HloInstruction* add, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(add, HloOpcode::kAdd, lhs, rhs); + return HandleElementwiseBinary(add, HloOpcode::kAdd); } virtual Status HandleDivide(HloInstruction* divide, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(divide, HloOpcode::kDivide, lhs, rhs); + return HandleElementwiseBinary(divide, HloOpcode::kDivide); } virtual Status HandleRemainder(HloInstruction* remainder, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(remainder, HloOpcode::kRemainder, lhs, rhs); + return HandleElementwiseBinary(remainder, HloOpcode::kRemainder); } virtual Status HandleSubtract(HloInstruction* subtract, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(subtract, HloOpcode::kSubtract, lhs, rhs); + return HandleElementwiseBinary(subtract, HloOpcode::kSubtract); } virtual Status HandleAbs(HloInstruction* abs, HloInstruction* operand) { - return HandleElementwiseUnary(abs, HloOpcode::kAbs, operand); + return HandleElementwiseUnary(abs, HloOpcode::kAbs); } virtual Status HandleSign(HloInstruction* sign, HloInstruction* operand) { - return HandleElementwiseUnary(sign, HloOpcode::kSign, operand); + return HandleElementwiseUnary(sign, HloOpcode::kSign); } virtual Status HandleNegate(HloInstruction* negate, HloInstruction* operand) { - return HandleElementwiseUnary(negate, HloOpcode::kNegate, operand); + return HandleElementwiseUnary(negate, HloOpcode::kNegate); } virtual Status HandleExp(HloInstruction* exp, HloInstruction* operand) { - return HandleElementwiseUnary(exp, HloOpcode::kExp, operand); + return HandleElementwiseUnary(exp, HloOpcode::kExp); } virtual Status HandleFloor(HloInstruction* floor, HloInstruction* operand) { - return HandleElementwiseUnary(floor, HloOpcode::kFloor, operand); + return HandleElementwiseUnary(floor, HloOpcode::kFloor); } virtual Status HandleCeil(HloInstruction* ceil, HloInstruction* operand) { - return HandleElementwiseUnary(ceil, HloOpcode::kCeil, operand); + return HandleElementwiseUnary(ceil, HloOpcode::kCeil); } virtual Status HandleLog(HloInstruction* log, HloInstruction* operand) { - return HandleElementwiseUnary(log, HloOpcode::kLog, operand); + return HandleElementwiseUnary(log, HloOpcode::kLog); + } + virtual Status HandleCos(HloInstruction* cos, HloInstruction* operand) { + return HandleElementwiseUnary(cos, HloOpcode::kCos); } virtual Status HandleTanh(HloInstruction* tanh, HloInstruction* operand) { - return HandleElementwiseUnary(tanh, HloOpcode::kTanh, operand); + return HandleElementwiseUnary(tanh, HloOpcode::kTanh); } virtual Status HandleIsFinite(HloInstruction* is_finite, HloInstruction* operand) { - return HandleElementwiseUnary(is_finite, HloOpcode::kIsFinite, operand); + return HandleElementwiseUnary(is_finite, HloOpcode::kIsFinite); } virtual Status HandleLogicalAnd(HloInstruction* logical_and, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(logical_and, HloOpcode::kLogicalAnd, lhs, - rhs); + return HandleElementwiseBinary(logical_and, HloOpcode::kLogicalAnd); } virtual Status HandleLogicalNot(HloInstruction* logical_not, HloInstruction* operand) { - return HandleElementwiseUnary(logical_not, HloOpcode::kLogicalNot, operand); + return HandleElementwiseUnary(logical_not, HloOpcode::kLogicalNot); } virtual Status HandleLogicalOr(HloInstruction* logical_or, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(logical_or, HloOpcode::kLogicalOr, lhs, rhs); + return HandleElementwiseBinary(logical_or, HloOpcode::kLogicalOr); + } + virtual Status HandleReducePrecision(HloInstruction* reduce_precision) { + return HandleElementwiseUnary(reduce_precision, + HloOpcode::kReducePrecision); } virtual Status HandleInfeed(HloInstruction* infeed) = 0; @@ -225,6 +225,8 @@ class DfsHloVisitor { virtual Status HandleRecv(HloInstruction* recv) = 0; + virtual Status HandleBatchNormTraining(HloInstruction* batchNormTraining) = 0; + // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". virtual Status FinishVisit(HloInstruction* root) = 0; @@ -237,6 +239,14 @@ class DfsHloVisitor { kVisited, }; + VisitState GetVisitState(const HloInstruction& instruction) { + auto it = visit_state_.find(&instruction); + if (it == visit_state_.end()) { + return kNotVisited; + } + return it->second; + } + // Sets the visitation state of the given instruction as kVisiting. // // Precondition: current state must be kNotVisited. @@ -248,13 +258,19 @@ class DfsHloVisitor { void SetVisited(const HloInstruction& instruction); // Returns whether the state of the given instruction is kVisiting. - bool IsVisiting(const HloInstruction& instruction); + bool IsVisiting(const HloInstruction& instruction) { + return GetVisitState(instruction) == kVisiting; + } // Returns whether the state of the given instruction is kVisited. - bool DidVisit(const HloInstruction& instruction); + bool DidVisit(const HloInstruction& instruction) { + return GetVisitState(instruction) == kVisited; + } // Returns whether the state of the given instruction is kNotVisited. - bool NotVisited(const HloInstruction& instruction); + bool NotVisited(const HloInstruction& instruction) { + return GetVisitState(instruction) == kNotVisited; + } // This method should be overridden by subclasses that wish to run some // operation on an op before its Handle* visitor method is called. @@ -279,7 +295,7 @@ class DfsHloVisitor { private: // Tracks the visitation state of each instruction. Any instructions that are - // not found from the map are considered as VisitState::kNotVisited. + // not found in the map are considered as VisitState::kNotVisited. tensorflow::gtl::FlatMap visit_state_; TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitor); diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 6557c3aa8e6b8356887432c6dd91d326603fc1e7..2970ba8cc41eaa1f928d69f1b70051591d5efd5d 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -41,15 +41,19 @@ class DfsHloVisitorWithDefault : public DfsHloVisitor { // Default action performed on HloInstruction. virtual Status DefaultAction(HloInstruction* hlo_instruction) = 0; - Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode, - HloInstruction* operand) override { + Status HandleElementwiseUnary(HloInstruction* hlo, + HloOpcode opcode) override { return DefaultAction(hlo); } - Status HandleElementwiseBinary(HloInstruction* hlo, HloOpcode opcode, - HloInstruction* lhs, - HloInstruction* rhs) override { + Status HandleElementwiseBinary(HloInstruction* hlo, + HloOpcode opcode) override { return DefaultAction(hlo); } + + Status HandleBatchNormTraining(HloInstruction* hlo) override { + return DefaultAction(hlo); + } + Status HandleClamp(HloInstruction* clamp, HloInstruction* /*min*/, HloInstruction* /*arg*/, HloInstruction* /*max*/) override { @@ -60,12 +64,10 @@ class DfsHloVisitorWithDefault : public DfsHloVisitor { tensorflow::gtl::ArraySlice /*operands*/) override { return DefaultAction(concatenate); } - Status HandleConvert(HloInstruction* convert, - HloInstruction* /*operand*/) override { + Status HandleConvert(HloInstruction* convert) override { return DefaultAction(convert); } - Status HandleCopy(HloInstruction* copy, - HloInstruction* /*operand*/) override { + Status HandleCopy(HloInstruction* copy) override { return DefaultAction(copy); } Status HandleSelect(HloInstruction* select, HloInstruction* /*pred*/, diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index be4aadb6522b8d6ad9d6425df56c1746c3849f11..9c4380b39d5ba449c586b302783c10a221decf7a 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -63,7 +63,7 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); PrimitiveType to_type = op->shape().element_type(); - CHECK(primitive_util::IsIntegralType(from_type)); + CHECK(primitive_util::IsIntegralType(from_type) || from_type == PRED); if (from_type == to_type) { return operand_value; } @@ -78,7 +78,8 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_)); } - if (primitive_util::IsUnsignedIntegralType(from_type)) { + if (primitive_util::IsUnsignedIntegralType(from_type) || + from_type == PRED) { return ir_builder_->CreateUIToFP( operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_)); @@ -172,6 +173,10 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {operand_value}, {operand_value->getType()}, ir_builder_); + case HloOpcode::kCos: + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {operand_value}, + {operand_value->getType()}, + ir_builder_); case HloOpcode::kFloor: return llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::floor, {operand_value}, {operand_value->getType()}, @@ -381,6 +386,118 @@ StatusOr ElementalIrEmitter::EmitErfcInv( return EmitErfInv(prim_type, ir_builder_->CreateFSub(one, value)); } +StatusOr ElementalIrEmitter::EmitReducePrecision( + const HloInstruction* hlo, llvm::Value* x) const { + if (hlo->operand(0)->shape().element_type() != F32) { + return Unimplemented("reduce-precision only implemented for F32"); + } + + // Integer and float types for casting and constant generation. + llvm::Type* float_type = x->getType(); + llvm::IntegerType* int_type = ir_builder_->getInt32Ty(); + + // Cast the input value to an integer for bitwise manipulation. + llvm::Value* x_as_int = ir_builder_->CreateBitCast(x, int_type); + + if (hlo->mantissa_bits() < 23) { + // Last remaining mantissa bit. + const uint32_t last_mantissa_bit_mask = 1u << (23 - hlo->mantissa_bits()); + + // Compute rounding bias for round-to-nearest with ties to even. This is + // equal to a base value of 0111... plus one bit if the last remaining + // mantissa bit is 1. + const uint32_t base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1; + llvm::Value* x_last_mantissa_bit = ir_builder_->CreateLShr( + ir_builder_->CreateAnd( + x_as_int, llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)), + (23 - hlo->mantissa_bits())); + llvm::Value* x_rounding_bias = ir_builder_->CreateAdd( + x_last_mantissa_bit, + llvm::ConstantInt::get(int_type, base_rounding_bias)); + + // Add rounding bias, and mask out truncated bits. Note that the case + // where adding the rounding bias overflows into the exponent bits is + // correct; the non-masked mantissa bits will all be zero, and the + // exponent will be incremented by one. + const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1); + x_as_int = ir_builder_->CreateAdd(x_as_int, x_rounding_bias); + x_as_int = ir_builder_->CreateAnd( + x_as_int, llvm::ConstantInt::get(int_type, truncation_mask)); + } + + if (hlo->exponent_bits() < 8) { + // Masks for f32 values. + const uint32_t f32_sign_bit_mask = 1u << 31; + const uint32_t f32_exp_bits_mask = 0xffu << 23; + + // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most- + // significant bit -- is equal to 1.0f for all exponent sizes. Adding + // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit- + // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest' + // exponent (corresponding to 0.0f). + // + // Thus, the f32 exponent corresponding to the highest non-infinite + // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 + // exponent corresponding to the lowest exponent for a bit size of n is + // (2^7-1) - 2^(n-1)-1. + // + // Note that we have already checked that exponents_bits >= 1. + const uint32_t f32_exponent_bias = (1 << 7) - 1; + const uint32_t reduced_exponent_bias = + (1 << (hlo->exponent_bits() - 1)) - 1; + const uint32_t reduced_max_exponent = + f32_exponent_bias + reduced_exponent_bias; + const uint32_t reduced_min_exponent = + f32_exponent_bias - reduced_exponent_bias; + + // Do we overflow or underflow? + llvm::Value* x_exponent = ir_builder_->CreateAnd( + x_as_int, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); + llvm::Value* x_overflows = ir_builder_->CreateICmpUGT( + x_exponent, + llvm::ConstantInt::get(int_type, reduced_max_exponent << 23)); + llvm::Value* x_underflows = ir_builder_->CreateICmpULE( + x_exponent, + llvm::ConstantInt::get(int_type, reduced_min_exponent << 23)); + + // Compute appropriately-signed values of zero and infinity. + llvm::Value* x_signed_zero = ir_builder_->CreateAnd( + x_as_int, llvm::ConstantInt::get(int_type, f32_sign_bit_mask)); + llvm::Value* x_signed_inf = ir_builder_->CreateOr( + x_signed_zero, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); + + // Force to zero or infinity if overflow or underflow. (Note that this + // truncates all denormal values to zero, rather than rounding them.) + x_as_int = ir_builder_->CreateSelect(x_overflows, x_signed_inf, x_as_int); + x_as_int = ir_builder_->CreateSelect(x_underflows, x_signed_zero, x_as_int); + } + + // Cast the result back to a floating-point type. + llvm::Value* result = ir_builder_->CreateBitCast(x_as_int, float_type); + + // Correct result for NaN inputs. + // + // The exponent handling will "normalize" NaN values to infinities, which is + // undesirable (except in the case with no mantissa bits, in which case it + // is mandatory). This logic also handles cases where mantissa-rounding + // causes a NaN's mantissa to overflow into the exponent bits, which would + // otherwise create an erroneous zero value. + // + // If the fast-math flags are set to assume no NaNs, the comparison is likely + // to be optimized away, so there's no point in even emitting it. + if (!ir_builder_->getFastMathFlags().noNaNs()) { + llvm::Value* x_is_nan = ir_builder_->CreateFCmpUNO(x, x); + + if (hlo->mantissa_bits() > 0) { + result = ir_builder_->CreateSelect(x_is_nan, x, result); + } else { + result = ir_builder_->CreateSelect( + x_is_nan, llvm::ConstantFP::getInfinity(float_type), result); + } + } + return result; +} + StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value, bool is_signed) const { @@ -588,20 +705,37 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator( llvm::Intrinsic::ctlz, {r, ir_builder_->getInt1(1)}, {param_ir_type}, ir_builder_); auto in_block = ir_builder_->GetInsertBlock(); - auto body_block = in_block->splitBasicBlock( - ir_builder_->GetInsertPoint(), "rng_body"); - SetToFirstInsertPoint(body_block, ir_builder_); - auto out_block = body_block->splitBasicBlock( - ir_builder_->GetInsertPoint(), "rng_out"); + + // A terminator should be present iff we're emitting code + // into the middle (as opposed to the end) of a basic block. + CHECK_EQ(ir_builder_->GetInsertPoint() == in_block->end(), + in_block->getTerminator() == nullptr); + + llvm::BasicBlock* body_block; + llvm::BasicBlock* out_block; + + if (ir_builder_->GetInsertPoint() == in_block->end()) { + body_block = + llvm_ir::CreateBasicBlock(nullptr, "rng_body", ir_builder_); + out_block = + llvm_ir::CreateBasicBlock(nullptr, "rng_out", ir_builder_); + llvm::BranchInst::Create(body_block, in_block); + } else { + body_block = in_block->splitBasicBlock( + ir_builder_->GetInsertPoint(), "rng_body"); + out_block = body_block->splitBasicBlock( + ir_builder_->GetInsertPoint(), "rng_out"); + body_block->getTerminator()->eraseFromParent(); + } + SetToFirstInsertPoint(body_block, ir_builder_); auto random = ir_builder_->CreateAnd( ir_builder_->CreateZExtOrTrunc(get_next_i64(), param_ir_type), ir_builder_->CreateLShr(llvm::ConstantInt::get(param_ir_type, ~0), leading_zeros)); - llvm::ReplaceInstWithInst( - body_block->getTerminator(), - llvm::BranchInst::Create(out_block, body_block, - ir_builder_->CreateICmpULT(random, r))); + llvm::BranchInst::Create(out_block, body_block, + ir_builder_->CreateICmpULT(random, r), + body_block); SetToFirstInsertPoint(out_block, ir_builder_); return ir_builder_->CreateAdd( p, ir_builder_->CreateSelect( @@ -647,6 +781,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kCeil: case HloOpcode::kConvert: case HloOpcode::kCopy: + case HloOpcode::kCos: case HloOpcode::kExp: case HloOpcode::kFloor: case HloOpcode::kIsFinite: @@ -720,6 +855,14 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( ElementwiseSourceIndex(index, *hlo, 2))); return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value)); }; + case HloOpcode::kReducePrecision: + return [this, hlo, &operand_to_generator]( + const IrArray::Index& index) -> StatusOr { + TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, + operand_to_generator.at(hlo->operand(0))( + ElementwiseSourceIndex(index, *hlo, 0))); + return EmitReducePrecision(hlo, operand_value); + }; case HloOpcode::kConcatenate: return [this, hlo, &operand_to_generator]( const IrArray::Index target_index) -> StatusOr { @@ -805,23 +948,9 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kSlice: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { - IrArray::Index sliced_index(index.size()); - for (int i = 0; i < index.size(); ++i) { - int64 stride = hlo->slice_stride(i); - if (stride != 1) { - sliced_index[i] = ir_builder_->CreateAdd( - ir_builder_->CreateMul( - index[i], llvm::ConstantInt::get(index[i]->getType(), - stride)), - llvm::ConstantInt::get(index[i]->getType(), - hlo->slice_starts(i))); - } else { - sliced_index[i] = ir_builder_->CreateAdd( - index[i], - llvm::ConstantInt::get(index[i]->getType(), - hlo->slice_starts(i))); - } - } + IrArray::Index sliced_index = index.SourceIndexOfSlice( + /*shape=*/hlo->shape(), /*starts=*/hlo->slice_starts(), + /*strides=*/hlo->slice_strides(), /*builder=*/ir_builder_); return operand_to_generator.at(hlo->operand(0))(sliced_index); }; case HloOpcode::kDynamicSlice: diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index 2576d3823e06ed3050554b38766dbd6c6a48ca5c..bb9117ca61e3b6ccb7f1fcecb62b0be5f984e6d1 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -84,6 +84,9 @@ class ElementalIrEmitter { virtual StatusOr EmitErfcInv(PrimitiveType prim_type, llvm::Value* value) const; + virtual StatusOr EmitReducePrecision(const HloInstruction* hlo, + llvm::Value* x) const; + // A helper method for MakeElementGenerator. Given an elementwise op `hlo` and // the target array index, computes the source array index of its // `operand_no`-th operand. diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 3a9f8dc79ee0589f27fe5aabf9592a73f34c4a0e..20eb1aea375e67dc296e11554e63da3d293e1fdd 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -21,39 +21,9 @@ limitations under the License. #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/regexp.h" namespace xla { -/* static */ void Executable::DumpExecutedHlo( - const HloModule& module, const string& label, - const HloExecutionProfile* profile) { - VLOG(2) << "module name = " << module.name(); - legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); - string generate_hlo_graph_regex; - if (!flags->xla_generate_hlo_graph.empty()) { - generate_hlo_graph_regex = flags->xla_generate_hlo_graph; - } else { - generate_hlo_graph_regex = - module.config().debug_options().xla_generate_hlo_graph(); - } - if (!generate_hlo_graph_regex.empty() && - RE2::PartialMatch(module.name(), generate_hlo_graph_regex)) { - hlo_graph_dumper::DumpGraph(*module.entry_computation(), label, - flags->xla_hlo_graph_addresses, - flags->xla_hlo_graph_layout, profile); - } - if (!flags->xla_log_hlo_text.empty() && - RE2::PartialMatch(module.name(), flags->xla_log_hlo_text)) { - LOG(INFO) << "HLO for module " << module.name(); - LOG(INFO) << "Label: " << label; - XLA_LOG_LINES(2, module.ToString()); - } - if (!flags->xla_dump_hlo_text_to.empty()) { - hlo_graph_dumper::DumpText(module, label, flags->xla_dump_hlo_text_to); - } -} - StatusOr> Executable::ExecuteOnStreams( tensorflow::gtl::ArraySlice run_options, diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 291916cd9f7acb0c136dc0834b28f57a83736ec6..b36a44e19ea5b33ca6b5fc85a775e7cb6ff661f7 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/service/session.pb.h" @@ -49,10 +50,6 @@ class Executable { shape_size_function_(std::move(shape_size_function)) {} virtual ~Executable() {} - // Dumps the executed HLO according to service-associated flags. - static void DumpExecutedHlo(const HloModule& module, const string& label, - const HloExecutionProfile* profile); - // Enqueues the compilation result on the provided stream, passing the given // arguments. This call is blocking and returns after the execution is done. // @@ -240,7 +237,8 @@ StatusOr Executable::ExecuteOnStreamWrapper( } } } - DumpExecutedHlo(module(), "Service::Execute", profile_ptr); + hlo_graph_dumper::MaybeDumpHloModule(module(), "Service::Execute", + profile_ptr); } return return_value; diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc index bb4712c86f6d649a9ec8f1450d90735de9ec43c3..a08506d84d1960e8d9d4bbb2ef124bab644725a7 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc @@ -80,7 +80,7 @@ class FlattenCallGraphTest : public HloTestBase { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, kScalarShape, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero)); return builder.Build(); @@ -157,7 +157,7 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(PRED, {}), "param0")); HloInstruction* false_constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); builder.AddInstruction( HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, param0, false_constant)); @@ -168,7 +168,7 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { { HloComputation::Builder builder(TestName() + ".entry"); HloInstruction* false_constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); builder.AddInstruction(HloInstruction::CreateWhile( ShapeUtil::MakeShape(PRED, {}), cond_computation, cond_computation, false_constant)); diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index eb8b93330fbc7b786c66a07f8009b4676358421b..476b2b8d6f8069efcad1a415f1a766a0af3a1ec3 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -82,13 +82,12 @@ Status GenericTransferManager::TransferLiteralFromDevice( } *literal->mutable_shape() = device_shape; - LiteralUtil::Reserve(ShapeUtil::ElementsIn(device_shape), literal); + literal->Reserve(ShapeUtil::ElementsIn(device_shape)); TF_RETURN_IF_ERROR(TransferBufferFromDevice( executor, source, /*size=*/ShapeUtil::ByteSizeOf(device_shape), - /*destination=*/LiteralUtil::MutableInternalData(literal))); + /*destination=*/literal->MutableInternalData())); if (!ShapeUtil::Equal(literal_shape, device_shape)) { - literal->Swap( - LiteralUtil::Relayout(*literal, literal_shape.layout()).get()); + literal->Swap(literal->Relayout(literal_shape.layout()).get()); } TF_RET_CHECK(ShapeUtil::Equal(literal_shape, literal->shape())); return Status::OK(); @@ -152,14 +151,20 @@ Status GenericTransferManager::TransferLiteralToDevice( tuple_elements_on_device.data(), destination); } - return TransferBufferToDevice( - executor, /*size=*/GetByteSizeRequirement(shape), - /*source=*/LiteralUtil::InternalData(literal), destination); + return TransferBufferToDevice(executor, + /*size=*/GetByteSizeRequirement(shape), + /*source=*/literal.InternalData(), destination); } Status GenericTransferManager::TransferLiteralToInfeed( se::StreamExecutor* executor, const Literal& literal) { - return Unimplemented("Infeed is not supported on GPU (b/30467474)"); + return Unimplemented("Generic transfer to Infeed"); +} + +Status GenericTransferManager::TransferBufferToInfeed( + perftools::gputools::StreamExecutor* executor, int64 size, + const void* source) { + return Unimplemented("Generic transfer to Infeed"); } Status GenericTransferManager::TransferLiteralFromOutfeed( diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 2fbdb94f06f1b12763571dc2aa9b0d770f420406..48c061d28e5967f903e9ea665fdaeb02fab7e02e 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -54,6 +54,8 @@ class GenericTransferManager : public TransferManager { Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor, const Literal& literal) override; + Status TransferBufferToInfeed(perftools::gputools::StreamExecutor* executor, + int64 size, const void* source) override; Status TransferLiteralFromOutfeed( perftools::gputools::StreamExecutor* executor, const Shape& literal_shape, diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 869869341179822aa8d9e9675211be92f733077d..68090996953d99f7e2f72b0ade35d46a35884669 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -70,6 +70,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/legacy_flags:stream_assignment_flags", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_reachability", "//tensorflow/core:lib", ], ) @@ -253,7 +254,6 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:convolution_thunk_flags", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:executable", @@ -267,7 +267,7 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/platform/default/build_config:cublas_plugin", "//tensorflow/core/platform/default/build_config:cudnn_plugin", - "//tensorflow/core/platform/default/build_config:stream_executor_cuda", + "//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep ], ) @@ -376,7 +376,6 @@ cc_test( ":fusion_merger", ":instruction_fusion", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", ], ) @@ -418,7 +417,6 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/legacy_flags:gpu_compiler_flags", "//tensorflow/compiler/xla/service:algebraic_simplifier", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", @@ -500,8 +498,10 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_ordering", + "//tensorflow/compiler/xla/service:hlo_reachability", + "//tensorflow/compiler/xla/service:hlo_scheduling", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 9a0b14eb7332358d0e68e6a40b47c94b88666eb6..20e0d8eb785daa07b3fcc5339efe950aac0dacad 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -17,7 +17,6 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/stringprintf.h" @@ -287,10 +286,7 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune( const ConvolutionDescriptor& convolution_descriptor, const BufferAllocations& buffer_allocations, se::Stream* stream) { // TODO(b/29126320): Try cudnn v5's new auto-tuner when it's rolled out. - legacy_flags::ConvolutionThunkFlags* flags = - legacy_flags::GetConvolutionThunkFlags(); - if (flags->xla_gpu_autotune_convolution_algorithm && - best_algorithm_.algorithm() == se::dnn::kDefaultAlgorithm) { + if (best_algorithm_.algorithm() == se::dnn::kDefaultAlgorithm) { // Auto-tuning either is disabled or only happens in the first run of this // function. VLOG(2) << "Profiling for best convolution algorithm used for " diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index aaf72935e61ee8b8da7df410ba3aaed63800cfd9..91d6df299da2686d6d836445d391c4b0eaf4ed00 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -81,9 +81,8 @@ class ConvolutionThunk : public Thunk { ConvolutionThunk(const ConvolutionThunk&) = delete; ConvolutionThunk& operator=(const ConvolutionThunk&) = delete; - // Does the convolution for the thunk on "stream". If the - // xla_gpu_autotune_convolution_algorithm is turned on, auto-tuning happens on - // the first run of this function. + // Does the convolution for the thunk on "stream". Auto-tuning happens on the + // first run of this function. tensorflow::Status ExecuteOnStream( const BufferAllocations& buffer_allocations, perftools::gputools::Stream* stream) override; diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 2987c8913d7cdd93d57bfcca40d6c56ae4dd30f0..c2dec7ed6af575e390fa075a279861791e2d29a7 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -55,7 +55,7 @@ using tensorflow::strings::StrAppend; // Returns whether operand is a floating-point literal with the given value. bool IsFPLiteralWithValue(const HloInstruction* operand, float value) { return operand->opcode() == HloOpcode::kConstant && - LiteralUtil::IsAllFloat(operand->literal(), value); + operand->literal().IsAllFloat(value); } GpuElementalIrEmitter::GpuElementalIrEmitter( diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index afb78b8300b457ba9384bd66f789d333630b51e4..e698646d18021733b78de53443316100021a995c 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -98,7 +98,13 @@ double CalculateFlopsToBytesRatio(HloInstruction* fusion) { // Calculate total bytes transferred in/out. double bytes = CalculateBytesReadByFusionInstruction(fusion); // Add bytes written to root instructions buffer. - bytes += ShapeUtil::ByteSizeOf(fusion->fused_expression_root()->shape()); + if (fusion->IsMultiOutputFusion()) { + for (auto& operand : fusion->fused_expression_root()->operands()) { + bytes += ShapeUtil::ByteSizeOf(operand->shape()); + } + } else { + bytes += ShapeUtil::ByteSizeOf(fusion->fused_expression_root()->shape()); + } // Calculate flops for all fused instructions. Use a null shape size function // because we don't care about bytes accessed by the ops. HloCostAnalysis analysis([](const Shape& shape) { return 0; }); @@ -112,8 +118,15 @@ double CalculateFlopsToBytesRatio(HloInstruction* fusion) { double GetCurrentBytesTransferred(HloInstruction* fusion) { CHECK_EQ(HloOpcode::kFusion, fusion->opcode()); const double bytes_read = CalculateBytesReadByFusionInstruction(fusion); - const double bytes_written = - ShapeUtil::ByteSizeOf(fusion->fused_expression_root()->shape()); + double bytes_written = 0; + if (fusion->IsMultiOutputFusion()) { + for (auto& operand : fusion->fused_expression_root()->operands()) { + bytes_written += ShapeUtil::ByteSizeOf(operand->shape()); + } + } else { + bytes_written = + ShapeUtil::ByteSizeOf(fusion->fused_expression_root()->shape()); + } // Current bytes transferred (ignoring non 'fusion' user operands) is bytes // read and written by 'fusion', plus reads of size 'bytes_written' for each // user. @@ -198,6 +211,12 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { ++num_fail_not_loop_fusion_; return Status::OK(); } + + // Skip multiple output fusion. It's not yet supported. + if (fusion->IsMultiOutputFusion()) { + ++num_fail_not_loop_fusion_; + return Status::OK(); + } // Skip 'fusion' instruction if we cannot merge into all of its users. // Merging into all users enables the removal of 'fusion' from the // computation. diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc index 8afc32dea97ea00442d2f094c8d6de0b510482fd..242c32936d31d0cb578825cade5f35979077a44e 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc @@ -59,7 +59,7 @@ class FusionMergerTest : public HloTestBase { // Create const vector of ones to be used in element-wise computations. auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.f, 1.f, 1.f, 1.f}))); + Literal::CreateR1({1.f, 1.f, 1.f, 1.f}))); // Create simple fusable computation for tuple element 0 (wont get merged). auto out0 = builder.AddInstruction(HloInstruction::CreateBinary( @@ -138,7 +138,7 @@ class FusionMergerTest : public HloTestBase { // Create two sub-computations, both of which are users of 'mul0'. auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.f, 1.f, 1.f, 1.f}))); + Literal::CreateR1({1.f, 1.f, 1.f, 1.f}))); // First sub-computation: out0 = Mul(Add(mul0, one_vec), one_vec) auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( @@ -209,7 +209,7 @@ class FusionMergerTest : public HloTestBase { // Create two fusable sub-computations which are dependent on shared // computation 'reduce_out'. auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.f, 1.f, 1.f, 1.f}))); + Literal::CreateR1({1.f, 1.f, 1.f, 1.f}))); // First sub-computation: out0 = Mul(Add(reduce_out, one_vec), one_vec) auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 86137a569f9b199782462582ba11683ff9884d7b..47fba3682251d0c3f20878b4608272661c74634a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -23,7 +23,6 @@ limitations under the License. #include "external/llvm/include/llvm/IR/DiagnosticPrinter.h" #include "external/llvm/include/llvm/IR/LLVMContext.h" #include "external/llvm/include/llvm/IR/Module.h" -#include "tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" @@ -95,11 +94,9 @@ constexpr int64 kMemoryAlignment = 256; // called in GpuCompiler's constructor, so can't return an error. But // GpuCompiler::Compile will return an error when the wanted libdevice file // doesn't exist in the folder this function returns. -string GetLibdeviceDir() { +string GetLibdeviceDir(const HloModuleConfig& config) { std::vector potential_libdevice_dirs; - // Flag xla_cuda_data_dir specified by the user. - legacy_flags::GpuCompilerFlags* flags = legacy_flags::GetGpuCompilerFlags(); - const string datadir = flags->xla_cuda_data_dir; + const string datadir = config.debug_options().xla_gpu_cuda_data_dir(); if (!datadir.empty()) { potential_libdevice_dirs.push_back(datadir); } @@ -122,14 +119,13 @@ string GetLibdeviceDir() { // Runs optimization passes on the given HLO module. tensorflow::Status OptimizeHloModule(HloModule* hlo_module, - const Compiler::HloDumper& dump_hlo, const se::DeviceDescription& device_desc) { { - HloPassPipeline pipeline("optimization", dump_hlo); + HloPassPipeline pipeline("optimization"); pipeline.AddInvariantChecker(); { - auto& pass = pipeline.AddPass>( - "simplification", dump_hlo); + auto& pass = + pipeline.AddPass>("simplification"); pass.AddPass( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }); @@ -149,7 +145,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } { - HloPassFix fusion("fusion", dump_hlo); + HloPassFix fusion("fusion"); fusion.AddPass(/*may_duplicate=*/false); fusion.AddPass(/*may_duplicate=*/true); fusion.AddPass(); @@ -159,14 +155,13 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, // Modifies the given HLO module so that it will be accepted by IrEmitter. // Unlike optimization passes, the passes are necessary for correctness. -tensorflow::Status PrepareHloModuleForIrEmitting( - const Compiler::HloDumper& dump_hlo, HloModule* hlo_module) { +tensorflow::Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { // In some cases, we have to place the result of an instruction in a temporary // buffer. For instance, the buffer that holds an external parameter is // assumed immutable at this point, and should not be reused for output // (b/27180329). Therefore, in that case, we set the output to be a copy of // the parameter. - HloPassPipeline pipeline("GPU-ir-emit-prepare", dump_hlo); + HloPassPipeline pipeline("GPU-ir-emit-prepare"); pipeline.AddInvariantChecker(); pipeline.AddPass(); pipeline.AddPass( @@ -230,17 +225,15 @@ void DumpPtxasInfo(const string& ptx) { } // namespace GpuCompiler::GpuCompiler() - : libdevice_dir_(GetLibdeviceDir()), - pointer_size_(llvm::DataLayout(kDataLayout).getPointerSize()) {} + : pointer_size_(llvm::DataLayout(kDataLayout).getPointerSize()) {} StatusOr> GpuCompiler::Compile( - std::unique_ptr module, HloDumper dump_hlo, - se::StreamExecutor* stream_exec) { + std::unique_ptr module, se::StreamExecutor* stream_exec) { TF_RET_CHECK(stream_exec != nullptr); - TF_RETURN_IF_ERROR(OptimizeHloModule(module.get(), dump_hlo, - stream_exec->GetDeviceDescription())); - TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(dump_hlo, module.get())); + TF_RETURN_IF_ERROR( + OptimizeHloModule(module.get(), stream_exec->GetDeviceDescription())); + TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get())); llvm::LLVMContext llvm_context; std::string buffer; @@ -271,13 +264,16 @@ StatusOr> GpuCompiler::Compile( TF_ASSIGN_OR_RETURN( std::unique_ptr buffer_assignment, BufferAssigner::Run(module.get(), hlo_schedule->ConsumeHloOrdering(), - BufferSizeBytesFunction(), kMemoryAlignment)); + BufferSizeBytesFunction(), [](LogicalBuffer::Color) { + return kMemoryAlignment; + })); - legacy_flags::GpuCompilerFlags* flags = legacy_flags::GetGpuCompilerFlags(); - if (!flags->xla_gpu_dump_debug_json_to.empty()) { + const string dump_debug_json_to = + module->config().debug_options().xla_dump_debug_json_to(); + if (!dump_debug_json_to.empty()) { HloProto proto = MakeHloProto(*module, *buffer_assignment); TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( - proto, flags->xla_gpu_dump_debug_json_to, module->name())); + proto, dump_debug_json_to, module->name())); } IrEmitterContext ir_emitter_context(module.get(), buffer_assignment.get(), @@ -292,7 +288,9 @@ StatusOr> GpuCompiler::Compile( entry_computation->root_instruction()->Accept(&ir_emitter)); string ir_module_string_before_opt; - if (VLOG_IS_ON(2) || flags->xla_gpu_embed_ir) { + const bool embed_ir_in_executable = + module->config().debug_options().xla_embed_ir_in_executable(); + if (VLOG_IS_ON(2) || embed_ir_in_executable) { ir_module_string_before_opt = llvm_ir::DumpModuleToString(llvm_module); VLOG(2) << "LLVM module before optimizations:"; XLA_VLOG_LINES(2, ir_module_string_before_opt); @@ -313,6 +311,10 @@ StatusOr> GpuCompiler::Compile( cc_major = 2; cc_minor = 0; } + if (libdevice_dir_.empty()) { + // Compute libdevice_dir_ just once and cache it in this member. + libdevice_dir_ = GetLibdeviceDir(module->config()); + } TF_ASSIGN_OR_RETURN(*ptx, CompileToPtx(&llvm_module, {cc_major, cc_minor}, module->config(), libdevice_dir_)); @@ -333,7 +335,7 @@ StatusOr> GpuCompiler::Compile( auto* gpu_executable = new GpuExecutable(*ptx, std::move(thunk_schedule), std::move(module), std::move(buffer_assignment), ShapeSizeBytesFunction()); - if (flags->xla_gpu_embed_ir) { + if (embed_ir_in_executable) { DCHECK_NE("", ir_module_string_before_opt); gpu_executable->set_ir_module_string(ir_module_string_before_opt); } @@ -341,16 +343,15 @@ StatusOr> GpuCompiler::Compile( } StatusOr>> GpuCompiler::Compile( - std::vector> modules, HloDumper dump_hlos, + std::vector> modules, std::vector stream_execs) { return Unimplemented( "Compilation of multiple HLO modules is not yet supported on GPU."); } StatusOr>> -GpuCompiler::CompileAheadOfTime( - std::vector> module, - HloDumper dump_hlo, const AotCompilationOptions& options) { +GpuCompiler::CompileAheadOfTime(std::vector> module, + const AotCompilationOptions& options) { return Unimplemented("not yet implemented: GpuCompiler::CompileAheadOfTime"); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h index da52f5ab1f8e5bf8c2fa3c33948ccf8a0f647f7a..b87555b931f1d73de8bcaf84aea80305c9d585bf 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h @@ -41,17 +41,16 @@ class GpuCompiler : public Compiler { ~GpuCompiler() override {} StatusOr> Compile( - std::unique_ptr module, HloDumper dump_hlo, + std::unique_ptr module, perftools::gputools::StreamExecutor* stream_exec) override; StatusOr>> Compile( - std::vector> modules, HloDumper dump_hlo, + std::vector> modules, std::vector stream_exec) override; StatusOr>> - CompileAheadOfTime( - std::vector> module, - HloDumper dump_hlo, AotCompilationOptions const& options) override; + CompileAheadOfTime(std::vector> module, + AotCompilationOptions const& options) override; perftools::gputools::Platform::Id PlatformId() const override; @@ -65,7 +64,7 @@ class GpuCompiler : public Compiler { private: // The parent directory of libdevice IR libraries. - const string libdevice_dir_; + string libdevice_dir_; // The list of PTX strings generated by this GpuCompiler. We let GpuCompiler // to own them because they need to be alive across the life span of the diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc index d16a1d4ee5be00e685fc181f19c1a3cfda253f6a..c61e47a93ce3f71904a889a37184ebaa06417c62 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc @@ -20,6 +20,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" +#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/types.h" namespace xla { @@ -67,38 +69,38 @@ GpuHloOrdering::GpuHloOrdering( // waits for its operands before executing. // // The predecessor map is built incrementally, in thunk launch order. We - // record the instructions already visited per stream in - // 'instructions_per_stream'. This lets us quickly determine the same-stream - // predecessors of each instruction. To capture cross-stream dependency edges, - // we use the predecessor map to insert each operand as well as its transitive - // closure of dependencies. - - // Compute the set of all instructions we will want to set reachability on - auto predecessor_map = MakeUnique( + // record the most-recently seen instructions per stream in + // 'last_instruction_per_stream'. This lets us quickly determine the + // same-stream predecessors of each instruction. + + // Compute the set of all instructions we will want to set reachability on. + auto predecessor_map = MakeUnique( module->entry_computation()->MakeInstructionPostOrder()); - std::vector> instructions_per_stream( - stream_assignment.StreamCount()); + // The most recently visited instruction per stream. + std::vector last_instruction_per_stream( + stream_assignment.StreamCount(), nullptr); for (const HloInstruction* hlo : thunk_launch_order) { + predecessor_map->SetReachable(hlo, hlo); if (stream_assignment.HasStreamAssigned(*hlo)) { + // Gather all instruction which are immediate predecessors of 'hlo' in the + // reachability graph. + std::vector immediate_preds; + immediate_preds.insert(immediate_preds.end(), hlo->operands().begin(), + hlo->operands().end()); + immediate_preds.insert(immediate_preds.end(), + hlo->control_predecessors().begin(), + hlo->control_predecessors().end()); + // All ops already queued on the same instruction stream, and their - // transitive predecessors, are predecessors. Since the relation is - // transitive, we just set the transitive closure of the previous op. + // transitive predecessors, are predecessors. const int stream_no = stream_assignment.StreamNumberForHlo(*hlo); - std::vector* instructions = - &instructions_per_stream[stream_no]; - if (!instructions->empty()) { - const HloInstruction* back = instructions->back(); - predecessor_map->SetReachableAndTransitiveClosure(hlo, back); - } - // All operands and their transitive predecessors are predecessors. Each - // operand must already exist in 'predecessor_map', since we're iterating - // in thunk launch order. - for (const HloInstruction* operand : hlo->operands()) { - predecessor_map->SetReachableAndTransitiveClosure(hlo, operand); + if (last_instruction_per_stream[stream_no] != nullptr) { + immediate_preds.push_back(last_instruction_per_stream[stream_no]); } - instructions->push_back(hlo); + predecessor_map->SetReachabilityToUnion(immediate_preds, hlo); + last_instruction_per_stream[stream_no] = hlo; } else { // Only parameters and constants don't have an assigned stream, since they // don't require a thunk. These ops don't have any predecessors. @@ -107,12 +109,11 @@ GpuHloOrdering::GpuHloOrdering( CHECK_EQ(hlo->operand_count(), 0); } } - strict_predecessors_.emplace(module->entry_computation(), - std::move(predecessor_map)); + predecessors_.emplace(module->entry_computation(), + std::move(predecessor_map)); - // The ordering of instructions in subcomputations is based solely on data - // dependencies. I.e. the strict predecessors of each subcomputation - // instruction is its transitive operands. + // The ordering of instructions in subcomputations is based solely on control + // and data dependencies. // // TODO(toddw): Each subcomputation is actually emitted as a function in DFS // postorder, so we can do better and establish the total order here. We don't @@ -120,8 +121,8 @@ GpuHloOrdering::GpuHloOrdering( // by IrEmitterNested. And mismatched ordering bugs would be hard to find. for (auto& computation : module->computations()) { if (computation.get() != module->entry_computation()) { - strict_predecessors_.emplace(computation.get(), - computation->ComputeTransitiveOperands()); + predecessors_.emplace(computation.get(), + computation->ComputeReachability()); } } } diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.h b/tensorflow/compiler/xla/service/gpu/hlo_schedule.h index 773973010a46bb4a2af1f536c43201ba8c0be5d8..1ce7a48ac8fcbbad0b3697845681582fe806b322 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.h @@ -19,9 +19,9 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/statusor.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 1a61eec353740202065c1ce98e8c91274facfd19..a04214930dfc95b82ca4c702d12648381a4c8135 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -86,23 +86,35 @@ void HloToIrBindings::EmitBasePointersForHlos( continue; } - // A non-IO HLO with a buffer is bound to - // (1) an alloca if it is thread-local, or - // (2) an internal pointer in temp_buffer_base according to its offset. - const BufferAllocation::Slice slice = - buffer_assignment_->GetUniqueTopLevelSlice(non_io_hlo) - .ConsumeValueOrDie(); - if (slice.allocation()->is_thread_local()) { - llvm::Type* pointee_type = - llvm_ir::ShapeToIrType(non_io_hlo->shape(), ir_builder_); - BindHloToIrValue(*non_io_hlo, ir_builder_->CreateAlloca(pointee_type)); - } else { - const int64 offset = slice.offset(); - CHECK_NE(nullptr, temp_buffer_base_); - BindHloToIrValue(*non_io_hlo, - ir_builder_->CreateInBoundsGEP( - temp_buffer_base_, ir_builder_->getInt64(offset))); - } + ShapeUtil::ForEachSubshape( + non_io_hlo->shape(), + [&](const Shape& /*subshape*/, const ShapeIndex& index) { + // A non-IO HLO with a buffer is bound to + // (1) an alloca if it is thread-local, or + // (2) an internal pointer in temp_buffer_base according to its + // offset. + auto slice_result = + buffer_assignment_->GetUniqueSlice(non_io_hlo, index); + if (!slice_result.ok()) { + return; + } + const BufferAllocation::Slice slice = + slice_result.ConsumeValueOrDie(); + if (slice.allocation()->is_thread_local()) { + llvm::Type* pointee_type = + llvm_ir::ShapeToIrType(non_io_hlo->shape(), ir_builder_); + BindHloToIrValue(*non_io_hlo, + ir_builder_->CreateAlloca(pointee_type), index); + } else { + const int64 offset = slice.offset(); + CHECK_NE(nullptr, temp_buffer_base_); + BindHloToIrValue( + *non_io_hlo, + ir_builder_->CreateInBoundsGEP(temp_buffer_base_, + ir_builder_->getInt64(offset)), + index); + } + }); } } @@ -112,7 +124,7 @@ llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte, if (gte->operand(0)->opcode() != HloOpcode::kGetTupleElement) { return llvm_ir::EmitGetTupleElement( gte->shape(), gte->tuple_index(), /*alignment=*/1, - GetTypedIrValue(*gte->operand(0), base_ptr), ir_builder_); + GetTypedIrValue(*gte->operand(0), {}, base_ptr), ir_builder_); } return llvm_ir::EmitGetTupleElement( gte->shape(), gte->tuple_index(), /*alignment=*/1, @@ -120,8 +132,10 @@ llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte, } llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, + const ShapeIndex& shape_index, llvm::Value* ir_value) { - llvm::Type* pointee_type = llvm_ir::ShapeToIrType(hlo.shape(), ir_builder_); + llvm::Type* pointee_type = llvm_ir::ShapeToIrType( + ShapeUtil::GetSubshape(hlo.shape(), shape_index), ir_builder_); llvm::Type* dest_type = pointee_type->getPointerTo(); llvm::Value* typed_ir_value; @@ -139,13 +153,24 @@ llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, } void HloToIrBindings::BindHloToIrValue(const HloInstruction& hlo, - llvm::Value* ir_value) { + llvm::Value* ir_value, + const ShapeIndex& shape_index) { VLOG(2) << "Binding " << hlo.ToString(); - InsertOrDie(&base_ptrs_, &hlo, GetTypedIrValue(hlo, ir_value)); + + const Shape& hlo_shape = hlo.shape(); + llvm::Value* typed_ir_value = GetTypedIrValue(hlo, shape_index, ir_value); + + if (!BoundToIrValue(hlo)) { + // Set the root of ShapeTree first before assigning the element ir value. + InsertOrDie(&base_ptrs_, &hlo, ShapeTree(hlo_shape, nullptr)); + } + *(base_ptrs_[&hlo].mutable_element(shape_index)) = typed_ir_value; } -llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo) { - llvm_ir::IrArray ir_array(GetBasePointer(hlo), hlo.shape()); +llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo, + const ShapeIndex& shape_index) { + llvm_ir::IrArray ir_array(GetBasePointer(hlo, shape_index), + ShapeUtil::GetSubshape(hlo.shape(), shape_index)); alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array); return ir_array; } @@ -154,7 +179,7 @@ void HloToIrBindings::UnbindAllLocalIrValues() { std::vector hlos_to_unbind; for (auto& key_value : base_ptrs_) { if (!llvm::isa( - key_value.second->stripPointerCasts())) { + (key_value.second.element({}))->stripPointerCasts())) { hlos_to_unbind.push_back(key_value.first); } } diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h index 5be2150801fbd2a3a624d9c87513d5cee7288bbd..2c59886e9ae410b6a6a1dd9973c75c061c8db808 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h @@ -48,7 +48,8 @@ class HloToIrBindings { tensorflow::gtl::ArraySlice non_io_hlos); // Rebinds the given HLO to the LLVM IR value that represent its address. - void BindHloToIrValue(const HloInstruction& hlo, llvm::Value* ir_value); + void BindHloToIrValue(const HloInstruction& hlo, llvm::Value* ir_value, + const ShapeIndex& shape_index = {}); // Unbinds all IR values that's defined in an LLVM function, e.g., function // arguments and stack variables. Global variables will be kept in bindings_. @@ -64,15 +65,18 @@ class HloToIrBindings { llvm::Value* GetTempBufferBase() const { return temp_buffer_base_; } - // A helper method that returns the base pointer of the IrArray for "inst". - llvm::Value* GetBasePointer(const HloInstruction& hlo) const { + // A helper method that returns the base pointer of the IrArray containing the + // output of "inst".at the given ShapeIndex. + llvm::Value* GetBasePointer(const HloInstruction& hlo, + const ShapeIndex& shape_index = {}) const { auto it = base_ptrs_.find(&hlo); CHECK(it != base_ptrs_.end()); - return it->second; + return it->second.element(shape_index); } // Return the underlying IrArray of the output of the given instruction. - llvm_ir::IrArray GetIrArray(const HloInstruction& hlo); + llvm_ir::IrArray GetIrArray(const HloInstruction& hlo, + const ShapeIndex& shape_index = {}); private: // Emits IR to resolve (possibly) recursive GetTupleElement instructions. @@ -81,6 +85,7 @@ class HloToIrBindings { // Returns an llvm typed ir representation of 'ir_value' based on 'hlo' shape. llvm::Value* GetTypedIrValue(const HloInstruction& hlo, + const ShapeIndex& shape_index, llvm::Value* ir_value); const BufferAssignment* buffer_assignment_; @@ -90,7 +95,10 @@ class HloToIrBindings { llvm::IRBuilder<>* ir_builder_; // Stores the underlying llvm::IrArray for each HloInstruction. - std::unordered_map base_ptrs_; + // For an instruction that generates multiple outputs, the root will be a + // tuple shape. The IrArray for each element output is stored in the subnode + // in the ShapeTree. + std::unordered_map> base_ptrs_; // The address of the memory block that contains all temporary buffers. llvm::Value* temp_buffer_base_; diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc index 120a3f7fba2101ce64da1e8135fb5f862e603fe4..ee5b447c9cd0b1fde4d3a0943d5d4cb8cc5b3376 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc @@ -13,8 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/core/platform/logging.h" namespace se = ::perftools::gputools; @@ -22,23 +24,23 @@ namespace se = ::perftools::gputools; namespace xla { namespace gpu { -InfeedManager::InfeedManager() - : current_buffer_(nullptr), - host_to_device_executor_(nullptr) {} +InfeedManager::InfeedManager() : host_to_device_executor_(nullptr) {} void InfeedManager::Reset() { tensorflow::mutex_lock l(mu_); - CHECK(!current_buffer_); + CHECK(dequeued_buffer_.empty()); for (auto buffer : enqueued_buffer_) { buffer->Done(); } enqueued_buffer_.clear(); } -void InfeedManager::EnqueueBuffer(InfeedBuffer* buffer) { +void InfeedManager::EnqueueBuffers(const std::vector& buffers) { tensorflow::mutex_lock l(mu_); bool was_empty = enqueued_buffer_.empty(); - enqueued_buffer_.push_back(buffer); + for (gpu::InfeedBuffer* b : buffers) { + enqueued_buffer_.push_back(b); + } if (was_empty) { // This has the potential to suffer from the notified thread // immediately trying and failing to acquire mu_, but seems @@ -53,18 +55,23 @@ InfeedBuffer* InfeedManager::BlockingDequeueBuffer() { while (enqueued_buffer_.empty()) { cv_.wait(l); } - CHECK(!current_buffer_); - current_buffer_ = enqueued_buffer_.front(); + InfeedBuffer* current_buffer = enqueued_buffer_.front(); enqueued_buffer_.pop_front(); - return current_buffer_; + dequeued_buffer_.insert(current_buffer); + return current_buffer; } -void InfeedManager::ReleaseCurrentBuffer(se::DeviceMemoryBase* device_memory) { - tensorflow::mutex_lock l(mu_); - CHECK(current_buffer_); - CHECK(device_memory->IsSameAs(*current_buffer_->device_memory())); - current_buffer_->Done(); - current_buffer_ = nullptr; +void InfeedManager::ReleaseBuffers(const std::vector& buffers) { + { + tensorflow::mutex_lock l(mu_); + for (gpu::InfeedBuffer* b : buffers) { + CHECK(ContainsKey(dequeued_buffer_, b)); + dequeued_buffer_.erase(b); + } + } + for (gpu::InfeedBuffer* b : buffers) { + b->Done(); + } } se::Stream* InfeedManager::GetStream(se::StreamExecutor* executor) { diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.h b/tensorflow/compiler/xla/service/gpu/infeed_manager.h index 50d0ce340f3d85c2c46f111dba3e316ff0f4df1a..73d5a5ce35497f156a181371bfb97fc37a8eb09e 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_manager.h +++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.h @@ -23,6 +23,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -81,25 +82,19 @@ class InfeedManager { // condition is to call Reset when no computation is taking place. void Reset(); - // Adds buffer to the infeed queue. buffer->Done will be called when - // the buffer will no longer be accessed by the InfeedManager, - // either as a result of a call to Reset or because the runtime has - // dequeued and used the buffer. - void EnqueueBuffer(InfeedBuffer* buffer); + // Adds a set of buffers to the infeed queue atomically. buffer->Done + // will be called when the buffer will no longer be accessed by the + // InfeedManager, either as a result of a call to Reset or because the + // runtime has dequeued and used the buffer. + void EnqueueBuffers(const std::vector& buffers); // Blocks until the infeed queue is non-empty, then returns the - // buffer at the head of the queue. Sets the current buffer to be - // the returned buffer. It is an error to call BlockingDequeueBuffer - // if there is an unreleased current buffer, i.e., - // ReleaseCurrentBuffer must be called between calls to - // BlockingDequeueBuffer. + // buffer at the head of the queue. Adds the current buffer to the + // to-be released set. InfeedBuffer* BlockingDequeueBuffer(); - // Releases the current buffer, which is the last buffer returned by - // BlockingDequeueBuffer and not yet released. device_memory must - // match that of the current buffer. - void ReleaseCurrentBuffer( - perftools::gputools::DeviceMemoryBase* device_memory); + // Releases a set of buffers from the to-be released set. + void ReleaseBuffers(const std::vector& buffers); // Returns a cached stream associated with an executor. Allocates a // new stream on the first invocation. On subsequent invocations, if @@ -109,18 +104,25 @@ class InfeedManager { perftools::gputools::StreamExecutor* executor); private: + // TODO(b/30467474): Revisit if this mutex becomes a point of + // contention. tensorflow::mutex mu_; + // Condition variable that is signaled every time a buffer is // enqueued to an empty queue. tensorflow::condition_variable cv_; + // InfeedBuffer* queue contents are not owned, but buffer->Done must // be called when the buffer is no longer needed by the runtime. std::deque enqueued_buffer_; - // If non-NULL, the buffer that is currently being processed by the + + // Buffers that are dequeued and currently being processed by the // runtime. Not owned. - InfeedBuffer* current_buffer_; + tensorflow::gtl::FlatSet dequeued_buffer_; + // Cached host to device stream for queuing infeed data. std::unique_ptr host_to_device_stream_; + // Executor that the host_to_device_stream belongs to. Not owned. perftools::gputools::StreamExecutor* host_to_device_executor_; }; diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc index 6f144c7273e69beedeb143c395ce37414ce99139..e33e904692ca5ad41e17d2e165dbb40b6bd4aa33 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc @@ -21,31 +21,59 @@ limitations under the License. namespace xla { namespace gpu { -InfeedThunk::InfeedThunk(const BufferAllocation::Slice& destination_buffer, - uint64 mem_size, const HloInstruction* hlo_instruction) +InfeedThunk::InfeedThunk( + tensorflow::gtl::ArraySlice tuple_element_buffers, + const BufferAllocation::Slice& destination_buffer, + const HloInstruction* hlo_instruction) : Thunk(Kind::kInfeed, hlo_instruction), - destination_buffer_(destination_buffer), - mem_size_(mem_size) {} + tuple_element_buffers_(tuple_element_buffers.begin(), + tuple_element_buffers.end()), + destination_buffer_(destination_buffer) {} tensorflow::Status InfeedThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, perftools::gputools::Stream* stream) { VLOG(2) << "Infeeding to GPU "; - perftools::gputools::DeviceMemoryBase destination_data = + + perftools::gputools::DeviceMemoryBase destination_address = buffer_allocations.GetDeviceAddress(destination_buffer_); InfeedManager* infeed_manager = GetOrCreateInfeedManager(); - InfeedBuffer* buffer = infeed_manager->BlockingDequeueBuffer(); - CHECK_EQ(buffer->length(), mem_size_); - stream->ThenMemcpy(&destination_data, *(buffer->device_memory()), - buffer->length()); + std::vector infeed_buffers; + if (ShapeUtil::IsTuple(hlo_instruction()->shape())) { + CHECK(!ShapeUtil::IsNestedTuple(hlo_instruction()->shape())); + // Transfer the tuple elements first. + std::vector tuple_element_addresses; + for (BufferAllocation::Slice tuple_element_buffer : + tuple_element_buffers_) { + perftools::gputools::DeviceMemoryBase tuple_element_address = + buffer_allocations.GetDeviceAddress(tuple_element_buffer); + + InfeedBuffer* buffer = infeed_manager->BlockingDequeueBuffer(); + infeed_buffers.push_back(buffer); + stream->ThenMemcpy(&tuple_element_address, *(buffer->device_memory()), + buffer->length()); + tuple_element_addresses.push_back(tuple_element_address.opaque()); + } + // Transfer the tuple outer buffer. + auto host_size = tuple_element_addresses.size() * sizeof(void*); + stream->ThenMemcpy(&destination_address, tuple_element_addresses.data(), + host_size); + } else { + InfeedBuffer* buffer = infeed_manager->BlockingDequeueBuffer(); + infeed_buffers.push_back(buffer); + stream->ThenMemcpy(&destination_address, *(buffer->device_memory()), + buffer->length()); + } + if (!stream->BlockHostUntilDone()) { return InternalError("Failed to complete data transfer on stream %p", stream); } - // Since Infeeds are totally ordered, no other infeed should sneak - // in and we should be able to release the same buffer we dequeued. - infeed_manager->ReleaseCurrentBuffer(buffer->device_memory()); + + infeed_manager->ReleaseBuffers(infeed_buffers); + + VLOG(2) << "Infeeding to GPU complete"; return tensorflow::Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h index 0a808186c212660e4be3905456d29cb2fed0f511..371d71f9dbdd21cb5f36cc3108c8f398a4a91c29 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h @@ -35,8 +35,10 @@ class InfeedThunk : public Thunk { // infeed queue to the device buffer // `destination_buffer`. `mem_size` is the size of the data in // bytes. - InfeedThunk(const BufferAllocation::Slice& destination_buffer, - uint64 mem_size, const HloInstruction* hlo_instruction); + InfeedThunk(tensorflow::gtl::ArraySlice + tuple_element_buffers, + const BufferAllocation::Slice& destination_buffer, + const HloInstruction* hlo_instruction); InfeedThunk(const InfeedThunk&) = delete; InfeedThunk& operator=(const InfeedThunk&) = delete; @@ -46,8 +48,8 @@ class InfeedThunk : public Thunk { perftools::gputools::Stream* stream) override; private: + const std::vector tuple_element_buffers_; const BufferAllocation::Slice destination_buffer_; - const uint64 mem_size_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 607a366ac67d98d11c5141b390420aef00539dcd..718e27101e0dc2bfb1338f17979d452b08a2a376 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -118,8 +118,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { IrEmitterContext* ir_emitter_context, bool is_nested); // A convenient helper for calling HloToIrBindings::GetIrArray. - llvm_ir::IrArray GetIrArray(const HloInstruction& inst) { - return bindings_.GetIrArray(inst); + llvm_ir::IrArray GetIrArray(const HloInstruction& inst, + const ShapeIndex& shape_index = {}) { + return bindings_.GetIrArray(inst, shape_index); } // A convenient helper for calling HloToIrBindings::GetBasePointer. llvm::Value* GetBasePointer(const HloInstruction& inst) const { @@ -231,7 +232,7 @@ class IrEmitterUnnested : public IrEmitter { // IrEmitterUnnested handles the following instructions differently from // IrEmitter. - Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override; + Status HandleCopy(HloInstruction* copy) override; Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, HloInstruction* rhs, const Window& window) override; Status HandleDot(HloInstruction* dot, HloInstruction* lhs_instruction, diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 5fa2bfdd7e4301144054e0d4f41d1161e798176b..484de369675fb0188754d4bc2d187cbc6c92259b 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -722,8 +722,7 @@ int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output, } // namespace -Status IrEmitterUnnested::HandleCopy(HloInstruction* copy, - HloInstruction* operand) { +Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { if (ImplementedAsMemcpy(*copy)) { thunk_sequence_->emplace_back(BuildCopyThunk(copy)); return Status::OK(); @@ -731,7 +730,7 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy, bool is_transpose_021; Shape reduced_input_shape, reduced_output_shape; std::tie(is_transpose_021, reduced_input_shape, reduced_output_shape) = - IsTranspose021(operand->shape(), copy->shape()); + IsTranspose021(copy->operand(0)->shape(), copy->shape()); if (is_transpose_021 && reduced_input_shape.dimensions(1) >= kMinDimensionToTransposeTiled && reduced_input_shape.dimensions(2) >= kMinDimensionToTransposeTiled) { @@ -739,7 +738,8 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy, VLOG(3) << "Emitting tiled 0-2-1 transposition"; constexpr int64 tile_size = 32; int64 num_tiles = EmitTranspose021Tiled( - GetIrArray(*operand).CastToShape(reduced_input_shape, &ir_builder_), + GetIrArray(*(copy->operand(0))) + .CastToShape(reduced_input_shape, &ir_builder_), GetIrArray(*copy).CastToShape(reduced_output_shape, &ir_builder_), tile_size, &ir_builder_); UpdateLaunchDimensions(LaunchDimensions(num_tiles, tile_size), LastThunk(), @@ -747,7 +747,7 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy, return Status::OK(); } - return IrEmitter::HandleCopy(copy, operand); + return IrEmitter::HandleCopy(copy); } Status IrEmitterUnnested::EmitColumnReduction( @@ -1648,7 +1648,7 @@ std::unique_ptr IrEmitterUnnested::BuildCopyThunk( const HloInstruction* operand = inst->operand(0); CHECK_EQ(HloOpcode::kConstant, operand->opcode()); return MakeUnique( - /*source_address=*/LiteralUtil::InternalData(operand->literal()), + /*source_address=*/operand->literal().InternalData(), /*destination_buffer=*/GetAllocationSlice(*inst), /*mem_size=*/ llvm_ir::ByteSizeOf(operand->shape(), @@ -1659,12 +1659,18 @@ std::unique_ptr IrEmitterUnnested::BuildCopyThunk( std::unique_ptr IrEmitterUnnested::BuildInfeedThunk( const HloInstruction* inst) { CHECK_EQ(HloOpcode::kInfeed, inst->opcode()); + + std::vector tuple_element_buffers; + for (int64 i = 0; i < inst->shape().tuple_shapes_size(); ++i) { + BufferAllocation::Slice buffer = ir_emitter_context_->buffer_assignment() + .GetUniqueSlice(inst, {i}) + .ConsumeValueOrDie(); + tuple_element_buffers.push_back(buffer); + } + return MakeUnique( - /*destination_buffer=*/GetAllocationSlice(*inst), - /*mem_size=*/ - llvm_ir::ByteSizeOf(inst->shape(), - ir_emitter_context_->llvm_module()->getDataLayout()), - inst); + tuple_element_buffers, + /*destination_buffer=*/GetAllocationSlice(*inst), inst); } std::unique_ptr IrEmitterUnnested::BuildGemmThunk( @@ -1880,15 +1886,38 @@ std::unique_ptr IrEmitterUnnested::BuildForThunk( Status IrEmitterUnnested::EmitTargetElementLoopInThunk( const HloInstruction& hlo, const llvm_ir::ElementGenerator& element_generator, KernelThunk* thunk) { + const Shape& element_shape = hlo.IsMultiOutputFusion() + ? ShapeUtil::GetSubshape(hlo.shape(), {0}) + : hlo.shape(); LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - hlo.shape(), ir_emitter_context_->device_description()); + element_shape, ir_emitter_context_->device_description()); UpdateLaunchDimensions(launch_dimensions, thunk, ir_emitter_context_->llvm_module()); - // Otherwise, emit a parallel loop that computes the partition that each - // thread is in charge of. - return ParallelLoopEmitter(element_generator, GetIrArray(hlo), - launch_dimensions, &ir_builder_) - .EmitLoop(); + if (!hlo.IsMultiOutputFusion()) { + return ParallelLoopEmitter(element_generator, GetIrArray(hlo), + launch_dimensions, &ir_builder_) + .EmitLoop(); + } + + // For multiple outputs fusion, we need to emit each operand and the root. + std::vector output_arrays; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) { + output_arrays.push_back(GetIrArray(hlo, {i})); + } + TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, output_arrays, + launch_dimensions, &ir_builder_) + .EmitLoop()); + + std::vector tuple_operand_ptrs; + for (int64 i = 0; i < output_arrays.size(); ++i) { + tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); + } + ir_builder_.SetInsertPoint(ir_builder_.GetInsertBlock()->getTerminator()); + // const HloInstruction* root = hlo.fused_expression_root(); + llvm_ir::EmitTuple( + GetIrArray(*hlo.fused_expression_root()->fusion_instruction()), + tuple_operand_ptrs, &ir_builder_); + return Status::OK(); } Status IrEmitterUnnested::EmitTargetElementLoop( diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD index 724549c0c4ef46e7526953f41439ea8eff71a779..1d1e5bee542c1c682fa74121934348e7e7a1b026 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD @@ -28,10 +28,10 @@ cc_library( "utils.h", ], deps = [ + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/legacy_flags:gpu_backend_lib_flags", "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index e03571a9672df62593318766fcecf414e0899ea1..881522a0298a8c8cd45d03a4863ad5e995bd4b13 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h" @@ -134,13 +133,8 @@ static string GetSmName(std::pair compute_capability) { // from the input filename. string MakeNameForTempProduct(const std::string& input_filename, tensorflow::StringPiece extension) { - legacy_flags::GpuBackendLibFlags* flags = - legacy_flags::GetGpuBackendLibFlags(); - return tensorflow::io::JoinPath( - flags->dump_temp_products_to, - ReplaceFilenameExtension( - tensorflow::io::Basename(llvm_ir::AsString(input_filename)), - extension)); + return ReplaceFilenameExtension( + tensorflow::io::Basename(llvm_ir::AsString(input_filename)), extension); } // Initializes LLVM passes. Uses the PassRegistry mechanism. @@ -177,20 +171,16 @@ std::unique_ptr GetTargetMachine( .xla_enable_fast_math(), &target_options); - // Enable FMA synthesis if desired. - legacy_flags::GpuBackendLibFlags* flags = - legacy_flags::GetGpuBackendLibFlags(); - if (flags->fma) { - target_options.AllowFPOpFusion = FPOpFusion::Fast; - } + // Enable FMA synthesis. + target_options.AllowFPOpFusion = FPOpFusion::Fast; // Set the verbose assembly options. - target_options.MCOptions.AsmVerbose = flags->verbose_ptx_asm; + target_options.MCOptions.AsmVerbose = false; // The selection of codegen optimization level is copied from function // GetCodeGenOptLevel in //external/llvm/tools/opt/opt.cpp. CodeGenOpt::Level codegen_opt_level; - switch (flags->opt_level) { + switch (hlo_module_config.debug_options().xla_backend_optimization_level()) { case 1: codegen_opt_level = CodeGenOpt::Less; break; @@ -262,12 +252,10 @@ string EmitModuleToPTX(Module* module, llvm::TargetMachine* target_machine) { // The extension is stripped by IrDumpingPassManager, so we need to // get creative to add a suffix. string module_id(llvm_ir::AsString(module->getModuleIdentifier())); - legacy_flags::GpuBackendLibFlags* flags = - legacy_flags::GetGpuBackendLibFlags(); IrDumpingPassManager codegen_passes( ReplaceFilenameExtension(tensorflow::io::Basename(module_id), "-nvptx.dummy"), - flags->dump_temp_products_to, flags->dump_ir_before_passes); + "", false); codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass( llvm::Triple(module->getTargetTriple()))); @@ -345,36 +333,19 @@ StatusOr CompileModuleToPtx(llvm::Module* module, TF_RETURN_IF_ERROR( LinkLibdeviceIfNecessary(module, compute_capability, libdevice_dir_path)); - legacy_flags::GpuBackendLibFlags* flags = - legacy_flags::GetGpuBackendLibFlags(); - if (!flags->dump_temp_products_to.empty()) { - string linked_filename = - MakeNameForTempProduct(module->getModuleIdentifier(), "linked.bc"); - LOG(INFO) << "dumping bitcode after linking libdevice to: " - << linked_filename; - EmitBitcodeToFile(*module, linked_filename); - } - // Set the flush-denormals-to-zero flag on the module so the NVVM reflect pass // can access it. - module->addModuleFlag(llvm::Module::Override, "nvvm-reflect-ftz", flags->ftz); + module->addModuleFlag(llvm::Module::Override, "nvvm-reflect-ftz", + hlo_module_config.debug_options().xla_gpu_ftz()); // If ftz is enabled, set it as an attribute on every function in the module. - if (flags->ftz) { + if (hlo_module_config.debug_options().xla_gpu_ftz()) { for (llvm::Function& fn : *module) { fn.addFnAttr("nvptx-f32ftz", "true"); } } - // Run IR-level optimizations. - if (flags->dump_ir_before_passes && flags->dump_temp_products_to.empty()) { - LOG(FATAL) << "--dump_ir_before_passes must be specified with " - "--dump_temp_products_to"; - } - - IrDumpingPassManager module_passes(module->getModuleIdentifier(), - flags->dump_temp_products_to, - flags->dump_ir_before_passes); + IrDumpingPassManager module_passes(module->getModuleIdentifier(), "", false); // Add an appropriate TargetLibraryInfo pass for the module's triple. llvm::TargetLibraryInfoWrapperPass* tliwp = @@ -406,8 +377,16 @@ StatusOr CompileModuleToPtx(llvm::Module* module, // too. llvm::legacy::FunctionPassManager function_passes(module); - AddOptimizationPasses(flags->opt_level, /*size_level=*/0, - target_machine.get(), &module_passes, &function_passes); + int32 opt_level = + hlo_module_config.debug_options().xla_backend_optimization_level(); + + CHECK_GE(opt_level, 2) + << "The XLA GPU backend doesn't support unoptimized code generation"; + + AddOptimizationPasses(opt_level, + /*size_level=*/0, target_machine.get(), &module_passes, + &function_passes); + // Loop unrolling exposes more opportunities for SROA. Therefore, we run SROA // again after the standard optimization passes [http://b/13329423]. // TODO(jingyue): SROA may further expose more optimization opportunities, such @@ -415,7 +394,7 @@ StatusOr CompileModuleToPtx(llvm::Module* module, // the inlining cost of a function). For now, running SROA already emits good // enough code for the evaluated benchmarks. We may want to run more // optimizations later. - if (flags->opt_level > 0) { + if (opt_level > 0) { // LLVM's optimizer turns on SROA when the optimization level is greater // than 0. We mimic this behavior here. module_passes.add(llvm::createSROAPass()); @@ -433,14 +412,6 @@ StatusOr CompileModuleToPtx(llvm::Module* module, function_passes.doFinalization(); module_passes.run(*module); - if (!flags->dump_temp_products_to.empty()) { - string optimized_filename = - MakeNameForTempProduct(module->getModuleIdentifier(), "optimized.bc"); - LOG(INFO) << "dumping bitcode after optimizations to: " - << optimized_filename; - EmitBitcodeToFile(*module, optimized_filename); - } - // Finally, produce PTX. return EmitModuleToPTX(module, target_machine.get()); } @@ -473,22 +444,6 @@ void GPUBackendInit() { // between those loads. FeedLLVMWithFlags({"-memdep-block-scan-limit=500"}); - legacy_flags::GpuBackendLibFlags* flags = - legacy_flags::GetGpuBackendLibFlags(); - if (!flags->llvm_cl_opts.empty()) { - std::vector opts = - tensorflow::str_util::Split(flags->llvm_cl_opts, ','); - FeedLLVMWithFlags(opts); - } - - if (flags->llvm_dump_passes) { - // Enable LLVM pass debugging dump. LLVM dumps this information when a pass - // manager is initialized for execution. It's done to stderr (this is - // hardcoded within LLVM to the dbgs() stream, we can't change it from the - // outside). - FeedLLVMWithFlags({"-debug-pass=Arguments"}); - } - // Initialize the NVPTX target; it's the only target we link with, so call its // specific initialization functions instead of the catch-all InitializeAll*. LLVMInitializeNVPTXTarget(); diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index a12a9a716829fbcf5b6348037fa723d5ddcc6930..b8c61620845a1434cc79dc9a8b00f89944e2ae95 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -61,7 +61,7 @@ HloInstruction* MaybePaddedAndSlicedInput( PrimitiveType element_type = input->shape().element_type(); HloInstruction* padding = computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique(LiteralUtil::Zero(element_type)))); + MakeUnique(Literal::Zero(element_type)))); input = computation->AddInstruction(HloInstruction::CreatePad( ShapeInference::InferPadShape( /*operand_shape=*/input->shape(), @@ -127,7 +127,7 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window, PrimitiveType element_type = kernel->shape().element_type(); HloInstruction* padding = computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique(LiteralUtil::Zero(element_type)))); + MakeUnique(Literal::Zero(element_type)))); return computation->AddInstruction(HloInstruction::CreatePad( ShapeInference::InferPadShape( /*operand_shape=*/kernel->shape(), @@ -242,9 +242,9 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // Create a new backward convolution replacing the old one. HloComputation* computation = backward_conv->parent(); HloInstruction* output = backward_conv->mutable_operand(1); - HloInstruction* padding = computation->AddInstruction( - HloInstruction::CreateConstant(MakeUnique( - LiteralUtil::Zero(input->shape().element_type())))); + HloInstruction* padding = + computation->AddInstruction(HloInstruction::CreateConstant( + MakeUnique(Literal::Zero(input->shape().element_type())))); HloInstruction* padded_input = computation->AddInstruction(HloInstruction::CreatePad( ShapeInference::InferPadShape(input->shape(), padding->shape(), diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc index 65610b0995c512cc4a611ac650c581d0180d258d..d5543d296b3f0f6b19de90c42bea4f162057802a 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc @@ -36,6 +36,13 @@ ParallelLoopEmitter::ParallelLoopEmitter( : LoopEmitter(body_emitter, shape, ir_builder), launch_dimensions_(launch_dimensions) {} +ParallelLoopEmitter::ParallelLoopEmitter( + const llvm_ir::ElementGenerator& target_element_generator, + tensorflow::gtl::ArraySlice target_arrays, + const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder) + : LoopEmitter(target_element_generator, target_arrays, ir_builder), + launch_dimensions_(launch_dimensions) {} + ParallelLoopEmitter::ParallelLoopEmitter( const llvm_ir::ElementGenerator& target_element_generator, const llvm_ir::IrArray& target_array, diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h index 73ca28cd842fe350ecd10885d983907e7288a350..d324a50698ea0d3e5e196347bd69c29b2ad27e3e 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h @@ -41,6 +41,12 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { const llvm_ir::IrArray& target_array, const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder); + + ParallelLoopEmitter( + const llvm_ir::ElementGenerator& target_element_generator, + tensorflow::gtl::ArraySlice target_arrays, + const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder); + ParallelLoopEmitter(const ParallelLoopEmitter&) = delete; ParallelLoopEmitter& operator=(const ParallelLoopEmitter&) = delete; ~ParallelLoopEmitter() override = default; diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc index 5065e7aedd08c591f33c152c6709823948db54f0..a304decc4917cee1ad3ae4b28730a6ab751ec5f6 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" namespace xla { namespace gpu { @@ -46,10 +47,9 @@ namespace { // Returns whether the two HLOs can run concurrently, i.e., neither is a // transitive consumer of the other. -bool CanRunConcurrently( - const HloInstruction& a, const HloInstruction& b, - const HloComputation::ReachabilityMap& transitive_operands) { - return !transitive_operands.IsConnected(&a, &b); +bool CanRunConcurrently(const HloInstruction& a, const HloInstruction& b, + const HloReachabilityMap& reachability) { + return !reachability.IsConnected(&a, &b); } // Returns which existing stream to assign to `hlo`, or -1 if a stream is not @@ -58,7 +58,7 @@ bool CanRunConcurrently( // are topologically before `hlo`. int ComputeStreamToAssign( const HloInstruction& hlo, const StreamAssignment& stream_assignment, - const HloComputation::ReachabilityMap& transitive_operands, + const HloReachabilityMap& reachability, const std::vector& seen_gemms) { if (hlo.opcode() == HloOpcode::kParameter || hlo.opcode() == HloOpcode::kConstant) { @@ -96,7 +96,7 @@ int ComputeStreamToAssign( for (const auto* seen_gemm : seen_gemms) { int stream_no = stream_assignment.StreamNumberForHlo(*seen_gemm); if (!forbidden_stream_numbers.count(stream_no) && - CanRunConcurrently(*seen_gemm, hlo, transitive_operands)) { + CanRunConcurrently(*seen_gemm, hlo, reachability)) { forbidden_stream_numbers.insert(stream_no); } } @@ -115,12 +115,12 @@ int ComputeStreamToAssign( std::unique_ptr AssignStreams(const HloModule& module) { auto stream_assignment = MakeUnique(); const HloComputation& computation = *module.entry_computation(); - std::unique_ptr transitive_operands = - computation.ComputeTransitiveOperands(); + std::unique_ptr reachability = + computation.ComputeReachability(); std::vector seen_gemms; for (const auto* hlo : computation.MakeInstructionPostOrder()) { int stream_no = ComputeStreamToAssign(*hlo, *stream_assignment, - *transitive_operands, seen_gemms); + *reachability, seen_gemms); if (stream_no != -1) { stream_assignment->AssignStreamToHlo(hlo, stream_no); } diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc index 06b01d311dac5a6be78d7b8b16e7fcb39c189647..3034ed06b7eaff46a923b19cedb39f02d276c9f8 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer.cc @@ -37,8 +37,8 @@ namespace { // patterns to match. // // Each ExprTree node is comprised of an HloOpcode, and a set of operands (each -// of type ExprTree). Operands can be added by specifying the index and HloOpcode -// of the operand. +// of type ExprTree). Operands can be added by specifying the index and +// HloOpcode of the operand. // // For example, the following computation: // @@ -197,10 +197,9 @@ class MatcherBase { return InvalidArgument("Must use S32 or S64 integral types."); } if (type == S32) { - *const_value = - static_cast(LiteralUtil::GetFirstElement(literal)); + *const_value = static_cast(literal.GetFirstElement()); } else if (type == S64) { - *const_value = LiteralUtil::GetFirstElement(literal); + *const_value = literal.GetFirstElement(); } return tensorflow::Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index e82491fd6f9f1158fc5b9e5bd475ef6ff97f2a7c..51d38f84212b01c08c33f1b648c579c5672769ba 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -41,7 +41,7 @@ class WhileTransformerTest : public HloTestBase { const int64 tuple_index, const int64 limit) { auto builder = HloComputation::Builder(TestName() + ".Condition"); auto limit_const = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(limit))); + HloInstruction::CreateConstant(Literal::CreateR0(limit))); auto loop_state = builder.AddInstruction( HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); auto induction_variable = @@ -64,8 +64,8 @@ class WhileTransformerTest : public HloTestBase { auto induction_variable = builder.AddInstruction(HloInstruction::CreateGetTupleElement( induction_variable_shape_, loop_state, ind_var_tuple_index)); - auto inc = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR0(increment))); + auto inc = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(increment))); auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); // Update data GTE(data_tuple_index). @@ -88,12 +88,10 @@ class WhileTransformerTest : public HloTestBase { const int64 ind_var_tuple_index, const int64 ind_var_init) { auto builder = HloComputation::Builder(TestName() + ".While"); - auto induction_var_init = - builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR0(ind_var_init))); - auto data_init = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1( - {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); + auto induction_var_init = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(ind_var_init))); + auto data_init = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); auto loop_state_init = ind_var_tuple_index == 0 ? builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu_transfer_manager.cc index 4b8d190a463ceb155f4fc8d3d22b47b9cbc8f23f..74f0bdb7db1847119c5bd75cc9fd9d921c6e162a 100644 --- a/tensorflow/compiler/xla/service/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu_transfer_manager.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -28,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -44,24 +44,85 @@ GpuTransferManager::GpuTransferManager() Status GpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, const Literal& literal) { const Shape& shape = literal.shape(); - VLOG(2) << "Transferring literal shape to infeed: " + VLOG(2) << "Transferring literal to infeed with shape: " << ShapeUtil::HumanString(shape); - // TODO(b/30467474) handle tuples. - if (ShapeUtil::IsTuple(shape)) { - return Unimplemented("Infeed with a tuple shape is not supported: %s", - ShapeUtil::HumanString(literal.shape()).c_str()); + if (!ShapeUtil::IsTuple(shape)) { + int64 size = GetByteSizeRequirement(shape); + return TransferBufferToInfeed(executor, size, literal.InternalData()); } - int64 size = GetByteSizeRequirement(shape); + if (ShapeUtil::IsNestedTuple(shape)) { + return Unimplemented( + "Infeed with a nested tuple shape is not supported: %s", + ShapeUtil::HumanString(literal.shape()).c_str()); + } + + // For a tuple, we transfer each of its elements to the device and + // enqueue the resulting destination device addresses with the + // infeed manager. + std::vector buffers; + buffers.reserve(literal.tuple_literals_size()); + auto cleanup = tensorflow::gtl::MakeCleanup([buffers]() { + for (gpu::InfeedBuffer* b : buffers) { + b->Done(); + } + }); + + for (const auto& tuple_element : literal.tuple_literals()) { + const Shape& tuple_element_shape = tuple_element.shape(); + int64 tuple_element_size = GetByteSizeRequirement(tuple_element_shape); + TF_ASSIGN_OR_RETURN( + gpu::InfeedBuffer * buffer, + TransferBufferToInfeedInternal(executor, tuple_element_size, + tuple_element.InternalData())); + buffers.push_back(buffer); + } + + cleanup.release(); + return EnqueueBuffersToInfeed(executor, buffers); +} + +Status GpuTransferManager::TransferBufferToInfeed(se::StreamExecutor* executor, + int64 size, + const void* source) { + TF_ASSIGN_OR_RETURN(gpu::InfeedBuffer * buffer, + TransferBufferToInfeedInternal(executor, size, source)); + return EnqueueBuffersToInfeed(executor, {buffer}); +} + +Status GpuTransferManager::EnqueueBuffersToInfeed( + se::StreamExecutor* executor, std::vector buffers) { + gpu::InfeedManager* infeed_manager = gpu::GetOrCreateInfeedManager(); + se::Stream* stream = infeed_manager->GetStream(executor); + + // TODO(b/30467474): Since this stream is shared across different + // infeed requests, blocking on the stream might be + // heavy-handed. Figure out if finer-grained acknowledgement is + // possible. + if (!stream->BlockHostUntilDone()) { + for (gpu::InfeedBuffer* b : buffers) { + b->Done(); + } + return InternalError("Failed to complete data transfer on stream %p", + stream); + } + + infeed_manager->EnqueueBuffers(buffers); + + VLOG(2) << "Infeed data transferred"; + + return Status::OK(); +} + +StatusOr GpuTransferManager::TransferBufferToInfeedInternal( + se::StreamExecutor* executor, int64 size, const void* source) { if (size > std::numeric_limits::max()) { - return Unimplemented("Infeed shape is too large: %s needs %lld bytes", - ShapeUtil::HumanString(literal.shape()).c_str(), size); + return InvalidArgument("Infeed shape is too large: needs %lld bytes", size); } if (size == 0) { - return Unimplemented("Infeed shape %s needs 0 bytes", - ShapeUtil::HumanString(literal.shape()).c_str()); + return InvalidArgument("Infeed shape needs 0 bytes"); } gpu::InfeedManager* infeed_manager = gpu::GetOrCreateInfeedManager(); @@ -71,21 +132,11 @@ Status GpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, } gpu::InfeedBuffer* buffer = new gpu::InfeedBuffer(executor, size); - stream->ThenMemcpy(buffer->device_memory(), - LiteralUtil::InternalData(literal), size); + stream->ThenMemcpy(buffer->device_memory(), source, size); VLOG(2) << "Queued infeed data on stream " << stream; - if (!stream->BlockHostUntilDone()) { - buffer->Done(); - return InternalError("Failed to complete data transfer on stream %p", - stream); - } - - infeed_manager->EnqueueBuffer(buffer); - - VLOG(2) << "Infeed data transferred"; - return Status::OK(); + return buffer; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu_transfer_manager.h b/tensorflow/compiler/xla/service/gpu_transfer_manager.h index 6dfe7ba0295aea699ca737e9dd47123b17cae3dc..9aa369c668364079504ead3491903e2590a142cc 100644 --- a/tensorflow/compiler/xla/service/gpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/gpu_transfer_manager.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/generic_transfer_manager.h" +#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -37,8 +38,21 @@ class GpuTransferManager : public GenericTransferManager { Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor, const Literal& literal) override; + Status TransferBufferToInfeed(perftools::gputools::StreamExecutor* executor, + int64 size, const void* source) override; private: + // Initiates the infeed data transfers. InfeedBuffer->Done() must be + // called to clean up the memory allocated for InfeedBuffer. + StatusOr TransferBufferToInfeedInternal( + perftools::gputools::StreamExecutor* executor, int64 size, + const void* source); + + // Enqueues infeed data buffers with the infeed manager after their + // transfer completes. + Status EnqueueBuffersToInfeed(perftools::gputools::StreamExecutor* executor, + std::vector buffers); + TF_DISALLOW_COPY_AND_ASSIGN(GpuTransferManager); }; diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc index cd00a41a03718502fcfa63e035639390b6fe6e07..049e8d80d80c835bca4a4d38592564ba82a3ecf9 100644 --- a/tensorflow/compiler/xla/service/graphviz_example.cc +++ b/tensorflow/compiler/xla/service/graphviz_example.cc @@ -47,7 +47,7 @@ HloComputation* AddScalarConstantComputation(int64 addend, HloModule* module) { auto x_value = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {}), "x_value")); auto half = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.5))); + HloInstruction::CreateConstant(Literal::CreateR0(0.5))); builder.AddInstruction(HloInstruction::CreateBinary( half->shape(), HloOpcode::kAdd, x_value, half)); return module->AddEmbeddedComputation(builder.Build()); @@ -118,7 +118,7 @@ std::unique_ptr MakeBigGraph() { auto rng = builder.AddInstruction( HloInstruction::CreateRng(vshape, RNG_UNIFORM, {param_m, param_m})); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto add_computation = ScalarSumComputation(module.get()); builder.AddInstruction( HloInstruction::CreateReduce(vshape, rng, one, {1}, add_computation)); @@ -156,10 +156,9 @@ int main(int argc, char** argv) { auto module = xla::MakeBigGraph(); - printf("Graph URL: %s\n", - xla::hlo_graph_dumper::DumpGraph( - *module->entry_computation(), "Example computation", - /*show_addresses=*/false, /*show_layouts=*/false) - .c_str()); + printf("Graph URL: %s\n", xla::hlo_graph_dumper::DumpGraph( + *module->entry_computation(), + "Example computation", xla::DebugOptions()) + .c_str()); return 0; } diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 86f62accd3b524c3aa39c256a982bcf21edc1b25..c662cec9c70c58cab4cd41b939bd1553e0c564bc 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -187,7 +187,7 @@ Status HeapSimulator::RunComputation( buffer->instruction()->opcode() != HloOpcode::kCopy && CanShareOperandBufferWithUser( operand_buffer->instruction(), operand_buffer->index(), - buffer->instruction(), buffer->index(), points_to_analysis)) { + buffer->instruction(), buffer->index(), &points_to_analysis)) { ShareBuffer(buffer, operand_buffer, instruction); shared = true; break; diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 60a0768a86b30ad5e8810a6f289008a9ee8c8a2e..ef9db8ba236f9923420c1f8b1a7423e0c036fb0f 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -173,7 +173,7 @@ class HeapSimulatorTest : public HloTestBase { TEST_F(HeapSimulatorTest, ScalarConstant) { auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); // Constants aren't assigned. See b/32248867 HeapSimulatorTracker tracker(TestName(), builder.Build(), {const0}); @@ -510,8 +510,7 @@ class HeapAlgorithmTestBase : public ::testing::Test { // other than the id and color. const LogicalBuffer* DummyLogicalBuffer() { const LogicalBuffer::Id id = buffers_.size(); - buffers_.emplace_back(MakeUnique(nullptr, ShapeIndex{}, id, - LogicalBuffer::Color(0))); + buffers_.emplace_back(MakeUnique(nullptr, ShapeIndex{}, id)); return buffers_.back().get(); } diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index 3b37f4a4b892497135c4dccc0082d244c1d8a27e..d03093f89f5247b3552c632fdf28c197c85f5fb8 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -15,21 +15,21 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" -#include -#include +#include +#include #include #include #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_buffer.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -38,105 +38,6 @@ using ::tensorflow::str_util::Join; using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; -void HloBuffer::AddValue(const HloValue& value) { - // If the value is already contained in this buffer, just return. - if (std::find(value_ids_.begin(), value_ids_.end(), value.id()) != - value_ids_.end()) { - return; - } - - value_ids_.push_back(value.id()); - - // Add all of the locations of the HloValue to this buffer. - for (const HloLocation& location : value.locations()) { - if (std::find(locations_.begin(), locations_.end(), location) == - locations_.end()) { - locations_.push_back(location); - } - } -} - -bool HloBuffer::operator==(const HloBuffer& other) const { - bool equal = id() == other.id(); - if (equal) { - // DCHECK because these comparisons are expensive (linear time). - DCHECK(value_ids() == other.value_ids()); - DCHECK(locations() == other.locations()); - } - return equal; -} - -string HloBuffer::ToString() const { - return StrCat("HloBuffer ", id_, ", values: ", Join(value_ids_, ", ")); -} - -std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer) { - out << buffer.ToString(); - return out; -} - -void HloBufferSet::AddBuffer(HloBuffer::Id buffer_id) { - if (std::find(buffer_ids_.begin(), buffer_ids_.end(), buffer_id) == - buffer_ids_.end()) { - buffer_ids_.push_back(buffer_id); - } -} - -void HloBufferSet::RemoveBufferOrDie(HloBuffer::Id buffer_id) { - auto it = std::find(buffer_ids_.begin(), buffer_ids_.end(), buffer_id); - CHECK(it != buffer_ids_.end()); - buffer_ids_.erase(it); -} - -string HloBufferSet::ToString() const { - return StrCat("HloBufferSet, buffers: ", Join(buffer_ids_, ", ")); -} - -std::ostream& operator<<(std::ostream& out, const HloBufferSet& buffer_set) { - out << buffer_set.ToString(); - return out; -} - -bool InstructionBufferSet::IsAmbiguous() const { - bool is_ambiguous = false; - ForEachElement( - [&is_ambiguous](const ShapeIndex& index, const HloBufferSet& buffer_set) { - is_ambiguous |= buffer_set.buffer_ids().size() > 1; - }); - return is_ambiguous; -} - -bool InstructionBufferSet::IsDistinct() const { - bool is_distinct = true; - tensorflow::gtl::FlatSet seen_ids; - ForEachElement([&is_distinct, &seen_ids](const ShapeIndex& index, - const HloBufferSet& buffer_set) { - for (HloBuffer::Id buffer_id : buffer_set.buffer_ids()) { - auto pair = seen_ids.insert(buffer_id); - if (!pair.second) { - is_distinct = false; - } - } - }); - return is_distinct; -} - -string InstructionBufferSet::ToString() const { - string out = - StrCat("InstructionBufferSet(", ShapeUtil::HumanString(shape()), ")\n"); - ForEachElement([this, &out](const ShapeIndex& index, - const HloBufferSet& value_set) { - StrAppend(&out, " ", index.ToString(), " : ", value_set.ToString(), "\n"); - }); - return out; -} - -std::ostream& operator<<(std::ostream& out, - const InstructionBufferSet& buffer_set) { - out << buffer_set.ToString(); - return out; -} - HloAliasAnalysis::HloAliasAnalysis(HloModule* module) : module_(module) {} void HloAliasAnalysis::InitializeBufferSets() { @@ -240,7 +141,7 @@ void HloAliasAnalysis::FlattenInstructionBufferSets( VLOG(4) << "Flattening buffer sets of instructions: " << Join(instructions, ", ", [this](string* out, const HloInstruction* instruction) { - StrAppend(out, instruction->FullyQualifiedName()); + StrAppend(out, instruction->name()); }); if (instructions.size() < 2) { return; @@ -282,7 +183,7 @@ string HloAliasAnalysis::ToString() const { module_->computations()) { for (const std::unique_ptr& instruction : computation->instructions()) { - StrAppend(&out, " ", instruction->FullyQualifiedName(), ":\n"); + StrAppend(&out, " ", instruction->name(), ":\n"); auto buffer_str = [this](const HloBuffer& buffer) { return StrCat( "Buffer ", buffer.id(), ", values: ", diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.h b/tensorflow/compiler/xla/service/hlo_alias_analysis.h index 0fa35827b5ecbfd3987a17e60c3b395b36b16b2e..429cfa09158fe1773001b16b58020c508a5d2b65 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.h @@ -16,182 +16,23 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ALIAS_ANALYSIS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ALIAS_ANALYSIS_H_ -#include -#include #include -#include #include #include -#include "tensorflow/compiler/xla/service/call_graph.h" -#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_buffer.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/logical_buffer.h" -#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/types.h" namespace xla { -// A container which can hold one or more HloValues. An HLO buffer abstractly -// represents the allocation which HLO instructions write into and read -// from. Generally there is a one-to-one correspondence between HloBuffers and -// HloValue where each HloValue in the module is held in a unique HloBuffer. An -// exception is the while instruction which updates the loop state in-place. In -// this case, we have a single HloBuffer for each HloLocation in the loop state, -// but multiple HloValues. For example: -// -// %init = ... -// %while = While(%init, body, condition) -// -// body: -// %body_param = Param(0) -// ... -// %body_root = ... -// -// condition: -// %cond_param = Param(0) -// ... -// -// For simplicity, assume that %while is array-shaped. In this case, we have a -// single HloBuffer which holds the following HloValues: HloValue{%init}, -// HloValue{%while}, HloValue{%body_param}, HloValue{%body_root}, and -// HloValue{%cond_param}. -// -// HloBuffers may appear at different HloLocations in the module mirroring the -// same propery of HloValues. For example: -// -// %sub = Sub(...) -// %add = Add(...) -// %tuple = Tuple(%add, %sub) -// %gte = GetTupleElement(%tuple, 0) -// -// In this case, the HloBuffer containing %add appears at the following -// locations: HloLocation{%add, {}}, HloLocation{%tuple, {0}}, and -// HloLocation{%gte, {}}. -// -// Different HloLocations which share the same HloBuffer indicate mandatory -// aliasing in the HLO module. These locations must share the same memory -// allocation for correctness (the backends rely on this property). This differs -// from incidental aliasing introduced by memory reuse in BufferAssignment where -// different instructions may happen to get the same allocation. -class HloBuffer { - public: - using Id = int64; - - HloBuffer(int64 id) : id_(id) {} - - // Return the unique identifier for this HloBuffer. - int64 id() const { return id_; } - - // Add a value to the set of values held by this buffer. Also adds the - // HloLocations of the value to the locations vector of the buffer. If the - // buffer already contains this value, then this method is a nop. - void AddValue(const HloValue& value); - - // Return the IDs of all values contained in this buffer. - const std::vector& value_ids() const { return value_ids_; } - - // Return the locations (output of which instruction and at what index) where - // the buffer is used. This is exactly the union of the locations of the - // HloValues contained by the buffer. - const std::vector& locations() const { return locations_; } - - string ToString() const; - - bool operator==(const HloBuffer& other) const; - bool operator!=(const HloBuffer& other) const { return !(*this == other); } - - private: - // Unique identifier for this HloBuffer. - const Id id_; - - // The set of values contained in the this buffer. - std::vector value_ids_; - - // The set of locations where this buffer is used. - std::vector locations_; -}; - -std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer); - -// A class representing the set of possible HloBuffers at a particular -// HloLocation (shape index in the output of an instruction) in the XLA -// graph. In most cases, the buffer set will have a single HloBuffer indicating -// that the HloBuffer which appears at that particular location is known -// unambiguously at compile-time. However, tuple-shaped Select instructions can -// introduce ambiguity as the tuple elements of the operands are passed by -// reference into the output of the Select. For example: -// -// %pred = ... -// %tuple0 = Tuple(%a, %b) -// %tuple1 = Tuple(%x, %y) -// %select = Select(%pred, %tuple0, %tuple1) -// -// In this case the HloBufferSet at HloLocation{%select, {0}} contains the -// HloBuffer holding %a and the HloBuffer holding %x. -class HloBufferSet { - public: - HloBufferSet() = default; - - // Add the given buffer to this buffer set. If the buffer already exists in - // the set, then this is a NOP. - void AddBuffer(HloBuffer::Id buffer_id); - - // Removes the given buffer from this buffer set. CHECK fails in the buffer is - // not contained in this set. - void RemoveBufferOrDie(HloBuffer::Id buffer_id); - - // Returns the unique buffer in this set. CHECK fails if the set does not - // contain exactly one buffer. - HloBuffer::Id GetUniqueBufferId() const { - CHECK_EQ(buffer_ids().size(), 1); - return buffer_ids()[0]; - } - - // Returns the IDs of the HloBuffers contained in this buffer set. - const std::vector& buffer_ids() const { return buffer_ids_; } - - string ToString() const; - - private: - // The IDs of the HloBuffers containted in this buffer set. - std::vector buffer_ids_; -}; - -std::ostream& operator<<(std::ostream& out, const HloBufferSet& buffer_set); - -// A class collecting the HloBuffers in the output of an HLO instruction. For -// array-shaped instructions, an InstructionBufferSet trivially holds a single -// HloBufferSet. Tuple-shaped InstructionBufferSets hold multiple -// HloBufferSets. -class InstructionBufferSet : public ShapeTree { - public: - InstructionBufferSet(const Shape& shape) : ShapeTree(shape) {} - - // Returns true if any HloBufferSet contained in this InstructionBufferSet - // is not a singleton. - bool IsAmbiguous() const; - - // Returns true if any HloBuffer appears in more than one HloBufferSet - // contained in this InstructionBufferSet. - bool IsDistinct() const; - - string ToString() const; -}; - -std::ostream& operator<<(std::ostream& out, - const InstructionBufferSet& buffer_set); - class HloAliasAnalysis { public: static StatusOr> Run(HloModule* module); diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index 24c467d411b93be32bd884a8bb92ef288d9c2f10..d67b48dff0c6b306ff1a3358da7fa13597a3ffd1 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -37,12 +37,12 @@ using ::testing::UnorderedElementsAre; class HloAliasAnalysisTest : public HloTestBase { protected: - HloAliasAnalysisTest() : module_(TestName()) {} + HloAliasAnalysisTest() : module_(CreateNewModule()) {} // Run alias analysis on the member module. For convenience returns a // reference to the generated analysis stored in analysis_. const HloAliasAnalysis& RunAnalysis() { - analysis_ = HloAliasAnalysis::Run(&module_).ConsumeValueOrDie(); + analysis_ = HloAliasAnalysis::Run(module_.get()).ConsumeValueOrDie(); return *analysis_; } @@ -77,7 +77,31 @@ class HloAliasAnalysisTest : public HloTestBase { return analysis_->dataflow_analysis().GetValue(buffer.value_ids()[0]); } - HloModule module_; + // Returns true if any values held in the same buffer interfere. Generally, in + // the compiler pipeline copy-insertion will guarantee that this interference + // never occurs, but HLO graphs with interference can be explicitly + // constructed. + bool AnyValuesInSameBufferInterfere() { + DependencyHloOrdering ordering(module_.get()); + for (const HloBuffer* buffer : analysis_->buffers()) { + for (HloValue::Id value_id_a : buffer->value_ids()) { + for (HloValue::Id value_id_b : buffer->value_ids()) { + if (value_id_a != value_id_b && + analysis_->dataflow_analysis().MayInterfere( + value_id_a, value_id_b, ordering)) { + VLOG(1) << analysis_->dataflow_analysis().GetValue(value_id_a) + << " interferes with " + << analysis_->dataflow_analysis().GetValue(value_id_b) + << " in buffer: " << *buffer; + return true; + } + } + } + } + return false; + } + + std::unique_ptr module_; std::unique_ptr analysis_; const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); @@ -87,12 +111,12 @@ TEST_F(HloAliasAnalysisTest, BinaryOperation) { // Test the analysis on a single binary operation (Add). auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, constant1, constant2)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -107,6 +131,8 @@ TEST_F(HloAliasAnalysisTest, BinaryOperation) { EXPECT_FALSE(analysis.GetInstructionBufferSet(add).IsAmbiguous()); EXPECT_TRUE(analysis.GetInstructionBufferSet(add).IsDistinct()); + + EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } TEST_F(HloAliasAnalysisTest, TupleAndGtes) { @@ -124,7 +150,7 @@ TEST_F(HloAliasAnalysisTest, TupleAndGtes) { HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1)); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape_, HloOpcode::kAdd, gte0, gte1)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -156,6 +182,8 @@ TEST_F(HloAliasAnalysisTest, TupleAndGtes) { EXPECT_FALSE(analysis.GetInstructionBufferSet(tuple).IsAmbiguous()); EXPECT_TRUE(analysis.GetInstructionBufferSet(tuple).IsDistinct()); + + EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } TEST_F(HloAliasAnalysisTest, NondistinctTuple) { @@ -168,7 +196,7 @@ TEST_F(HloAliasAnalysisTest, NondistinctTuple) { // param0 is included twice in the tuple. auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({param0, param1, param0})); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -179,6 +207,8 @@ TEST_F(HloAliasAnalysisTest, NondistinctTuple) { EXPECT_FALSE(analysis.GetInstructionBufferSet(tuple).IsAmbiguous()); EXPECT_FALSE(analysis.GetInstructionBufferSet(tuple).IsDistinct()); + + EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } TEST_F(HloAliasAnalysisTest, SingleCall) { @@ -192,16 +222,16 @@ TEST_F(HloAliasAnalysisTest, SingleCall) { auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, subparam0, subparam1)); HloComputation* called_computation = - module_.AddEmbeddedComputation(subbuilder.Build()); + module_->AddEmbeddedComputation(subbuilder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto call = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -217,6 +247,8 @@ TEST_F(HloAliasAnalysisTest, SingleCall) { EXPECT_THAT( analysis.GetUniqueBufferAt(add).locations(), UnorderedElementsAre(HloLocation{add, {}}, HloLocation{call, {}})); + + EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } TEST_F(HloAliasAnalysisTest, ComputationCalledTwice) { @@ -229,18 +261,18 @@ TEST_F(HloAliasAnalysisTest, ComputationCalledTwice) { auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, subparam0, subparam1)); HloComputation* called_computation = - module_.AddEmbeddedComputation(subbuilder.Build()); + module_->AddEmbeddedComputation(subbuilder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto call1 = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); auto call2 = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {call1, constant2}, called_computation)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -269,6 +301,8 @@ TEST_F(HloAliasAnalysisTest, ComputationCalledTwice) { EXPECT_FALSE(analysis.GetInstructionBufferSet(subparam1).IsAmbiguous()); EXPECT_TRUE(analysis.GetInstructionBufferSet(subparam0).IsDistinct()); EXPECT_TRUE(analysis.GetInstructionBufferSet(subparam1).IsDistinct()); + + EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } TEST_F(HloAliasAnalysisTest, SingleWhile) { @@ -303,27 +337,27 @@ TEST_F(HloAliasAnalysisTest, SingleWhile) { scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); auto body_tuple = body_builder.AddInstruction( HloInstruction::CreateTuple({body_element_0, add})); - HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build()); + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); // Condition computation trivially returns a constant "false". auto cond_builder = HloComputation::Builder("condition"); auto cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); HloComputation* condition = - module_.AddEmbeddedComputation(cond_builder.Build()); + module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, tuple)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -356,6 +390,8 @@ TEST_F(HloAliasAnalysisTest, SingleWhile) { GetValueDefinedAt(body_param, {1}), GetValueDefinedAt(cond_param, {1}), GetValueDefinedAt(add))); + + EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } TEST_F(HloAliasAnalysisTest, SequentialWhiles) { @@ -392,21 +428,21 @@ TEST_F(HloAliasAnalysisTest, SequentialWhiles) { scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); body_builder.AddInstruction( HloInstruction::CreateTuple({body_element_0, add})); - HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build()); + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); auto cond_builder = HloComputation::Builder("condition"); cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); HloComputation* condition = - module_.AddEmbeddedComputation(cond_builder.Build()); + module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while0 = builder.AddInstruction( @@ -415,7 +451,7 @@ TEST_F(HloAliasAnalysisTest, SequentialWhiles) { HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while0)); auto xla_while2 = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while1)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -449,13 +485,21 @@ TEST_F(HloAliasAnalysisTest, NestedWhiles) { const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); - auto cond_builder = HloComputation::Builder("condition"); - cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, tuple_shape, "param")); - cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); - HloComputation* condition = - module_.AddEmbeddedComputation(cond_builder.Build()); + auto build_cond_computation = [&tuple_shape]() { + auto cond_builder = HloComputation::Builder("condition"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + return cond_builder.Build(); + }; + // Build separate condition computations so the call graph is flat. The + // callgraph is always flattened in the compiler pipeline, and the flattened + // callgraph enables representative interference analysis. + HloComputation* condition1 = + module_->AddEmbeddedComputation(build_cond_computation()); + HloComputation* condition2 = + module_->AddEmbeddedComputation(build_cond_computation()); // Element 0 passes transparently through the body. auto inner_builder = HloComputation::Builder("inner_body"); @@ -470,7 +514,7 @@ TEST_F(HloAliasAnalysisTest, NestedWhiles) { inner_builder.AddInstruction( HloInstruction::CreateTuple({inner_element_0, add})); HloComputation* inner_body = - module_.AddEmbeddedComputation(inner_builder.Build()); + module_->AddEmbeddedComputation(inner_builder.Build()); // Element 1 passes transparently through the body. auto outer_builder = HloComputation::Builder("outer_body"); @@ -485,20 +529,20 @@ TEST_F(HloAliasAnalysisTest, NestedWhiles) { auto outer_tuple = outer_builder.AddInstruction( HloInstruction::CreateTuple({negate, outer_element_1})); auto nested_while = outer_builder.AddInstruction(HloInstruction::CreateWhile( - tuple_shape, condition, inner_body, outer_tuple)); + tuple_shape, condition1, inner_body, outer_tuple)); HloComputation* outer_body = - module_.AddEmbeddedComputation(outer_builder.Build()); + module_->AddEmbeddedComputation(outer_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto entry_while = builder.AddInstruction( - HloInstruction::CreateWhile(tuple_shape, condition, outer_body, tuple)); - module_.AddEntryComputation(builder.Build()); + HloInstruction::CreateWhile(tuple_shape, condition2, outer_body, tuple)); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -515,6 +559,8 @@ TEST_F(HloAliasAnalysisTest, NestedWhiles) { analysis.GetUniqueBufferAt(nested_while, /*index=*/{1})); EXPECT_EQ(analysis.GetUniqueBufferAt(constant2), analysis.GetUniqueBufferAt(inner_element_1)); + + EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } TEST_F(HloAliasAnalysisTest, SwizzlingWhile) { @@ -548,28 +594,28 @@ TEST_F(HloAliasAnalysisTest, SwizzlingWhile) { HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 2)); body_builder.AddInstruction(HloInstruction::CreateTuple( {body_element_1, body_element_2, body_element_0})); - HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build()); + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); auto cond_builder = HloComputation::Builder("condition"); cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); auto cond_constant = cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); HloComputation* condition = - module_.AddEmbeddedComputation(cond_builder.Build()); + module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + HloInstruction::CreateConstant(Literal::CreateR0(3.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2, constant3})); auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, tuple)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -593,6 +639,10 @@ TEST_F(HloAliasAnalysisTest, SwizzlingWhile) { analysis.GetUniqueBufferAt(constant2)); EXPECT_EQ(analysis.GetUniqueBufferAt(constant1), analysis.GetUniqueBufferAt(constant3)); + + // All elements in of the loop state tuple are forced into the same buffer + // resulting liveness interference. + EXPECT_TRUE(AnyValuesInSameBufferInterfere()); } TEST_F(HloAliasAnalysisTest, TupleSelect) { @@ -600,15 +650,15 @@ TEST_F(HloAliasAnalysisTest, TupleSelect) { // instruction. auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + HloInstruction::CreateConstant(Literal::CreateR0(3.0))); auto constant4 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(4.0))); + HloInstruction::CreateConstant(Literal::CreateR0(4.0))); auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({constant1})); auto tuple2 = @@ -627,7 +677,7 @@ TEST_F(HloAliasAnalysisTest, TupleSelect) { auto select1234 = builder.AddInstruction(HloInstruction::CreateTernary( tuple_shape, HloOpcode::kSelect, pred, select12, select34)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -655,6 +705,8 @@ TEST_F(HloAliasAnalysisTest, TupleSelect) { EXPECT_TRUE(analysis.GetInstructionBufferSet(select12).IsDistinct()); EXPECT_TRUE(analysis.GetInstructionBufferSet(select34).IsDistinct()); EXPECT_TRUE(analysis.GetInstructionBufferSet(select1234).IsDistinct()); + + EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } TEST_F(HloAliasAnalysisTest, TupleSelectToWhile) { @@ -688,22 +740,22 @@ TEST_F(HloAliasAnalysisTest, TupleSelectToWhile) { auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary( scalar_shape_, HloOpcode::kNegate, body_element)); body_builder.AddInstruction(HloInstruction::CreateTuple({negate})); - HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build()); + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); auto cond_builder = HloComputation::Builder("condition"); auto cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); HloComputation* condition = - module_.AddEmbeddedComputation(cond_builder.Build()); + module_->AddEmbeddedComputation(cond_builder.Build()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({constant1})); auto tuple2 = @@ -713,7 +765,7 @@ TEST_F(HloAliasAnalysisTest, TupleSelectToWhile) { auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, select)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -736,17 +788,21 @@ TEST_F(HloAliasAnalysisTest, TupleSelectToWhile) { EXPECT_TRUE(analysis.GetInstructionBufferSet(select).IsDistinct()); EXPECT_TRUE(analysis.GetInstructionBufferSet(xla_while).IsDistinct()); + + // The two operands of the select get flattened into the same buffer resulting + // in liveness interference. + EXPECT_TRUE(AnyValuesInSameBufferInterfere()); } TEST_F(HloAliasAnalysisTest, Bitcast) { // Bitcasting a value should not produce a new buffer. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary( scalar_shape_, HloOpcode::kBitcast, constant)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); diff --git a/tensorflow/compiler/xla/service/hlo_buffer.cc b/tensorflow/compiler/xla/service/hlo_buffer.cc new file mode 100644 index 0000000000000000000000000000000000000000..f289f56966ac2e9350ff32e51b1081ed94554d96 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_buffer.cc @@ -0,0 +1,139 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_buffer.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +using ::tensorflow::str_util::Join; +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; + +void HloBuffer::AddValue(const HloValue& value) { + // If the value is already contained in this buffer, just return. + if (std::find(value_ids_.begin(), value_ids_.end(), value.id()) != + value_ids_.end()) { + return; + } + + value_ids_.push_back(value.id()); + + // Add all of the locations of the HloValue to this buffer. + for (const HloLocation& location : value.locations()) { + if (std::find(locations_.begin(), locations_.end(), location) == + locations_.end()) { + locations_.push_back(location); + } + } +} + +bool HloBuffer::operator==(const HloBuffer& other) const { + bool equal = id() == other.id(); + if (equal) { + // DCHECK because these comparisons are expensive (linear time). + DCHECK(value_ids() == other.value_ids()); + DCHECK(locations() == other.locations()); + } + return equal; +} + +string HloBuffer::ToString() const { + return StrCat("HloBuffer ", id_, ", values: ", Join(value_ids_, ", ")); +} + +std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer) { + out << buffer.ToString(); + return out; +} + +void HloBufferSet::AddBuffer(HloBuffer::Id buffer_id) { + if (std::find(buffer_ids_.begin(), buffer_ids_.end(), buffer_id) == + buffer_ids_.end()) { + buffer_ids_.push_back(buffer_id); + } +} + +void HloBufferSet::RemoveBufferOrDie(HloBuffer::Id buffer_id) { + auto it = std::find(buffer_ids_.begin(), buffer_ids_.end(), buffer_id); + CHECK(it != buffer_ids_.end()); + buffer_ids_.erase(it); +} + +string HloBufferSet::ToString() const { + return StrCat("HloBufferSet, buffers: ", Join(buffer_ids_, ", ")); +} + +std::ostream& operator<<(std::ostream& out, const HloBufferSet& buffer_set) { + out << buffer_set.ToString(); + return out; +} + +bool InstructionBufferSet::IsAmbiguous() const { + bool is_ambiguous = false; + ForEachElement( + [&is_ambiguous](const ShapeIndex& index, const HloBufferSet& buffer_set) { + is_ambiguous |= buffer_set.buffer_ids().size() > 1; + }); + return is_ambiguous; +} + +bool InstructionBufferSet::IsDistinct() const { + bool is_distinct = true; + tensorflow::gtl::FlatSet seen_ids; + ForEachElement([&is_distinct, &seen_ids](const ShapeIndex& index, + const HloBufferSet& buffer_set) { + for (HloBuffer::Id buffer_id : buffer_set.buffer_ids()) { + auto pair = seen_ids.insert(buffer_id); + if (!pair.second) { + is_distinct = false; + } + } + }); + return is_distinct; +} + +string InstructionBufferSet::ToString() const { + string out = + StrCat("InstructionBufferSet(", ShapeUtil::HumanString(shape()), ")\n"); + ForEachElement([this, &out](const ShapeIndex& index, + const HloBufferSet& value_set) { + StrAppend(&out, " ", index.ToString(), " : ", value_set.ToString(), "\n"); + }); + return out; +} + +std::ostream& operator<<(std::ostream& out, + const InstructionBufferSet& buffer_set) { + out << buffer_set.ToString(); + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_buffer.h b/tensorflow/compiler/xla/service/hlo_buffer.h new file mode 100644 index 0000000000000000000000000000000000000000..e38499d210bcc59730a7ea2bcb09958bec9942af --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_buffer.h @@ -0,0 +1,183 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_BUFFER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_BUFFER_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_value.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { + +// A container which can hold one or more HloValues. An HLO buffer abstractly +// represents the allocation which HLO instructions write into and read +// from. Generally there is a one-to-one correspondence between HloBuffers and +// HloValue where each HloValue in the module is held in a unique HloBuffer. An +// exception is the while instruction which updates the loop state in-place. In +// this case, we have a single HloBuffer for each HloLocation in the loop state, +// but multiple HloValues. For example: +// +// %init = ... +// %while = While(%init, body, condition) +// +// body: +// %body_param = Param(0) +// ... +// %body_root = ... +// +// condition: +// %cond_param = Param(0) +// ... +// +// For simplicity, assume that %while is array-shaped. In this case, we have a +// single HloBuffer which holds the following HloValues: HloValue{%init}, +// HloValue{%while}, HloValue{%body_param}, HloValue{%body_root}, and +// HloValue{%cond_param}. +// +// HloBuffers may appear at different HloLocations in the module mirroring the +// same propery of HloValues. For example: +// +// %sub = Sub(...) +// %add = Add(...) +// %tuple = Tuple(%add, %sub) +// %gte = GetTupleElement(%tuple, 0) +// +// In this case, the HloBuffer containing %add appears at the following +// locations: HloLocation{%add, {}}, HloLocation{%tuple, {0}}, and +// HloLocation{%gte, {}}. +// +// Different HloLocations which share the same HloBuffer indicate mandatory +// aliasing in the HLO module. These locations must share the same memory +// allocation for correctness (the backends rely on this property). This differs +// from incidental aliasing introduced by memory reuse in BufferAssignment where +// different instructions may happen to get the same allocation. +class HloBuffer { + public: + using Id = int64; + + HloBuffer(Id id) : id_(id) {} + + // Return the unique identifier for this HloBuffer. + Id id() const { return id_; } + + // Add a value to the set of values held by this buffer. Also adds the + // HloLocations of the value to the locations vector of the buffer. If the + // buffer already contains this value, then this method is a nop. + void AddValue(const HloValue& value); + + // Return the IDs of all values contained in this buffer. + const std::vector& value_ids() const { return value_ids_; } + + // Return the locations (output of which instruction and at what index) where + // the buffer is used. This is exactly the union of the locations of the + // HloValues contained by the buffer. + const std::vector& locations() const { return locations_; } + + string ToString() const; + + bool operator==(const HloBuffer& other) const; + bool operator!=(const HloBuffer& other) const { return !(*this == other); } + + private: + // Unique identifier for this HloBuffer. + const Id id_; + + // The set of values contained in this buffer. + std::vector value_ids_; + + // The set of locations where this buffer is used. + std::vector locations_; +}; + +std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer); + +// A class representing the set of possible HloBuffers at a particular +// HloLocation (shape index in the output of an instruction) in the XLA +// graph. In most cases, the buffer set will have a single HloBuffer indicating +// that the HloBuffer which appears at that particular location is known +// unambiguously at compile-time. However, tuple-shaped Select instructions can +// introduce ambiguity as the tuple elements of the operands are passed by +// reference into the output of the Select. For example: +// +// %pred = ... +// %tuple0 = Tuple(%a, %b) +// %tuple1 = Tuple(%x, %y) +// %select = Select(%pred, %tuple0, %tuple1) +// +// In this case the HloBufferSet at HloLocation{%select, {0}} contains the +// HloBuffer holding %a and the HloBuffer holding %x. +class HloBufferSet { + public: + HloBufferSet() = default; + + // Add the given buffer to this buffer set. If the buffer already exists in + // the set, then this is a NOP. + void AddBuffer(HloBuffer::Id buffer_id); + + // Removes the given buffer from this buffer set. CHECK fails in the buffer is + // not contained in this set. + void RemoveBufferOrDie(HloBuffer::Id buffer_id); + + // Returns the unique buffer in this set. CHECK fails if the set does not + // contain exactly one buffer. + HloBuffer::Id GetUniqueBufferId() const { + CHECK_EQ(buffer_ids().size(), 1); + return buffer_ids()[0]; + } + + // Returns the IDs of the HloBuffers contained in this buffer set. + const std::vector& buffer_ids() const { return buffer_ids_; } + + string ToString() const; + + private: + // The IDs of the HloBuffers containted in this buffer set. + std::vector buffer_ids_; +}; + +std::ostream& operator<<(std::ostream& out, const HloBufferSet& buffer_set); + +// A class collecting the HloBuffers in the output of an HLO instruction. For +// array-shaped instructions, an InstructionBufferSet trivially holds a single +// HloBufferSet. Tuple-shaped InstructionBufferSets hold multiple +// HloBufferSets. +class InstructionBufferSet : public ShapeTree { + public: + InstructionBufferSet(const Shape& shape) : ShapeTree(shape) {} + + // Returns true if any HloBufferSet contained in this InstructionBufferSet + // is not a singleton. + bool IsAmbiguous() const; + + // Returns true if any HloBuffer appears in more than one HloBufferSet + // contained in this InstructionBufferSet. + bool IsDistinct() const; + + string ToString() const; +}; + +std::ostream& operator<<(std::ostream& out, + const InstructionBufferSet& buffer_set); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_BUFFER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index ff76cc7bf67e29d489f9b32e4fce94ce28b59992..6a5533c4696b83b9468617772a71769367f9481a 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -66,22 +67,25 @@ HloComputation::HloComputation( HloInstruction* root_instruction, bool is_fusion_computation) : name_(name), root_instruction_(root_instruction), - is_fusion_computation_(is_fusion_computation), - instruction_name_uniquer_(/*separator=*/".") { + is_fusion_computation_(is_fusion_computation) { param_instructions_.resize(parameter_count, nullptr); bool root_found = false; for (auto& instruction : *instructions) { if (instruction->opcode() == HloOpcode::kParameter) { int64 param_no = instruction->parameter_number(); - CHECK_GE(param_no, 0); - CHECK_LT(param_no, param_instructions_.size()); - CHECK_EQ(nullptr, param_instructions_[param_no]); + CHECK(param_no >= 0 && param_no < parameter_count) + << "\nERROR: invalid parameter number. Expected [0, " + << parameter_count << "), got " << param_no; + CHECK(param_instructions_[param_no] == nullptr) + << "\nERROR: parameter number " << param_no + << " already allocated in this computation"; param_instructions_[param_no] = instruction.get(); } root_found |= instruction.get() == root_instruction_; AddInstructionInternal(std::move(instruction)); } - CHECK(root_found); + CHECK(root_found) + << "\nERROR: root instruction is not present in computation."; } HloInstruction* HloComputation::AddInstruction( @@ -94,8 +98,9 @@ HloInstruction* HloComputation::AddInstruction( HloInstruction* HloComputation::AddInstructionInternal( std::unique_ptr instruction) { - // Generate a unique name for the instruction. - instruction->UniquifyName(&instruction_name_uniquer_); + if (parent() != nullptr) { + instruction->UniquifyName(&parent()->instruction_name_uniquer()); + } Reparent(instruction.get()); HloInstruction* pinst = instruction.get(); instruction_iterators_[pinst] = @@ -537,67 +542,46 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, return RemoveInstructionAndUnusedOperands(old_instruction); } -HloComputation::ReachabilityMap::ReachabilityMap( - const std::list& all_instructions) { - const int n = all_instructions.size(); - int next_id = 0; - for (const auto* hlo : all_instructions) { - ids_[hlo] = next_id; - next_id++; - } - DCHECK_EQ(n, ids_.size()); // instructions should be unique - matrix_.Reset(n * n); -} - -void HloComputation::ReachabilityMap::SetReachable(const HloInstruction* a, - const HloInstruction* b) { - const int id_a = FindOrDie(ids_, a); - const int id_b = FindOrDie(ids_, b); - matrix_.set(id_a * ids_.size() + id_b); -} +std::unique_ptr HloComputation::ComputeReachability() + const { + const std::list all = MakeInstructionPostOrder(); + auto result = MakeUnique(all); -bool HloComputation::ReachabilityMap::IsReachable( - const HloInstruction* a, const HloInstruction* b) const { - const int id_a = FindOrDie(ids_, a); - const int id_b = FindOrDie(ids_, b); - return matrix_.get(id_a * ids_.size() + id_b); + std::vector inputs; + for (const HloInstruction* hlo : all) { + inputs.assign(hlo->operands().begin(), hlo->operands().end()); + inputs.insert(inputs.end(), hlo->control_predecessors().begin(), + hlo->control_predecessors().end()); + result->SetReachabilityToUnion(inputs, hlo); + } + return result; } -bool HloComputation::ReachabilityMap::IsConnected( - const HloInstruction* a, const HloInstruction* b) const { - const int id_a = FindOrDie(ids_, a); - const int id_b = FindOrDie(ids_, b); - return matrix_.get(id_a * ids_.size() + id_b) || - matrix_.get(id_b * ids_.size() + id_a); -} +void HloComputation::UpdateReachabilityThroughInstruction( + const HloInstruction* instruction, HloReachabilityMap* reachability_map) { + std::queue worklist; + worklist.push(instruction); -void HloComputation::ReachabilityMap::SetReachableAndTransitiveClosure( - const HloInstruction* a, const HloInstruction* b) { - const int id_a = FindOrDie(ids_, a); - const int id_b = FindOrDie(ids_, b); - const int n = ids_.size(); - matrix_.set(id_a * n + id_b); + std::vector inputs; - // Copy transitive set for b into entries for a - for (int i = 0; i < n; i++) { - if (matrix_.get(id_b * n + i)) { - matrix_.set(id_a * n + i); - } - } -} + while (!worklist.empty()) { + const HloInstruction* item = worklist.front(); + worklist.pop(); -std::unique_ptr -HloComputation::ComputeTransitiveOperands() const { - const auto all = MakeInstructionPostOrder(); - auto result = MakeUnique(all); + inputs.assign(item->operands().begin(), item->operands().end()); + inputs.insert(inputs.end(), item->control_predecessors().begin(), + item->control_predecessors().end()); - // Fill in the dependency bit matrix - for (const auto* hlo : all) { - for (const HloInstruction* operand : hlo->operands()) { - result->SetReachableAndTransitiveClosure(hlo, operand); + if (reachability_map->SetReachabilityToUnion(inputs, item)) { + // Add immediate successors to worklist. + for (const HloInstruction* user : item->users()) { + worklist.push(user); + } + for (const HloInstruction* succ : item->control_successors()) { + worklist.push(succ); + } } } - return result; } std::vector HloComputation::CollectUnreachableRoots() const { diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 39074b24e41f073b6b5b60880cbd1f6e2e9b399d..cf6df3c94f885816d20530161822f7cc948a30be 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -29,11 +29,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/bitmap.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -153,9 +153,18 @@ class HloComputation { // this order, definitions of values always appear before their uses. std::list MakeInstructionPostOrder() const; - // Computes and returns the mapping from HLO to its transitive operands. - class ReachabilityMap; - std::unique_ptr ComputeTransitiveOperands() const; + // Computes and returns the reachability between HLO instructions in the + // computation. The returned HloReachabilityMap is constructed such that + // HloReachabilityMap::IsReachable(a, b) returns true iff there exists a + // directed path (from producer to consumer) from 'a' to 'b'. Both data + // dependencies (operands) and control dependencies are considered for + // reachability. Trivially an instruction is reachable from itself. + std::unique_ptr ComputeReachability() const; + + // Updates the given reachabilty map after the immediate predecessor set + // (operands and control predecessors) of 'instruction' has changed. + void UpdateReachabilityThroughInstruction( + const HloInstruction* instruction, HloReachabilityMap* reachability_map); int64 instruction_count() const { return instructions_.size(); } @@ -308,34 +317,6 @@ class HloComputation { TF_DISALLOW_COPY_AND_ASSIGN(HloComputation); }; -class HloComputation::ReachabilityMap { - public: - // Sets up an empty reachable matrix for the full set of - // instructions specified in "all_instructions" - explicit ReachabilityMap(const std::list& all_instructions); - // Sets entry so that IsReachable(a, b) will return true - void SetReachable(const HloInstruction* a, const HloInstruction* b); - - // Sets IsReachable(a_inst, b_inst) as well as IsReachable(a_inst, trans) - // for all "trans" s.t. "IsReachable(b_inst, trans)" is true - void SetReachableAndTransitiveClosure(const HloInstruction* a_inst, - const HloInstruction* b_inst); - - // Returns true if "b" is reachable from "a" - bool IsReachable(const HloInstruction* a, const HloInstruction* b) const; - - // Returns true if "b" is reachable from "a" or "a" is reachable from "b" - bool IsConnected(const HloInstruction* a, const HloInstruction* b) const; - - private: - friend class HloComputation; - - // dense id assignment from HloInstruction* to number - tensorflow::gtl::FlatMap ids_; - // matrix_(a,b) is true iff b is reachable from a - tensorflow::core::Bitmap matrix_; -}; - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_ diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 5d49c83e2d070cb9e5409a62983940225b903b2b..4a4a8556692b3da6f92f8333397a9537ade2f8ef 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -110,7 +110,7 @@ TEST_F(HloComputationTest, PostOrderSingleton) { // Test GetInstructionPostOrder for a computation with one instruction. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto computation = builder.Build(); EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant)); @@ -121,7 +121,7 @@ TEST_F(HloComputationTest, PostOrderSimple) { // instructions. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto negate1 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto negate2 = builder.AddInstruction( @@ -136,7 +136,7 @@ TEST_F(HloComputationTest, PostOrderTrace) { // Test GetInstructionPostOrder for a computation with a trace instruction. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto negate1 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto trace = @@ -155,13 +155,13 @@ TEST_F(HloComputationTest, PostOrderDisconnectedInstructions) { // which are not connected. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto constant4 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto computation = builder.Build(); EXPECT_THAT(computation->MakeInstructionPostOrder(), @@ -173,11 +173,11 @@ TEST_F(HloComputationTest, PostOrderWithMultipleRoots) { // which are not connected. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant2)); auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( @@ -197,11 +197,11 @@ TEST_F(HloComputationTest, VisitWithMultipleRoots) { // computation has multiple roots (dead code). auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); // Add three disconnected add expressions. builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, constant1, constant2)); @@ -248,7 +248,7 @@ TEST_F(HloComputationTest, DeepCopyArray) { // Test that DeepCopyInstruction properly copies an array. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); + Literal::CreateR1({1.0, 2.0, 3.0}))); auto computation = builder.Build(); auto copy = computation->DeepCopyInstruction(constant).ValueOrDie(); @@ -260,9 +260,9 @@ TEST_F(HloComputationTest, DeepCopyTuple) { // Test that DeepCopyInstruction properly copies a tuple. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); + Literal::CreateR1({1.0, 2.0, 3.0}))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); @@ -280,7 +280,7 @@ TEST_F(HloComputationTest, CycleDetection) { // Test whether the visitor can detect cycles in the graph. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto negate = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto add = builder.AddInstruction( @@ -303,7 +303,7 @@ TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) { // twice. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto dead_negate = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto dead_add = builder.AddInstruction(HloInstruction::CreateBinary( @@ -326,9 +326,9 @@ TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) { TEST_F(HloComputationTest, CloneWithControlDependency) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0f))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant2)); @@ -352,6 +352,105 @@ TEST_F(HloComputationTest, CloneWithControlDependency) { EXPECT_THAT(successors, ::testing::ElementsAre(cloned_add)); } +TEST_F(HloComputationTest, Reachability) { + // Test reachability of a non-trivial computation: + // + // const1 const2 + // | | + // | +-------+ + // | | | + // add .. negate + // | . | + // | .... exp + // | | + // +---+ +-+---+ + // | | | + // multiply copy + // + // There is a control dependency from 'add' to 'exp'. + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(2.0f))); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + r0f32_, HloOpcode::kAdd, constant1, constant2)); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant2)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, negate)); + auto mul = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kMultiply, add, exp)); + auto copy = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kCopy, exp)); + + auto computation = builder.Build(/*root_instruction=*/mul); + + TF_CHECK_OK(add->AddControlDependencyTo(exp)); + auto reachability = computation->ComputeReachability(); + + EXPECT_TRUE(reachability->IsReachable(constant1, constant1)); + EXPECT_FALSE(reachability->IsReachable(constant1, constant2)); + EXPECT_TRUE(reachability->IsReachable(constant1, add)); + EXPECT_FALSE(reachability->IsReachable(constant1, negate)); + EXPECT_TRUE(reachability->IsReachable(constant1, exp)); + EXPECT_TRUE(reachability->IsReachable(constant1, mul)); + EXPECT_TRUE(reachability->IsReachable(constant1, copy)); + + EXPECT_FALSE(reachability->IsReachable(constant2, constant1)); + EXPECT_TRUE(reachability->IsReachable(constant2, constant2)); + EXPECT_TRUE(reachability->IsReachable(constant2, add)); + EXPECT_TRUE(reachability->IsReachable(constant2, negate)); + EXPECT_TRUE(reachability->IsReachable(constant2, exp)); + EXPECT_TRUE(reachability->IsReachable(constant2, mul)); + EXPECT_TRUE(reachability->IsReachable(constant2, copy)); + + EXPECT_FALSE(reachability->IsReachable(exp, constant1)); + EXPECT_FALSE(reachability->IsReachable(exp, constant2)); + EXPECT_FALSE(reachability->IsReachable(exp, add)); + EXPECT_FALSE(reachability->IsReachable(exp, negate)); + EXPECT_TRUE(reachability->IsReachable(exp, exp)); + EXPECT_TRUE(reachability->IsReachable(exp, mul)); + EXPECT_TRUE(reachability->IsReachable(exp, copy)); + + EXPECT_FALSE(reachability->IsReachable(mul, constant1)); + EXPECT_FALSE(reachability->IsReachable(mul, constant2)); + EXPECT_FALSE(reachability->IsReachable(mul, add)); + EXPECT_FALSE(reachability->IsReachable(mul, negate)); + EXPECT_FALSE(reachability->IsReachable(mul, exp)); + EXPECT_TRUE(reachability->IsReachable(mul, mul)); + EXPECT_FALSE(reachability->IsReachable(mul, copy)); + + EXPECT_TRUE(reachability->IsConnected(constant1, copy)); + EXPECT_TRUE(reachability->IsConnected(copy, constant1)); + EXPECT_FALSE(reachability->IsConnected(negate, add)); + EXPECT_FALSE(reachability->IsConnected(add, negate)); + + // Remove the control dependency then update and verify the reachability map + ASSERT_IS_OK(add->RemoveControlDependencyTo(exp)); + computation->UpdateReachabilityThroughInstruction(exp, reachability.get()); + + EXPECT_TRUE(reachability->IsReachable(constant1, constant1)); + EXPECT_FALSE(reachability->IsReachable(constant1, constant2)); + EXPECT_TRUE(reachability->IsReachable(constant1, add)); + EXPECT_FALSE(reachability->IsReachable(constant1, negate)); + EXPECT_FALSE(reachability->IsReachable(constant1, exp)); + EXPECT_TRUE(reachability->IsReachable(constant1, mul)); + EXPECT_FALSE(reachability->IsReachable(constant1, copy)); + + // Change a use within the graph then update and verify the reachability map + ASSERT_IS_OK(constant2->ReplaceUseWith(negate, constant1)); + computation->UpdateReachabilityThroughInstruction(negate, reachability.get()); + + EXPECT_FALSE(reachability->IsReachable(constant2, constant1)); + EXPECT_TRUE(reachability->IsReachable(constant2, constant2)); + EXPECT_TRUE(reachability->IsReachable(constant2, add)); + EXPECT_FALSE(reachability->IsReachable(constant2, negate)); + EXPECT_FALSE(reachability->IsReachable(constant2, exp)); + EXPECT_TRUE(reachability->IsReachable(constant2, mul)); + EXPECT_FALSE(reachability->IsReachable(constant2, copy)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 93f448e701853e271646c9f8fb0d42f49489b756..804efdd906a176af38c2a8c1e93a849f71307ddf 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -58,6 +58,13 @@ StatusOr HloConstantFolding::Run(HloModule* module) { continue; } + // Broadcasts dramatically increase the size of constants with is often + // detrimental to performance and memory capacity so do not fold + // broadcasts. + if (instruction->opcode() == HloOpcode::kBroadcast) { + continue; + } + std::unique_ptr result = evaluator->TryEvaluate(instruction); // Currently we skip unimplemented operations. // TODO(b/35975797): Fold constant computations for more operations. diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 31b81052cb2b00e602b94b9d84525a623caa741e..1c60b06dddc8cf1f59bb1f8cb39b7d4d16019ba9 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -41,7 +41,7 @@ using HloConstantFoldingTest = HloTestBase; TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { HloComputation::Builder builder(TestName()); HloInstruction* input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {}), input)); @@ -55,15 +55,14 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); - EXPECT_EQ(LiteralUtil::GetFirstElement( - computation->root_instruction()->literal()), + EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement(), 42); } TEST_F(HloConstantFoldingTest, ConvertS64ToF32) { HloComputation::Builder builder(TestName()); HloInstruction* input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); + HloInstruction::CreateConstant(Literal::CreateR0(42))); builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); @@ -77,15 +76,14 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) { EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); - EXPECT_EQ(LiteralUtil::GetFirstElement( - computation->root_instruction()->literal()), + EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement(), 42.0f); } TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) { HloComputation::Builder builder(TestName()); - HloInstruction* input = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({42.0f, 19.0f}))); + HloInstruction* input = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({42.0f, 19.0f}))); builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {2}), input)); @@ -99,12 +97,8 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) { EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); - EXPECT_EQ( - LiteralUtil::Get(computation->root_instruction()->literal(), {0}), - 42); - EXPECT_EQ( - LiteralUtil::Get(computation->root_instruction()->literal(), {1}), - 19); + EXPECT_EQ(computation->root_instruction()->literal().Get({0}), 42); + EXPECT_EQ(computation->root_instruction()->literal().Get({1}), 19); } TEST_F(HloConstantFoldingTest, Concatenate) { @@ -126,7 +120,7 @@ TEST_F(HloConstantFoldingTest, Concatenate) { for (auto csize : test_config.concat_sizes) { dimensions[test_config.concat_dimension] = csize; concat_size += csize; - auto literal = LiteralUtil::CreateFromDimensions(F32, dimensions); + auto literal = Literal::CreateFromDimensions(F32, dimensions); HloInstruction* insn = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); operands.push_back(insn); @@ -180,7 +174,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { TF_ASSIGN_OR_ASSERT_OK(auto literal, LiteralTestUtil::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); - auto literal_clone = LiteralUtil::CloneToUnique(*literal); + auto literal_clone = literal->Literal::CloneToUnique(); HloInstruction* literal_instruction = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5}); @@ -200,12 +194,10 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { using NativeT = typename primitive_util::PrimitiveTypeToNative::type; bool matched = true; - LiteralUtil::EachCell( - root->literal(), + root->literal().EachCell( [&](tensorflow::gtl::ArraySlice indices, NativeT value) { std::vector rindexes = Permute(permutation, indices); - matched = matched && (value == LiteralUtil::Get(*literal_clone, - rindexes)); + matched = matched && (value == literal_clone->Get(rindexes)); }); EXPECT_TRUE(matched); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 38cc74b0f1e640d4e72188416258d9b262053152..522dddea4e8935a11dfdedbfa4d911cf5b0b124f 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -25,34 +25,56 @@ limitations under the License. namespace xla { +constexpr char HloCostAnalysis::kFlopsKey[]; +constexpr char HloCostAnalysis::kTranscendentalsKey[]; +constexpr char HloCostAnalysis::kBytesAccessedKey[]; +constexpr char HloCostAnalysis::kSecondsKey[]; + +HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size) + : HloCostAnalysis(shape_size, {}) {} + +HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size, + const Properties& per_second_rates) + : shape_size_(shape_size), per_second_rates_(per_second_rates) {} + Status HloCostAnalysis::Preprocess(HloInstruction* hlo) { // Set current instruction cost values to reasonable default values. Each - // handler can overwrite these values. In Postprocess, these value are + // handler can overwrite these values. In Postprocess, these values are // accumulated and written to the per-instruction maps. - current_flop_count_ = 0; - current_transcendental_count_ = 0; + current_properties_.clear(); + current_should_compute_bottleneck_time_ = true; - // The default element count for an instruction is the sum of elements in the - // operands and output. The default ShapeUtil::ByteSizeOf does not handle - // opaque types. - current_bytes_accessed_ = shape_size_(hlo->shape()); + // The default number of bytes accessed for an instruction is the sum of the + // sizes of the inputs and outputs. The default ShapeUtil::ByteSizeOf does not + // handle opaque types. + float bytes_accessed = shape_size_(hlo->shape()); for (const HloInstruction* operand : hlo->operands()) { - current_bytes_accessed_ += shape_size_(operand->shape()); + bytes_accessed += shape_size_(operand->shape()); } + current_properties_[kBytesAccessedKey] = bytes_accessed; return Status::OK(); } Status HloCostAnalysis::Postprocess(HloInstruction* hlo) { - // Accumulate cost values and write into per-instruction maps. - flop_count_ += current_flop_count_; - hlo_to_flop_count_[hlo] = current_flop_count_; - - transcendental_count_ += current_transcendental_count_; - hlo_to_transcendental_count_[hlo] = current_transcendental_count_; + if (current_should_compute_bottleneck_time_) { + // Compute the time as the time of the bottleneck, i.e. the slowest property + // given the per-second rate of each property. + float max_seconds = 0.0f; + for (const auto& property : current_properties_) { + if (property.first != kSecondsKey) { + max_seconds = std::max( + max_seconds, + property.second / GetProperty(property.first, per_second_rates_)); + } + } + current_properties_[kSecondsKey] = max_seconds; + } - bytes_accessed_ += current_bytes_accessed_; - hlo_to_bytes_accessed_[hlo] = current_bytes_accessed_; + TF_RET_CHECK(hlo_properties_.emplace(hlo, current_properties_).second); + for (const auto& property : current_properties_) { + properties_sum_[property.first] += property.second; + } return Status::OK(); } @@ -65,25 +87,39 @@ Status HloCostAnalysis::HandleElementwiseOp(HloInstruction* hlo_instruction) { auto opcode = hlo_instruction->opcode(); // We treat the two opcodes (kExp, kPower) as transcendental operations. if (opcode == HloOpcode::kExp || opcode == HloOpcode::kPower) { - current_transcendental_count_ = computation_count; + current_properties_[kTranscendentalsKey] = computation_count; } else { // Note: transcendental operations are considered a separate category from // FLOPs. - current_flop_count_ = computation_count; + current_properties_[kFlopsKey] = computation_count; } return Status::OK(); } +/*static*/ float HloCostAnalysis::GetProperty(const string& key, + const Properties& properties) { + auto key_value = properties.find(key); + return key_value == properties.end() ? 0.0f : key_value->second; +} + +/*static*/ float HloCostAnalysis::GetPropertyForHlo( + const HloInstruction& hlo, const string& key, + const HloToProperties& hlo_to_properties) { + auto it = hlo_to_properties.find(&hlo); + if (it == hlo_to_properties.end()) { + return 0.0f; + } else { + return GetProperty(key, it->second); + } +} + Status HloCostAnalysis::HandleElementwiseUnary(HloInstruction* hlo, - HloOpcode opcode, - HloInstruction* operand) { + HloOpcode opcode) { return HandleElementwiseOp(hlo); } Status HloCostAnalysis::HandleElementwiseBinary(HloInstruction* hlo, - HloOpcode opcode, - HloInstruction* lhs, - HloInstruction* rhs) { + HloOpcode opcode) { return HandleElementwiseOp(hlo); } @@ -100,14 +136,18 @@ Status HloCostAnalysis::HandleClamp(HloInstruction* clamp, return HandleElementwiseOp(clamp); } +Status HloCostAnalysis::HandleReducePrecision(HloInstruction* hlo) { + return HandleElementwiseOp(hlo); +} + Status HloCostAnalysis::HandleParameter(HloInstruction* parameter) { - current_bytes_accessed_ = 0; + current_properties_[kBytesAccessedKey] = 0; return Status::OK(); } Status HloCostAnalysis::HandleConstant(HloInstruction* constant, const Literal& literal) { - current_bytes_accessed_ = 0; + current_properties_[kBytesAccessedKey] = 0; return Status::OK(); } @@ -115,7 +155,7 @@ Status HloCostAnalysis::HandleGetTupleElement(HloInstruction* get_tuple_element, HloInstruction* operand) { // GetTupleElement forwards a pointer and does not touch each element in the // output. - current_bytes_accessed_ = 0; + current_properties_[kBytesAccessedKey] = 0; return Status::OK(); } @@ -153,8 +193,9 @@ Status HloCostAnalysis::HandleTuple( tensorflow::gtl::ArraySlice operands) { // The tuple instruction only gathers pointers from inputs (it doesn't iterate // through them). The memory touched is then only the size of the output - // buffer. - current_bytes_accessed_ = shape_size_(tuple->shape()); + // index table of the tuple. + + current_properties_[kBytesAccessedKey] = shape_size_(tuple->shape()); return Status::OK(); } @@ -164,13 +205,11 @@ Status HloCostAnalysis::HandleConcatenate( return Status::OK(); } -Status HloCostAnalysis::HandleConvert(HloInstruction* convert, - HloInstruction* operand) { +Status HloCostAnalysis::HandleConvert(HloInstruction* convert) { return HandleElementwiseOp(convert); } -Status HloCostAnalysis::HandleCopy(HloInstruction* copy, - HloInstruction* operand) { +Status HloCostAnalysis::HandleCopy(HloInstruction* copy) { return Status::OK(); } @@ -194,7 +233,7 @@ Status HloCostAnalysis::HandleDot(HloInstruction* dot, } // We count an FMA operation as 2 floating point operations. - current_flop_count_ = kFmaFlops * fma_count; + current_properties_[kFlopsKey] = kFmaFlops * fma_count; return Status::OK(); } @@ -210,16 +249,17 @@ Status HloCostAnalysis::HandleMap( HloInstruction* map, tensorflow::gtl::ArraySlice operands, HloComputation* function, tensorflow::gtl::ArraySlice /*static_operands*/) { - // Compute the cost of the user function. - HloInstruction* function_instruction = function->root_instruction(); - HloCostAnalysis visitor(shape_size_); - TF_RETURN_IF_ERROR(function_instruction->Accept(&visitor)); + // Compute properties of the mapped function. + TF_ASSIGN_OR_RETURN(const Properties sub_properties, + ProcessSubcomputation(function)); // Compute the cost of all elements for this Map operation. - int64 element_count = ShapeUtil::ElementsIn(map->shape()); - current_transcendental_count_ = - element_count * visitor.transcendental_count(); - current_flop_count_ = element_count * visitor.flop_count(); + const int64 element_count = ShapeUtil::ElementsIn(map->shape()); + for (const auto& property : sub_properties) { + if (property.first != kBytesAccessedKey) { + current_properties_[property.first] = property.second * element_count; + } + } return Status::OK(); } @@ -227,16 +267,17 @@ Status HloCostAnalysis::HandleReduce( HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, tensorflow::gtl::ArraySlice dimensions, HloComputation* function) { // Compute the cost of the user function. - HloInstruction* function_instruction = function->root_instruction(); - HloCostAnalysis visitor(shape_size_); - TF_RETURN_IF_ERROR(function_instruction->Accept(&visitor)); + TF_ASSIGN_OR_RETURN(const Properties sub_properties, + ProcessSubcomputation(function)); // Compute the cost of all elements for this Reduce operation. int64 reduction_count = ShapeUtil::ElementsIn(arg->shape()) - ShapeUtil::ElementsIn(reduce->shape()); - current_flop_count_ = reduction_count * visitor.flop_count(); - current_transcendental_count_ = - reduction_count * visitor.transcendental_count(); + for (const auto& property : sub_properties) { + if (property.first != kBytesAccessedKey) { + current_properties_[property.first] = property.second * reduction_count; + } + } return Status::OK(); } @@ -244,55 +285,63 @@ Status HloCostAnalysis::HandleReduceWindow(HloInstruction* reduce_window, HloInstruction* operand, const Window& window, HloComputation* function) { - // Compute the cost of the user function. - HloInstruction* function_instruction = function->root_instruction(); - HloCostAnalysis visitor(shape_size_); - TF_RETURN_IF_ERROR(function_instruction->Accept(&visitor)); + // Compute the properties of the reduction function. + TF_ASSIGN_OR_RETURN(const Properties sub_properties, + ProcessSubcomputation(function)); // Compute the cost of all elements for this ReduceWindow operation. For each - // output element, (window_size - 1) number of user computations are applied. - auto output_size = ShapeUtil::ElementsIn(reduce_window->shape()); - int64 window_size = 1; + // output element there are window_size - 1 reductions to perform. + int64 window_element_count = 1; for (const auto& dimension : window.dimensions()) { - window_size *= dimension.size(); + window_element_count *= dimension.size(); + } + const int64 output_element_count = + ShapeUtil::ElementsIn(reduce_window->shape()); + const int64 reduction_count = + (window_element_count - 1) * output_element_count; + for (const auto& property : sub_properties) { + if (property.first != kBytesAccessedKey) { + current_properties_[property.first] = property.second * reduction_count; + } } - current_flop_count_ = output_size * (window_size - 1) * visitor.flop_count(); - current_transcendental_count_ = - output_size * (window_size - 1) * visitor.transcendental_count(); return Status::OK(); } Status HloCostAnalysis::HandleSelectAndScatter(HloInstruction* instruction) { - // Compute the cost of the select and scatter function. - HloInstruction* select = instruction->select()->root_instruction(); - HloCostAnalysis select_visitor(shape_size_); - TF_RETURN_IF_ERROR(select->Accept(&select_visitor)); - HloInstruction* scatter = instruction->scatter()->root_instruction(); - HloCostAnalysis scatter_visitor(shape_size_); - TF_RETURN_IF_ERROR(scatter->Accept(&scatter_visitor)); + // Compute the properties of the select and scatter function. + // Compute the properties of the reduction function. + TF_ASSIGN_OR_RETURN(const Properties select_properties, + ProcessSubcomputation(instruction->select())); + TF_ASSIGN_OR_RETURN(const Properties scatter_properties, + ProcessSubcomputation(instruction->scatter())); // Compute the cost of all elements for this operation. For each scatter - // source element, (window_size - 1) number of select computations and 1 - // scatter computation are applied. + // source element there are window_size - 1 select computations to perform and + // 1 scatter computation to perform. const auto source = instruction->operand(1); const auto source_element_count = ShapeUtil::ElementsIn(source->shape()); - int64 window_size = 1; + int64 window_element_count = 1; for (const auto& dimension : instruction->window().dimensions()) { - window_size *= dimension.size(); + window_element_count *= dimension.size(); + } + const int64 select_count = source_element_count * (window_element_count - 1); + for (const auto& property : select_properties) { + if (property.first != kBytesAccessedKey) { + current_properties_[property.first] += property.second * select_count; + } + } + for (const auto& property : scatter_properties) { + if (property.first != kBytesAccessedKey) { + current_properties_[property.first] += + property.second * source_element_count; + } } - current_flop_count_ = - source_element_count * ((window_size - 1) * select_visitor.flop_count() + - scatter_visitor.flop_count()); - current_transcendental_count_ = - source_element_count * - ((window_size - 1) * select_visitor.transcendental_count() + - scatter_visitor.transcendental_count()); return Status::OK(); } Status HloCostAnalysis::HandleBitcast(HloInstruction* bitcast) { // A bitcast does no computation and touches no memory. - current_bytes_accessed_ = 0; + current_properties_[kBytesAccessedKey] = 0; return Status::OK(); } @@ -314,6 +363,12 @@ Status HloCostAnalysis::HandleReshape(HloInstruction* reshape) { return Status::OK(); } +Status HloCostAnalysis::HandleBatchNormTraining( + HloInstruction* batchNormTraining) { + // TODO(b/62294698): Implement cost analysis for batch-norm-learning. + return Status::OK(); +} + Status HloCostAnalysis::HandleTranspose(HloInstruction* transpose) { return Status::OK(); } @@ -326,12 +381,13 @@ Status HloCostAnalysis::HandleConvolution(HloInstruction* convolution, const int64 output_features = convolution->shape().dimensions(dnums.feature_dimension()); - // For each output element, we do one fma per element in the - // kernel at some given output feature index. + // For each output element, we do one fma per element in the kernel at some + // given output feature index. const int64 fmas_per_output_element = ShapeUtil::ElementsIn(rhs_instruction->shape()) / output_features; const int64 output_elements = ShapeUtil::ElementsIn(convolution->shape()); - current_flop_count_ = output_elements * fmas_per_output_element * kFmaFlops; + current_properties_[kFlopsKey] = + output_elements * fmas_per_output_element * kFmaFlops; return Status::OK(); } @@ -341,7 +397,7 @@ Status HloCostAnalysis::HandleCrossReplicaSum(HloInstruction* crs) { // // TODO(b/33004697): Compute correct cost here, taking the actual number of // replicas into account. - current_flop_count_ = ShapeUtil::ElementsIn(crs->shape()); + current_properties_[kFlopsKey] = ShapeUtil::ElementsIn(crs->shape()); return Status::OK(); } @@ -350,31 +406,43 @@ Status HloCostAnalysis::HandleRng(HloInstruction* random, // TODO(b/26346211): Implement better estimates for the RNG cost, since the // cost changes with the implementation and the distribution. For now, assume // the cost of each RNG is same as a transcendental operation. - current_transcendental_count_ = ShapeUtil::ElementsIn(random->shape()); + current_properties_[kTranscendentalsKey] = + ShapeUtil::ElementsIn(random->shape()); return Status::OK(); } Status HloCostAnalysis::HandleFusion(HloInstruction* fusion) { - // Compute the cost of the fused expression. - HloInstruction* fused_expression_root = fusion->fused_expression_root(); - // Don't compute sizes inside of fused ops. We don't use the size here and the - // operations inside might not have a layout. - HloCostAnalysis visitor([](const Shape&) { return 0; }); - TF_RETURN_IF_ERROR(fused_expression_root->Accept(&visitor)); + // Compute the properties of the fused expression and attribute them to the + // fusion node. Use a dummy shape_size to avoid any errors from trying to + // calculate the size of a shape that does not have a layout, since nodes + // inside fusion nodes do not necessarily have a layout assigned. + ShapeSizeFunction shape_size = [](const Shape& shape) { return 0; }; + TF_ASSIGN_OR_RETURN( + current_properties_, + ProcessSubcomputation(fusion->fused_instructions_computation(), + &shape_size)); + + // Fusion nodes that produce a tuple also produce the entries in the tuple. + // Ignore the memory accessed inside fused ops, since fusion is supposed to + // prevent intermediate data from touching slow memory. + current_properties_[kBytesAccessedKey] = 0; + ShapeUtil::ForEachSubshape( + fusion->shape(), + [this](const Shape& subshape, const ShapeIndex& /*shape_index*/) { + current_properties_[kBytesAccessedKey] += shape_size_(subshape); + }); + + for (const HloInstruction* operand : fusion->operands()) { + current_properties_[kBytesAccessedKey] += shape_size_(operand->shape()); + } - // Attribute the cost of the fused expression to the fusion node. - current_transcendental_count_ = visitor.transcendental_count(); - current_flop_count_ = visitor.flop_count(); return Status::OK(); } Status HloCostAnalysis::HandleCall(HloInstruction* call) { - HloCostAnalysis computation_visitor(shape_size_); - TF_RETURN_IF_ERROR(call->to_apply()->Accept(&computation_visitor)); - - current_flop_count_ = computation_visitor.flop_count(); - current_transcendental_count_ = computation_visitor.transcendental_count(); - current_bytes_accessed_ = computation_visitor.bytes_accessed(); + TF_ASSIGN_OR_RETURN(current_properties_, + ProcessSubcomputation(call->to_apply())); + current_should_compute_bottleneck_time_ = false; return Status::OK(); } @@ -382,34 +450,38 @@ Status HloCostAnalysis::HandleCustomCall( HloInstruction* custom_call, tensorflow::gtl::ArraySlice operands, tensorflow::StringPiece custom_call_target) { - return Unimplemented("custom-call"); + return Unimplemented("Custom-call is not implemented for HLO cost analysis."); } Status HloCostAnalysis::HandleSort(HloInstruction* sort, HloInstruction* operand_instruction) { - // The cost of sort is implementation dependent, so cannot determine at HLO - // level. Assume comparison based N*log(N) sorting. + // This assumes a comparison based N*log(N) algorithm. As for all ops, the + // actual properties of the op depend on the backend implementation. int64 elements = ShapeUtil::ElementsIn(operand_instruction->shape()); - current_flop_count_ = elements * tensorflow::Log2Ceiling(elements); + current_properties_[kFlopsKey] = elements * tensorflow::Log2Ceiling(elements); return Status::OK(); } Status HloCostAnalysis::HandleWhile(HloInstruction* xla_while) { - // Since the number of iterations of the while node is not statically - // determined, we cannot precisely compute the cost of a while node. For now - // compute the cost of a single iteration. - // TODO(b/26346211): Improve the cost analysis for while node. - HloCostAnalysis body_visitor(shape_size_); - TF_RETURN_IF_ERROR(xla_while->while_body()->Accept(&body_visitor)); - HloCostAnalysis condition_visitor(shape_size_); - TF_RETURN_IF_ERROR(xla_while->while_condition()->Accept(&condition_visitor)); + // Since the number of iterations of the while node will not always be + // something that we can statically analyze, we cannot precisely compute the + // cost of a while node. For now compute the cost of a single iteration. + // + // TODO(b/26346211): Improve the cost analysis for while nodes. + TF_ASSIGN_OR_RETURN(const Properties body_properties, + ProcessSubcomputation(xla_while->while_body())); - current_flop_count_ = - body_visitor.flop_count() + condition_visitor.flop_count(); - current_transcendental_count_ = body_visitor.transcendental_count() + - condition_visitor.transcendental_count(); - current_bytes_accessed_ = - body_visitor.bytes_accessed() + condition_visitor.bytes_accessed(); + TF_ASSIGN_OR_RETURN(const Properties condition_properties, + ProcessSubcomputation(xla_while->while_condition())); + + current_properties_.clear(); + for (const auto& property : body_properties) { + current_properties_[property.first] += property.second; + } + for (const auto& property : condition_properties) { + current_properties_[property.first] += property.second; + } + current_should_compute_bottleneck_time_ = false; return Status::OK(); } @@ -418,19 +490,42 @@ Status HloCostAnalysis::FinishVisit(HloInstruction* root) { return Status::OK(); } +float HloCostAnalysis::flop_count() const { + return GetProperty(kFlopsKey, properties_sum_); +} + +float HloCostAnalysis::transcendental_count() const { + return GetProperty(kTranscendentalsKey, properties_sum_); +} + +float HloCostAnalysis::bytes_accessed() const { + return GetProperty(kBytesAccessedKey, properties_sum_); +} + +float HloCostAnalysis::seconds() const { + return GetProperty(kSecondsKey, properties_sum_); +} + int64 HloCostAnalysis::flop_count(const HloInstruction& hlo) const { - auto it = hlo_to_flop_count_.find(&hlo); - return it == hlo_to_flop_count_.end() ? 0 : it->second; + return GetPropertyForHlo(hlo, kFlopsKey, hlo_properties_); } int64 HloCostAnalysis::transcendental_count(const HloInstruction& hlo) const { - auto it = hlo_to_transcendental_count_.find(&hlo); - return it == hlo_to_transcendental_count_.end() ? 0 : it->second; + return GetPropertyForHlo(hlo, kTranscendentalsKey, hlo_properties_); } int64 HloCostAnalysis::bytes_accessed(const HloInstruction& hlo) const { - auto it = hlo_to_bytes_accessed_.find(&hlo); - return it == hlo_to_bytes_accessed_.end() ? 0 : it->second; + return GetPropertyForHlo(hlo, kBytesAccessedKey, hlo_properties_); +} + +StatusOr HloCostAnalysis::ProcessSubcomputation( + HloComputation* computation, const ShapeSizeFunction* shape_size) { + if (shape_size == nullptr) { + shape_size = &shape_size_; + } + HloCostAnalysis visitor(*shape_size, per_second_rates_); + TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + return visitor.properties(); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index b2c40f75ca4e833f1f5529977564b0e3a7ca25b1..17c26fc1a15272a53c3abcc9ba2e3a261b1b71ca 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -36,17 +36,22 @@ namespace xla { // operations separately from transcendental operations. class HloCostAnalysis : public DfsHloVisitor { public: + // Each HLO is associated to a vector of properties with the indices given + // below. Sub-classes can add further properties. + typedef std::map Properties; + static constexpr char kFlopsKey[] = "flops"; + static constexpr char kTranscendentalsKey[] = "transcendentals"; + static constexpr char kBytesAccessedKey[] = "bytes accessed"; + static constexpr char kSecondsKey[] = "seconds"; + // shape_size is a function which returns the size in bytes of the top-level // buffer of a shape. using ShapeSizeFunction = std::function; - explicit HloCostAnalysis(const ShapeSizeFunction& shape_size) - : shape_size_(shape_size) {} - - Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode, - HloInstruction* operand) override; - Status HandleElementwiseBinary(HloInstruction* hlo, HloOpcode opcode, - HloInstruction* lhs, - HloInstruction* rhs) override; + explicit HloCostAnalysis(const ShapeSizeFunction& shape_size); + + Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode) override; + Status HandleElementwiseBinary(HloInstruction* hlo, + HloOpcode opcode) override; Status HandleConstant(HloInstruction* constant, const Literal& literal) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element, @@ -58,14 +63,14 @@ class HloCostAnalysis : public DfsHloVisitor { HloInstruction* lhs, HloInstruction* rhs) override; Status HandleClamp(HloInstruction* clamp, HloInstruction* min, HloInstruction* arg, HloInstruction* max) override; + Status HandleReducePrecision(HloInstruction* hlo) override; Status HandleConcatenate( HloInstruction* concatenate, tensorflow::gtl::ArraySlice operands) override; Status HandleSend(HloInstruction* send) override; Status HandleRecv(HloInstruction* recv) override; - Status HandleConvert(HloInstruction* convert, - HloInstruction* operand) override; - Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override; + Status HandleConvert(HloInstruction* convert) override; + Status HandleCopy(HloInstruction* copy) override; Status HandleDot(HloInstruction* dot, HloInstruction* lhs, HloInstruction* rhs) override; Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, @@ -83,6 +88,7 @@ class HloCostAnalysis : public DfsHloVisitor { HloInstruction* init_value, tensorflow::gtl::ArraySlice dimensions, HloComputation* function_handle) override; + Status HandleBatchNormTraining(HloInstruction* batchNormTraining) override; Status HandleFusion(HloInstruction* fusion) override; Status HandleCall(HloInstruction* call) override; Status HandleCustomCall(HloInstruction* custom_call, @@ -119,48 +125,88 @@ class HloCostAnalysis : public DfsHloVisitor { Status Preprocess(HloInstruction* hlo) override; Status Postprocess(HloInstruction* hlo) override; - // Returns the amount of computations in the graph. - int64 flop_count() const { return flop_count_; } - int64 transcendental_count() const { return transcendental_count_; } + // Set the rates used to calculate the time taken by the computation. These + // need to be set before visiting starts. + void set_flops_per_second(float value) { + per_second_rates_[kFlopsKey] = value; + } + void set_transcendentals_per_second(float value) { + per_second_rates_[kTranscendentalsKey] = value; + } + void set_bytes_per_second(float value) { + per_second_rates_[kBytesAccessedKey] = value; + } + + // Returns properties for the computation. + float flop_count() const; + float transcendental_count() const; + float bytes_accessed() const; + float seconds() const; // Returns the respective cost computed for a particular HLO instruction, or 0 // if the HLO was not found to have a cost in the analysis. int64 flop_count(const HloInstruction& hlo) const; int64 transcendental_count(const HloInstruction& hlo) const; - - // Returns the number of bytes read/written. int64 bytes_accessed(const HloInstruction& hlo) const; - int64 bytes_accessed() const { return bytes_accessed_; } + float seconds(const HloInstruction& hlo) const; + + const Properties& properties() const { return properties_sum_; } + const float property(const string& key) const { + return GetProperty(key, properties()); + } + + protected: + typedef std::unordered_map HloToProperties; - private: // An FMA counts as two floating point operations in these analyzes. static constexpr int64 kFmaFlops = 2; + HloCostAnalysis(const ShapeSizeFunction& shape_size, + const Properties& per_second_rates); + + // Returns the properties computed from visiting the computation rooted at the + // given hlo. Uses shape_size_ to calculate shape sizes if shape_size is null, + // otherwise uses shape_size_. + StatusOr ProcessSubcomputation( + HloComputation* computation, + const ShapeSizeFunction* shape_size = nullptr); + // Utility function to handle all element-wise operations. Status HandleElementwiseOp(HloInstruction* hlo_instruction); + // Returns 0.0f if the key is not present in the properties. Otherwise, + // returns the value that the key maps to from the properties parameter. + static float GetProperty(const string& key, const Properties& properties); + + // Returns 0.0f if the hlo is not present in hlo_to_properties or if the key + // is not present in hlo_to_properties[hlo]. Otherwise, returns the value that + // the key maps to in the properties of the given hlo. + static float GetPropertyForHlo(const HloInstruction& hlo, const string& key, + const HloToProperties& hlo_to_properties); + // Function which computes the size of the top-level of a given shape (not // including nested elements, if any). If null then bytes_accessed methods // return an error. const ShapeSizeFunction shape_size_; - // The total number of floating point operations, transcendental operations, - // and bytes accesses (read or written) in the computation. - int64 flop_count_ = 0; - int64 transcendental_count_ = 0; - int64 bytes_accessed_ = 0; - - // Cost counts of the current instruction. These should be set by each - // handlers if different from the default values computed in Preprocess. - int64 current_flop_count_; - int64 current_transcendental_count_; - int64 current_bytes_accessed_; - - // Mapping from HLO instructions to the cost we computed for them in the - // course of the graph analysis. - std::map hlo_to_flop_count_; - std::map hlo_to_transcendental_count_; - std::map hlo_to_bytes_accessed_; + HloToProperties hlo_properties_; + + // If true, the time taken will be computed from the rates for each property + // and the total time will be the maximum time, which is the time of the + // bottleneck. + bool current_should_compute_bottleneck_time_; + + // The properties of the currently visited instruction. A HandleFoo method can + // modify these to change the default values computed in Preprocess. + Properties current_properties_; + + // The sum of the properties of all HLOs in the computation. + Properties properties_sum_; + + // How much of each property can be processed per second. E.g. if the property + // is bytes accessed, this is the number of bytes that can be processed per + // second. Is empty if no rates have been set. + Properties per_second_rates_; TF_DISALLOW_COPY_AND_ASSIGN(HloCostAnalysis); }; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index b74c7eb4e074bd8f340137066b6d9675bb32cee1..f74568316518dd6951672923411ef023cec3b50b 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -332,48 +332,64 @@ TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) { using FusionCostAnalysis = ::testing::Test; TEST_F(FusionCostAnalysis, LoopFusion) { - Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); - - // Fuse all instructions in complicated expression: - // - // add = Add(C1, C2) - // clamp = Clamp(C2, add, add) - // exp = Exp(add) - // mul = Mul(exp, C3) - // sub = Sub(mul, clamp) - // tuple = Tuple({sub, sub, mul, C1}) - auto c1 = HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( - /*from=*/0.0f, /*to=*/1.0f, /*rows=*/2, /*cols=*/2)); - auto c2 = HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( - /*from=*/1.0f, /*to=*/2.0f, /*rows=*/2, /*cols=*/2)); - auto c3 = HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( - /*from=*/2.0f, /*to=*/3.0f, /*rows=*/2, /*cols=*/2)); - - auto add = - HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, c1.get(), c2.get()); - auto clamp = HloInstruction::CreateTernary(r2f32, HloOpcode::kClamp, c2.get(), - add.get(), add.get()); - auto exp = HloInstruction::CreateUnary(r2f32, HloOpcode::kExp, add.get()); - auto mul = HloInstruction::CreateBinary(r2f32, HloOpcode::kMultiply, - exp.get(), c3.get()); - auto sub = HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract, - mul.get(), clamp.get()); - auto tuple = - HloInstruction::CreateTuple({sub.get(), sub.get(), mul.get(), c1.get()}); - - auto fusion = HloInstruction::CreateFusion( - r2f32, HloInstruction::FusionKind::kLoop, tuple.get()); - fusion->FuseInstruction(sub.get()); - fusion->FuseInstruction(mul.get()); - fusion->FuseInstruction(exp.get()); - fusion->FuseInstruction(clamp.get()); - fusion->FuseInstruction(add.get()); - - HloCostAnalysis fusion_analysis(ShapeSize); - ASSERT_IS_OK(fusion->Accept(&fusion_analysis)); - - EXPECT_EQ(fusion_analysis.flop_count(), 16); - EXPECT_EQ(fusion_analysis.transcendental_count(), 4); + // Do this 4 times with different per-second rates to test the computation of + // bottleneck time on fusion nodes. + for (int i = 0; i < 4; ++i) { + Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); + + // Fuse all instructions in complicated expression: + // + // add = Add(C1, C2) + // clamp = Clamp(C2, add, add) + // exp = Exp(add) + // mul = Mul(exp, C3) + // sub = Sub(mul, clamp) + // tuple = Tuple({sub, sub, mul, C1}) + auto c1 = HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/0.0f, /*to=*/1.0f, /*rows=*/2, /*cols=*/2)); + auto c2 = HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/1.0f, /*to=*/2.0f, /*rows=*/2, /*cols=*/2)); + auto c3 = HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/2.0f, /*to=*/3.0f, /*rows=*/2, /*cols=*/2)); + + auto add = HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, c1.get(), + c2.get()); + auto clamp = HloInstruction::CreateTernary(r2f32, HloOpcode::kClamp, + c2.get(), add.get(), add.get()); + auto exp = HloInstruction::CreateUnary(r2f32, HloOpcode::kExp, add.get()); + auto mul = HloInstruction::CreateBinary(r2f32, HloOpcode::kMultiply, + exp.get(), c3.get()); + auto sub = HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract, + mul.get(), clamp.get()); + auto tuple = HloInstruction::CreateTuple( + {sub.get(), sub.get(), mul.get(), c1.get()}); + + auto fusion = HloInstruction::CreateFusion( + r2f32, HloInstruction::FusionKind::kLoop, tuple.get()); + fusion->FuseInstruction(sub.get()); + fusion->FuseInstruction(mul.get()); + fusion->FuseInstruction(exp.get()); + fusion->FuseInstruction(clamp.get()); + fusion->FuseInstruction(add.get()); + + // The time given these rates at i == 0 is exactly even among the properties + // at 1.0 seconds. For other values, one of the rates is slower so that it + // becomes the bottleneck. + HloCostAnalysis fusion_analysis(ShapeSize); + fusion_analysis.set_flops_per_second(16 * (i == 1 ? 1 / 2.0 : 1.0)); + fusion_analysis.set_transcendentals_per_second(4 * + (i == 2 ? 1 / 4.0 : 1.0)); + fusion_analysis.set_bytes_per_second(64 * (i == 3 ? 1 / 8.0 : 1.0)); + ASSERT_IS_OK(fusion->Accept(&fusion_analysis)); + + EXPECT_EQ(fusion_analysis.flop_count(), 16); + EXPECT_EQ(fusion_analysis.transcendental_count(), 4); + constexpr int64 bytes_accessed = sizeof(float) * 4 * 2 * 2; + static_assert(bytes_accessed == 64, ""); + EXPECT_EQ(fusion_analysis.bytes_accessed(), bytes_accessed); + + EXPECT_EQ(fusion_analysis.seconds(), 1 << i); + } } TEST_F(FusionCostAnalysis, NoLayout) { @@ -383,9 +399,8 @@ TEST_F(FusionCostAnalysis, NoLayout) { shape_without_layout.clear_layout(); auto c1 = HloInstruction::CreateConstant( - LiteralUtil::CreateR4FromArray4D(Array4D(2, 3, 4, 5))); - auto c2 = - HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3})); + Literal::CreateR4FromArray4D(Array4D(2, 3, 4, 5))); + auto c2 = HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3})); auto broadcast = HloInstruction::CreateBroadcast(shape_without_layout, c2.get(), {1}); diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index 4c6af5c40fa563d1c656eb152819e454aae5fb69..0fef89a06d01779114b9ac4e0a25a9ae9ded1aef 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -68,7 +68,7 @@ bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { auto range = constants.equal_range(shape_string); HloInstruction* match = nullptr; for (auto it = range.first; it != range.second; ++it) { - if (LiteralUtil::Equal(instruction->literal(), it->second->literal())) { + if (instruction->literal().Equal(it->second->literal())) { match = it->second; break; } diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index cc39c3ac20396f9648b5d325933aad819275b2a6..8b0b9c8bbd0cf442149b32a4539277b2daeed90e 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -51,9 +51,9 @@ TEST_F(HloCseTest, CombineTwoConstants) { // Test that two identical constants are commoned. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); @@ -67,10 +67,10 @@ TEST_F(HloCseTest, CombineTwoConstants) { EXPECT_EQ(2, computation->instruction_count()); HloInstruction* constant = computation->instructions().begin()->get(); - EXPECT_EQ(42.0f, LiteralUtil::Get(constant->literal(), {})); + EXPECT_EQ(42.0f, constant->literal().Get({})); auto result = ExecuteAndTransfer(std::move(module), {}); - auto expected = LiteralUtil::CreateR0(84.0); + auto expected = Literal::CreateR0(84.0); LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4)); } @@ -102,7 +102,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { EXPECT_THAT(add, op::Add(first_operand, first_operand)); auto result = ExecuteAndTransfer(std::move(module), {}); - auto expected = LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); + auto expected = Literal::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4)); } @@ -132,7 +132,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { EXPECT_THAT(add, op::Add(constant1, constant2)); auto result = ExecuteAndTransfer(std::move(module), {}); - auto expected = LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); + auto expected = Literal::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4)); } @@ -141,20 +141,20 @@ TEST_F(HloCseTest, ConstantsSameValueDifferentType) { // commoned. auto builder = HloComputation::Builder(TestName()); builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); + HloInstruction::CreateConstant(Literal::CreateR0(42))); builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); + HloInstruction::CreateConstant(Literal::CreateR0(42))); builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); // Duplicate the float constant to verify something happens. builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -171,13 +171,13 @@ TEST_F(HloCseTest, NonscalarConstants) { // Test that identical nonscalar constants are merged. auto builder = HloComputation::Builder(TestName()); auto common_constant1 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); auto common_constant2 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); // Create a constant which has the same shape but a different value. auto uncommon_constant = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}))); + Literal::CreateR2({{2.0, 4.0}, {6.0, 8.0}}))); // Tie the constants together with a tuple. This makes it easier to refer to // the constant instructions via their use. @@ -206,7 +206,7 @@ TEST_F(HloCseTest, IdenticalInstructions) { // Test that three identical instructions are commoned. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kExp, constant)); auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary( @@ -236,7 +236,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) { // commoned if the pass is layout sensitive. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kExp, constant)); @@ -267,7 +267,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) { // the pass is layout insensitive. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kExp, constant)); @@ -311,7 +311,7 @@ TEST_F(HloCseTest, IdenticalExpressions) { // The *1 instructions should be merged with the *2 instructions. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); auto negate1 = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kNegate, constant)); @@ -349,9 +349,9 @@ TEST_F(HloCseTest, DoNotCombineRng) { // Test that two RNG ops are not commoned. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); auto rng1 = builder.AddInstruction(HloInstruction::CreateRng( ShapeUtil::MakeShape(F32, {}), RandomDistribution::RNG_UNIFORM, {constant1, constant2})); @@ -392,9 +392,9 @@ TEST_F(HloCseTest, DISABLED_DoNotCombineCallsToImpureFunctions) { Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); auto builder = HloComputation::Builder(TestName() + "_rng_fun"); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); auto rng = builder.AddInstruction(HloInstruction::CreateRng( scalar_shape, RandomDistribution::RNG_UNIFORM, {constant1, constant2})); auto param = builder.AddInstruction(HloInstruction::CreateParameter( @@ -409,7 +409,7 @@ TEST_F(HloCseTest, DISABLED_DoNotCombineCallsToImpureFunctions) { { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({5.0f}))); + HloInstruction::CreateConstant(Literal::CreateR1({5.0f}))); auto rng1 = builder.AddInstruction( HloInstruction::CreateMap(constant->shape(), {constant}, rng_function)); auto rng2 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index d1b87256445e4fd51134a66666e5736baf272c71..91592c19024271c72e2462f0ced58f79e098f4c3 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -16,14 +16,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include -#include #include -#include #include #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -35,7 +32,6 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -43,209 +39,6 @@ namespace xla { using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; -string HloLocation::ToString() const { - string index_str = - ShapeUtil::IsTuple(instruction->shape()) ? (" " + index.ToString()) : ""; - return StrCat(instruction->FullyQualifiedName(), index_str); -} - -std::ostream& operator<<(std::ostream& out, const HloLocation& location) { - out << location.ToString(); - return out; -} - -string HloUse::ToString() const { - string index_str = - ShapeUtil::IsTuple(instruction->operand(operand_number)->shape()) - ? (" " + operand_index.ToString()) - : ""; - return StrCat(instruction->FullyQualifiedName(), ", operand ", operand_number, - index_str); -} - -std::ostream& operator<<(std::ostream& out, const HloUse& use) { - out << use.ToString(); - return out; -} - -HloValue::HloValue(HloValue::Id id, HloInstruction* instruction, - const ShapeIndex& index, bool is_phi) - : id_(id), is_phi_(is_phi) { - // The defining location is always the first element in the locations_ vector. - AddLocation(instruction, index); -} - -bool HloValue::operator==(const HloValue& other) const { - bool equal = instruction() == other.instruction() && index() == other.index(); - // If the values are equal they most both be phi (or non phi). - CHECK(!(equal && is_phi() != other.is_phi())); - return equal; -} - -bool HloValue::operator!=(const HloValue& other) const { - return !(*this == other); -} - -string HloValue::ToShortString() const { - string index_str = - ShapeUtil::IsTuple(instruction()->shape()) ? index().ToString() : ""; - return StrCat(is_phi_ ? "PHI " : "", instruction()->FullyQualifiedName(), - index_str); -} - -string HloValue::ToString(int indent) const { - string indentation(indent, ' '); - string out = StrCat(indentation, ToShortString(), ", locations:\n"); - for (const HloLocation& location : locations()) { - StrAppend(&out, indentation, " ", location.ToString(), "\n"); - } - StrAppend(&out, indentation, " uses:\n"); - for (const HloUse& use : uses()) { - StrAppend(&out, indentation, " ", use.ToString(), "\n"); - } - return out; -} - -void HloValue::AddLocation(HloInstruction* instruction, - const ShapeIndex& index) { - // The given location should not already exist in locations_. - for (const HloLocation& location : locations_) { - DCHECK(!(location.instruction == instruction && location.index == index)); - } - - locations_.push_back(HloLocation{instruction, index}); - - // Update uses. - for (HloInstruction* user : instruction->users()) { - for (int64 operand_number : user->OperandIndices(instruction)) { - if (!DoesNotUseOperandBuffer(instruction, index, user)) { - for (const HloUse& use : uses_) { - // Verify that this use does not already exist. - DCHECK(!(use.instruction == user && - use.operand_number == operand_number && - use.operand_index == index)); - } - - uses_.push_back(HloUse{user, operand_number, index}); - } - } - } - - // Update liveout status of this HloValue. - const HloModule& module = *instruction->parent()->parent(); - if (instruction == module.entry_computation()->root_instruction()) { - live_out_of_module_ = true; - } -} - -void HloValue::RemoveLocation(HloInstruction* instruction, - const ShapeIndex& index) { - // The defining location cannot be removed. - CHECK(!(instruction == this->instruction() && index == this->index())); - - int64 size_before = locations_.size(); - locations_.erase( - std::remove_if(locations_.begin(), locations_.end(), - [instruction, &index](const HloLocation& location) { - return location.instruction == instruction && - location.index == index; - }), - locations_.end()); - // Only a single location should have been removed. - CHECK_EQ(locations_.size(), size_before - 1); - - // Update uses which referred to this location. - uses_.erase(std::remove_if(uses_.begin(), uses_.end(), - [instruction, &index](const HloUse& use) { - return use.instruction->operand( - use.operand_number) == instruction && - use.operand_index == index; - }), - uses_.end()); - - const HloModule& module = *instruction->parent()->parent(); - if (instruction == module.entry_computation()->root_instruction()) { - // Value has been removed from a location in the entry root instruction. - // Check if the value is still live out of the module by walking all - // remaining locations. - live_out_of_module_ = false; - for (const HloLocation& location : locations()) { - if (location.instruction == - module.entry_computation()->root_instruction()) { - live_out_of_module_ = true; - break; - } - } - } -} - -std::ostream& operator<<(std::ostream& out, const HloValue& value) { - out << value.ToShortString(); - return out; -} - -void HloValueSet::SortAndUniquifyValues() { - std::sort(value_ids_.begin(), value_ids_.end()); - value_ids_.erase(std::unique(value_ids_.begin(), value_ids_.end()), - value_ids_.end()); -} - -string HloValueSet::ToString() const { - return StrCat("HloValueSet: ", tensorflow::str_util::Join(value_ids_, ", ")); -} - -/*static */ -HloValueSet HloValueSet::Union( - tensorflow::gtl::ArraySlice inputs) { - HloValueSet union_set; - for (const HloValueSet* input : inputs) { - for (HloValue::Id value_id : input->value_ids()) { - union_set.value_ids_.push_back(value_id); - } - } - union_set.SortAndUniquifyValues(); - return union_set; -} - -std::ostream& operator<<(std::ostream& out, const HloValueSet& value_set) { - out << value_set.ToString(); - return out; -} - -InstructionValueSet InstructionValueSet::Union( - tensorflow::gtl::ArraySlice inputs) { - CHECK_GT(inputs.size(), 0); - for (int i = 1; i < inputs.size(); ++i) { - CHECK(ShapeUtil::Compatible(inputs[0]->shape(), inputs[i]->shape())); - } - InstructionValueSet union_set(inputs[0]->shape()); - union_set.ForEachMutableElement( - [&inputs](const ShapeIndex& index, HloValueSet* value_set) { - std::vector input_sets; - for (const InstructionValueSet* input : inputs) { - input_sets.push_back(&input->element(index)); - } - *value_set = HloValueSet::Union(input_sets); - }); - return union_set; -} - -std::ostream& operator<<(std::ostream& out, - const InstructionValueSet& instruction_value_set) { - out << instruction_value_set.ToString(); - return out; -} - -string InstructionValueSet::ToString() const { - string out = - StrCat("InstructionValueSet(", ShapeUtil::HumanString(shape()), ")\n"); - ForEachElement([this, &out](const ShapeIndex& index, - const HloValueSet& value_set) { - StrAppend(&out, " ", index.ToString(), " : ", value_set.ToString(), "\n"); - }); - return out; -} - HloDataflowAnalysis::HloDataflowAnalysis(HloModule* module, bool ssa_form, bool bitcast_defines_value) : module_(module), @@ -259,7 +52,8 @@ bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction, if (value_set.value_ids().size() != 1) { return false; } - return GetValue(value_set.GetUniqueValueId()).instruction() == instruction; + return GetValue(value_set.GetUniqueValueId()).defining_instruction() == + instruction; } const HloValue& HloDataflowAnalysis::GetValueDefinedAt( @@ -305,7 +99,7 @@ string HloDataflowAnalysis::ToString() const { module_->computations()) { for (const std::unique_ptr& instruction : computation->instructions()) { - StrAppend(&out, " ", instruction->FullyQualifiedName(), ":\n"); + StrAppend(&out, " ", instruction->name(), ":\n"); if (ShapeUtil::IsTuple(instruction->shape())) { GetInstructionValueSet(instruction.get()) .ForEachElement([this, &instruction, &out]( @@ -468,8 +262,8 @@ void HloDataflowAnalysis::UpdateLocationsOfValuesAt( } // Don't remove the defining location of the value. HloValue& value = GetValue(value_id); - if (instruction == value.instruction()) { - CHECK_EQ(index, value.index()); + if (instruction == value.defining_instruction()) { + CHECK_EQ(index, value.defining_index()); } else { value.RemoveLocation(instruction, index); } @@ -482,8 +276,8 @@ void HloDataflowAnalysis::UpdateLocationsOfValuesAt( const HloValueSet& value_set) { for (HloValue::Id value_id : value_set.value_ids()) { HloValue& value = GetValue(value_id); - if (instruction == value.instruction()) { - CHECK_EQ(index, value.index()); + if (instruction == value.defining_instruction()) { + CHECK_EQ(index, value.defining_index()); } else { value.AddLocation(instruction, index); } @@ -694,15 +488,24 @@ InstructionValueSet HloDataflowAnalysis::RecomputeParameterValueSet( std::vector inputs; bool called_from_while = false; for (const CallSite& callsite : call_graph_node.caller_callsites()) { - inputs.push_back(&GetInstructionValueSet( - callsite.instruction()->operand(parameter->parameter_number()))); - if (callsite.instruction()->opcode() == HloOpcode::kWhile) { - // In a while instruction, the backedge is also a dataflow input to the - // parameter instruction. This code covers the case where the parameter is - // in the while body or the parameter is in the while condition. + if (callsite.instruction()->opcode() == HloOpcode::kCall) { + // The operand values of a call instruction are forwarded to the + // respective parameter instruction of the subcomputation. + inputs.push_back(&GetInstructionValueSet( + callsite.instruction()->operand(parameter->parameter_number()))); + } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) { + // In a while instruction, the while operand (ie, the init value) and the + // backedge are dataflow inputs to the parameter instruction. This is the + // case for parameters of both the body and condition computations. + CHECK_EQ(parameter->parameter_number(), 0); + inputs.push_back( + &GetInstructionValueSet(callsite.instruction()->operand(0))); inputs.push_back(&GetInstructionValueSet( callsite.instruction()->while_body()->root_instruction())); called_from_while = true; + } else { + LOG(FATAL) << "CallContext::kSequential computations should only be " + "called from call or while instructions"; } } @@ -804,6 +607,149 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { return Status::OK(); } +bool HloDataflowAnalysis::IsDefinedBefore(const HloValue& a, const HloValue& b, + const HloOrdering& ordering) const { + // If 'b' is an entry param then 'a' cannot be defined before 'b' because 'b' + // is live into the module. + if (b.defining_instruction()->parent() == module_->entry_computation() && + b.defining_instruction()->opcode() == HloOpcode::kParameter) { + return false; + } + + // Phi values require special handling. Because XLA does not have a phi + // instruction, the definition instruction of the phis values are + // placeholders: either the subcomputation parameter (body or condition) or + // the while instruction. However, the program point where these values are + // logically defined does not necessarily coincide exactly with program point + // of these place-holder instructions. So we explicitly define the following + // order for phi values: + // + // body/condition parameter phi: + // Defined before all values defined in its computation excepting other + // phis. + // + // while phi: + // defined after all values defined in the condition or body. + // + auto is_body_or_condition_phi = [](const HloValue& v) { + return v.is_phi() && + v.defining_instruction()->opcode() == HloOpcode::kParameter; + }; + if (is_body_or_condition_phi(a) && !is_body_or_condition_phi(b) && + call_graph_->InstructionIsNestedIn(b.defining_instruction(), + a.defining_instruction()->parent())) { + return true; + } + if (is_body_or_condition_phi(b) && + call_graph_->InstructionIsNestedIn(a.defining_instruction(), + b.defining_instruction()->parent())) { + return false; + } + + // If 'b' is a while phi and 'a' is in the body or condition, then 'a' + // executes before 'b'. + if (b.is_phi() && b.defining_instruction()->opcode() == HloOpcode::kWhile && + (call_graph_->InstructionIsNestedIn( + a.defining_instruction(), b.defining_instruction()->while_body()) || + call_graph_->InstructionIsNestedIn( + a.defining_instruction(), + b.defining_instruction()->while_condition()))) { + return true; + } + + return ordering.ExecutesBefore(a.defining_instruction(), + b.defining_instruction()); +} + +bool HloDataflowAnalysis::UseIsBeforeValueDefinition( + const HloUse& use, const HloValue& value, + const HloOrdering& ordering) const { + if (ordering.ExecutesBefore(use.instruction, value.defining_instruction())) { + return true; + } + + // If the use is at the instruction where the value is defined, then the use + // is before the def if the instruction allows buffer sharing (in place + // computation). + if (use.instruction == value.defining_instruction() && + CanShareOperandBufferWithUser( + use.instruction->mutable_operand(use.operand_number), + use.operand_index, value.defining_instruction(), + value.defining_index())) { + return true; + } + + // The use at a while is an input to a phi, and logically occurs before values + // are defined in the body or condition computations. + if (use.instruction->opcode() == HloOpcode::kWhile) { + const HloInstruction* xla_while = use.instruction; + if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), + xla_while->while_body()) || + call_graph_->InstructionIsNestedIn(value.defining_instruction(), + xla_while->while_condition())) { + return true; + } + } + + // Similarly if the value is defined at a while, it logically occurs after any + // uses in the body or condition computations. + if (value.defining_instruction()->opcode() == HloOpcode::kWhile) { + CHECK(ssa_form_); + const HloInstruction* xla_while = value.defining_instruction(); + if (call_graph_->InstructionIsNestedIn(use.instruction, + xla_while->while_body()) || + call_graph_->InstructionIsNestedIn(use.instruction, + xla_while->while_condition())) { + return true; + } + } + return false; +} + +bool HloDataflowAnalysis::LiveRangeStrictlyBefore( + const HloValue& a, const HloValue& b, const HloOrdering& ordering) const { + VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString() + << ", b = " << b.ToShortString() << ")"; + if (!IsDefinedBefore(a, b, ordering)) { + VLOG(4) << "a not defined before b"; + return false; + } + + // Live-out values from the module can never have ranges strictly before any + // other value. + if (a.live_out_of_module()) { + VLOG(4) << "a is live out of module"; + return false; + } + + // Live-out values of computations can never have ranges strictly before any + // other value in the computation (including values nested in + // subcomputations). + if (a.live_out_of_computation() && + call_graph_->InstructionIsNestedIn(b.defining_instruction(), + a.defining_instruction()->parent())) { + VLOG(4) << "a is live out of computation containing b"; + return false; + } + + // All uses of 'a' must be before 'b' is defined. + for (const HloUse& use : a.uses()) { + if (!UseIsBeforeValueDefinition(use, b, ordering)) { + VLOG(4) << "use of a (" << use << ") not before b is defined"; + return false; + } + } + + return true; +} + +bool HloDataflowAnalysis::MayInterfere(const HloValue& a, const HloValue& b, + const HloOrdering& ordering) const { + // Buffers without disjoint liveness may interfere. + return !LiveRangeStrictlyBefore(a, b, ordering) && + !LiveRangeStrictlyBefore(b, a, ordering); +} + /* static */ StatusOr> HloDataflowAnalysis::Run( HloModule* module, bool ssa_form, bool bitcast_defines_value) { diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index 2f9b0a64be5a00f490e5fc678ac5589e374f80d7..d909c5b668e661bae282b38bc0ce02845bf730f2 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -20,7 +20,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_ -#include +#include #include #include #include @@ -28,222 +28,18 @@ limitations under the License. #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/types.h" namespace xla { -// Abstraction which identifies a specific point in the XLA graph. An -// HloLocation specifies a ShapeIndex within the output of a specific -// instruction. -struct HloLocation { - HloInstruction* instruction; - ShapeIndex index; - - string ToString() const; - - bool operator==(const HloLocation& other) const { - return instruction == other.instruction && index == other.index; - } - bool operator!=(const HloLocation& other) const { return !(*this == other); } -}; - -std::ostream& operator<<(std::ostream& out, const HloLocation& location); - -// Defines a single use of an HLO value. -struct HloUse { - // Instruction at which the value is used. - HloInstruction* instruction; - - // The operand number in which the value is appears. - int64 operand_number; - - // The shape index within the operand in which the value appears. - ShapeIndex operand_index; - - string ToString() const; - - bool operator==(const HloUse& other) const { - return instruction == other.instruction && - operand_number == other.operand_number && - operand_index == other.operand_index; - } - - bool operator!=(const HloUse& other) const { return !(*this == other); } -}; - -std::ostream& operator<<(std::ostream& out, const HloUse& use); - -// Class describing a value used by the dataflow analysis. XLA arrays are -// trivially a single HloValue. Tuples are made up of more than one HloValue: an -// HloValue for the pointer vector, and an HloValue for each child element. -// -// Every HloValue is defined by a particular instruction and most instructions -// define only a single HloValue. Instructions which define a single HloValue -// include array-shaped instructions such as Add but also includes Tuple-shaped -// instructions such as Tuple. The Tuple instruction defines a single HloValue -// which is a vector of pointers to the values containing the Tuple -// instruction's operands. Though the result of the Tuple instruction includes -// multiple values only the top-level HloValue (the vector of pointers) is -// defined by the Tuple instruction. The values containing the tuple elements -// are defined by earlier instructions, usually the operands of the Tuple -// instruction. -// -// Instructions which construct both the tuple *and* the tuple elements define -// more than one HloValue. This includes (at least) tuple-shaped Constant, -// Parameter, Infeed and While instructions. These tuple-shaped instructions do -// not assemble a tuple from existing HloValues like the Tuple instruction does, -// but rather define all the HloValues in the tuple. -class HloValue { - public: - using Id = int64; - - // Construct an HloValue defined by 'instruction' at shape index 'index'. If - // is_phi is true, then this value is a phi value, for example, at the - // parameter of a while body computation. Phi values are only used in the SSA - // dataflow analysis (HloDataflowAnalysis::ssa_form_ is true). - HloValue(HloValue::Id id, HloInstruction* instruction, - const ShapeIndex& index, bool is_phi = false); - - // Return a unique identifier for this HloValue. This value is used for stable - // sorting and iteration - Id id() const { return id_; } - - // Returns whether this value is a phi value. - bool is_phi() const { return is_phi_; } - - // Return the location where this value is defined. - const HloLocation& DefinitionLocation() const { return locations_[0]; } - - // Return the instruction which defines this HloValue. - HloInstruction* instruction() const { - return DefinitionLocation().instruction; - } - - // Return the shape index at which this HloValue is defined in the output of - // instruction(). - const ShapeIndex& index() const { return DefinitionLocation().index; } - - // Add or remove a location at which the HloValue appears. The definition - // location can not be removed. The uses of the HloValue are updated. - void AddLocation(HloInstruction* instruction, const ShapeIndex& index); - void RemoveLocation(HloInstruction* instruction, const ShapeIndex& index); - - // Return all locations of the HloValue in the module. - const std::vector& locations() const { return locations_; } - - // Return all uses of the HloValue. - const std::vector& uses() const { return uses_; } - - // Set/get whether this HloValue is live out of the module. - bool live_out_of_module() const { return live_out_of_module_; } - - bool operator==(const HloValue& other) const; - bool operator!=(const HloValue& other) const; - - // Return a single-line string representation of the value. - string ToShortString() const; - - string ToString(int indent = 0) const; - - private: - // Unique identifier for this HloValue. Used for stable sorting and iteration. - const Id id_; - - // Whether this instruction is a phi value. - const bool is_phi_; - - // The set of locations of this HloValue. The first element is always the - // location of the definition. - std::vector locations_; - - // The set of uses of this HloValue. - std::vector uses_; - - // Whether this value is live out of the HLO module. - bool live_out_of_module_ = false; -}; - -std::ostream& operator<<(std::ostream& out, const HloValue& hlo_value); - -// A class representing the possible set of HloValues at a particular point -// (shape index in the output of an instruction) in the XLA graph. This set -// contains the set of reaching HloValue definitions. For a simple array-shaped -// instruction like Add, the HloValueSet of the top-level of the instruction's -// output trivially contains only the HloValue defined by the instruction. For -// instructions which have non-trivial dataflow such as Tuple or Select, the -// HloValueSets of the instruction's output contains one or more HloValues -// defined by the instruction's operands or defined further up in the XLA graph. -class HloValueSet { - public: - HloValueSet() = default; - - explicit HloValueSet(tensorflow::gtl::ArraySlice value_ids) - : value_ids_(value_ids.begin(), value_ids.end()) { - SortAndUniquifyValues(); - } - - // Return the union of the given HloValueSets. - static HloValueSet Union( - tensorflow::gtl::ArraySlice inputs); - - // Return the vector of the IDs of all HloValues in the set. Values in the - // vector are unique and sorted. - const std::vector& value_ids() const { return value_ids_; } - - // Return the unique HLO value in the set. CHECKs if the set does not contain - // exactly one value. - HloValue::Id GetUniqueValueId() const { - CHECK_EQ(value_ids().size(), 1); - return value_ids()[0]; - } - - bool operator==(const HloValueSet& other) const { - return value_ids() == other.value_ids(); - } - bool operator!=(const HloValueSet& other) const { return !(*this == other); } - - string ToString() const; - - private: - // Sorts value_ and removes duplicates. This should be called after adding any - // elements to values_. - void SortAndUniquifyValues(); - - // HloValues sorted by HloValue::Id. - std::vector value_ids_; -}; - -std::ostream& operator<<(std::ostream& out, const HloValueSet& hlo_value); - -// A class collecting the HloValues which might be contained in the output of -// an HLO instruction. For array-shaped instructions, an InstructionValueSet -// trivially holds a single HloValueSet. Tuple-shaped InstructionValueSets -// hold multiple HloValueSets. -class InstructionValueSet : public ShapeTree { - public: - InstructionValueSet(const Shape& shape) : ShapeTree(shape) {} - - // Return the union of the given InstructionValueSets. - static InstructionValueSet Union( - tensorflow::gtl::ArraySlice inputs); - - string ToString() const; -}; - -std::ostream& operator<<(std::ostream& out, - const InstructionValueSet& instruction_value_set); - // Analysis which identifies all HLO values and their uses in an HLO module. class HloDataflowAnalysis { public: @@ -309,6 +105,17 @@ class HloDataflowAnalysis { const HloValue& GetValue(HloValue::Id value_id) const; HloValue& GetValue(HloValue::Id value_id); + // Returns whether the given values interfere assuming the given HLO + // ordering. Two values interfere if they may both be simultaneously live. + bool MayInterfere(const HloValue& a, const HloValue& b, + const HloOrdering& ordering) const; + + // Overload which takes HloValue:Ids. + bool MayInterfere(HloValue::Id a, HloValue::Id b, + const HloOrdering& ordering) const { + return MayInterfere(GetValue(a), GetValue(b), ordering); + } + // Return the total number of HloValues. int64 value_count() const { return values_.size(); } @@ -374,6 +181,20 @@ class HloDataflowAnalysis { HloInstruction* instruction, const InstructionValueSet& new_value_set, const InstructionValueSet* prev_value_set = nullptr); + // Returns true if the live range of the given value 'a' is strictly before + // the live range of value 'b' using the given HLO ordering. + bool LiveRangeStrictlyBefore(const HloValue& a, const HloValue& b, + const HloOrdering& ordering) const; + + // Returns whether the value 'a' is defined before the value 'b' under the + // given ordering. + bool IsDefinedBefore(const HloValue& a, const HloValue& b, + const HloOrdering& ordering) const; + + // Returns whether the given use is before the given value definition. + bool UseIsBeforeValueDefinition(const HloUse& use, const HloValue& value, + const HloOrdering& ordering) const; + HloModule* const module_; const bool ssa_form_; const bool bitcast_defines_value_; diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 21344af5f224843a857984162a36b8a09915e607..79edd0fcb59023d4c89cbe4712c3bc9446834ef5 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -39,14 +39,14 @@ using ::testing::UnorderedElementsAre; class HloDataflowAnalysisTest : public HloTestBase, public ::testing::WithParamInterface { protected: - HloDataflowAnalysisTest() : module_(TestName()) {} + HloDataflowAnalysisTest() : module_(CreateNewModule()) {} // Run dataflow analysis on the member module. For convenience returns a // reference to the generated analysis stored in analysis_. const HloDataflowAnalysis& RunAnalysis(bool ssa_form, bool bitcast_defines_value = false) { analysis_ = - HloDataflowAnalysis::Run(&module_, ssa_form, bitcast_defines_value) + HloDataflowAnalysis::Run(module_.get(), ssa_form, bitcast_defines_value) .ConsumeValueOrDie(); return *analysis_; } @@ -63,22 +63,34 @@ class HloDataflowAnalysisTest : public HloTestBase, return values; } - HloModule module_; + // Returns true if the top-level values for instructions 'a' and 'b' may + // interfere. Precondition: 'a' and 'b' define array-shaped values. + bool InstructionsMayInterfere(const HloOrdering& ordering, + const HloInstruction* a, + const HloInstruction* b) { + EXPECT_FALSE(ShapeUtil::IsTuple(a->shape())); + EXPECT_FALSE(ShapeUtil::IsTuple(b->shape())); + return analysis_->MayInterfere(analysis_->GetValueDefinedAt(a), + analysis_->GetValueDefinedAt(b), ordering); + } + + std::unique_ptr module_; std::unique_ptr analysis_; const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); + const Shape vector_shape_ = ShapeUtil::MakeShape(F32, {42}); }; TEST_P(HloDataflowAnalysisTest, BinaryOperation) { // Test the dataflow for a simple binary operation (Add). auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, constant1, constant2)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -126,7 +138,7 @@ TEST_P(HloDataflowAnalysisTest, TupleAndGtes) { HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape_, HloOpcode::kAdd, gte0, gte1)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -158,27 +170,21 @@ TEST_P(HloDataflowAnalysisTest, TupleAndGtes) { // Verify uses. Of interest is that a GetTupleElement instruction is only a // use of the top-level value in the tuple operand. EXPECT_THAT(analysis.GetValueDefinedAt(param0).uses(), - UnorderedElementsAre(HloUse{tuple, 0, {}}, HloUse{add, 0, {}})); + UnorderedElementsAre(HloUse{add, 0, {}})); EXPECT_THAT(analysis.GetValueDefinedAt(param1).uses(), - UnorderedElementsAre(HloUse{tuple, 1, {}}, HloUse{add, 1, {}})); + UnorderedElementsAre(HloUse{add, 1, {}})); EXPECT_THAT(analysis.GetValueDefinedAt(tuple, /*index=*/{}).uses(), UnorderedElementsAre(HloUse{gte0, 0, {}}, HloUse{gte1, 0, {}})); EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); } TEST_P(HloDataflowAnalysisTest, NestedTuple) { - // Verify the dataflow through a nested tuple of the following form for two - // constants %constant1 and %constant2: - // - // %nested_tuple = {{%constant1, %constant2}, - // {%constant1, %constant2}, - // %constant1} - // + // Verify the dataflow through a nested tuple. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto nested_tuple = builder.AddInstruction( @@ -187,7 +193,7 @@ TEST_P(HloDataflowAnalysisTest, NestedTuple) { HloInstruction::CreateGetTupleElement(tuple->shape(), nested_tuple, 1)); auto gte_out = builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_shape_, gte_tuple, 0)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -202,18 +208,15 @@ TEST_P(HloDataflowAnalysisTest, NestedTuple) { HloLocation{nested_tuple, {0, 0}}, HloLocation{nested_tuple, {1, 0}}, HloLocation{nested_tuple, {2}}, HloLocation{gte_tuple, {0}}, HloLocation{gte_out, {}})); - EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre( - HloUse{tuple, 0, {}}, HloUse{nested_tuple, 0, {0}}, - HloUse{nested_tuple, 1, {0}}, HloUse{nested_tuple, 2, {}})); - EXPECT_THAT( - analysis.GetValueDefinedAt(constant2).uses(), - UnorderedElementsAre(HloUse{tuple, 1, {}}, HloUse{nested_tuple, 0, {1}}, - HloUse{nested_tuple, 1, {1}})); + // Constant values should have no uses though one is live out. The locations + // where they appear as operands are on instructions which do not use the + // values (eg, Tuple). + EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).uses().empty()); + EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).uses().empty()); + + // The top-level tuple values are used in GTE instructions. EXPECT_THAT(analysis.GetValueDefinedAt(tuple, /*index=*/{}).uses(), - UnorderedElementsAre(HloUse{nested_tuple, 0, {}}, - HloUse{nested_tuple, 1, {}}, - HloUse{gte_out, 0, {}})); + UnorderedElementsAre(HloUse{gte_out, 0, {}})); EXPECT_THAT(analysis.GetValueDefinedAt(nested_tuple, /*index=*/{}).uses(), UnorderedElementsAre(HloUse{gte_tuple, 0, {}})); @@ -236,16 +239,16 @@ TEST_P(HloDataflowAnalysisTest, SingleCall) { auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, subparam0, subparam1)); HloComputation* called_computation = - module_.AddEmbeddedComputation(subbuilder.Build()); + module_->AddEmbeddedComputation(subbuilder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto call = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -268,11 +271,12 @@ TEST_P(HloDataflowAnalysisTest, SingleCall) { EXPECT_EQ(analysis.GetUniqueValueAt(call), analysis.GetValueDefinedAt(add)); EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{call, 0, {}})); + UnorderedElementsAre(HloUse{add, 0, {}})); EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), - UnorderedElementsAre(HloUse{add, 1, {}}, HloUse{call, 1, {}})); + UnorderedElementsAre(HloUse{add, 1, {}})); EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); + EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); } TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) { @@ -285,20 +289,20 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) { auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, subparam0, subparam1)); HloComputation* called_computation = - module_.AddEmbeddedComputation(subbuilder.Build()); + module_->AddEmbeddedComputation(subbuilder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto call1 = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); auto call2 = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); auto sub = builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kSubtract, call1, call2)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -316,17 +320,18 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) { EXPECT_TRUE(analysis.ValueIsDefinedAt(sub)); EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{call1, 0, {}}, - HloUse{call2, 0, {}})); + UnorderedElementsAre(HloUse{add, 0, {}})); EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), - UnorderedElementsAre(HloUse{add, 1, {}}, HloUse{call1, 1, {}}, - HloUse{call2, 1, {}})); + UnorderedElementsAre(HloUse{add, 1, {}})); // The Add from the subcomputation is used as both operands of the Subtract. EXPECT_THAT(analysis.GetValueDefinedAt(add).uses(), UnorderedElementsAre(HloUse{sub, 0, {}}, HloUse{sub, 1, {}})); EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module()); + EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); + EXPECT_TRUE(analysis.GetValueDefinedAt(sub).live_out_of_module()); + EXPECT_TRUE(analysis.GetValueDefinedAt(sub).live_out_of_computation()); } TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithDifferentArguments) { @@ -339,18 +344,18 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithDifferentArguments) { auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, subparam0, subparam1)); HloComputation* called_computation = - module_.AddEmbeddedComputation(subbuilder.Build()); + module_->AddEmbeddedComputation(subbuilder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto call1 = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); auto call2 = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {call1, constant2}, called_computation)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -392,7 +397,7 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) { auto add = inner_builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, inner_param0, inner_param1)); HloComputation* inner_computation = - module_.AddEmbeddedComputation(inner_builder.Build()); + module_->AddEmbeddedComputation(inner_builder.Build()); auto outer_builder = HloComputation::Builder("OuterComputation"); auto outer_param0 = outer_builder.AddInstruction( @@ -400,19 +405,19 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) { auto outer_param1 = outer_builder.AddInstruction( HloInstruction::CreateParameter(1, scalar_shape_, "param1")); // Swizzle parameters. - auto nested_call = outer_builder.AddInstruction(HloInstruction::CreateCall( + outer_builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {outer_param1, outer_param0}, inner_computation)); HloComputation* outer_computation = - module_.AddEmbeddedComputation(outer_builder.Build()); + module_->AddEmbeddedComputation(outer_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); - auto call = builder.AddInstruction(HloInstruction::CreateCall( + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, outer_computation)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -423,14 +428,10 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) { // Verify that the uses of the constants are properly swizzled by parameter // permutation in nested_call. - EXPECT_THAT( - analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre(HloUse{call, 0, {}}, HloUse{nested_call, 1, {}}, - HloUse{add, 1, {}})); - EXPECT_THAT( - analysis.GetValueDefinedAt(constant2).uses(), - UnorderedElementsAre(HloUse{call, 1, {}}, HloUse{nested_call, 0, {}}, - HloUse{add, 0, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), + UnorderedElementsAre(HloUse{add, 1, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), + UnorderedElementsAre(HloUse{add, 0, {}})); EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); } @@ -465,33 +466,37 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) { HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); auto add = body_builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); - auto body_tuple = body_builder.AddInstruction( + body_builder.AddInstruction( HloInstruction::CreateTuple({body_element_0, add})); - HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build()); + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); // Condition computation trivially returns a constant "false". auto cond_builder = HloComputation::Builder("condition"); auto cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); - cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + auto cond_constant = cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); HloComputation* condition = - module_.AddEmbeddedComputation(cond_builder.Build()); + module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, tuple)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + EXPECT_TRUE( + analysis.GetValueDefinedAt(cond_constant).live_out_of_computation()); + EXPECT_FALSE(analysis.GetValueDefinedAt(cond_constant).live_out_of_module()); + if (ssa_form) { // Element 0 of the tuple passed through the body so no phi value is // defined. @@ -507,15 +512,17 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) { EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1})); EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{1}).is_phi()); - EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{tuple, 0, {}}, - HloUse{xla_while, 0, {0}}, - HloUse{body_tuple, 0, {}})); + EXPECT_THAT( + analysis.GetValueDefinedAt(constant1).uses(), + UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{xla_while, 0, {0}})); // Constant1 passes through the body and out of the module. EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}) .live_out_of_module()); + + EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); + EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module()); } else { // While instruction and subcomputation parameters should not define values // in non-ssa form. @@ -528,6 +535,7 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) { EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); + EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); } } @@ -565,21 +573,21 @@ TEST_P(HloDataflowAnalysisTest, SequentialWhiles) { scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); body_builder.AddInstruction( HloInstruction::CreateTuple({body_element_0, add})); - HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build()); + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); auto cond_builder = HloComputation::Builder("condition"); cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); HloComputation* condition = - module_.AddEmbeddedComputation(cond_builder.Build()); + module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while0 = builder.AddInstruction( @@ -588,7 +596,7 @@ TEST_P(HloDataflowAnalysisTest, SequentialWhiles) { HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while0)); auto xla_while2 = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while1)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -630,9 +638,9 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); HloComputation* condition = - module_.AddEmbeddedComputation(cond_builder.Build()); + module_->AddEmbeddedComputation(cond_builder.Build()); // Element 0 passes transparently through the body. auto inner_builder = HloComputation::Builder("inner_body"); @@ -647,7 +655,7 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) { inner_builder.AddInstruction( HloInstruction::CreateTuple({inner_element_0, add})); HloComputation* inner_body = - module_.AddEmbeddedComputation(inner_builder.Build()); + module_->AddEmbeddedComputation(inner_builder.Build()); // Element 1 passes transparently through the body. auto outer_builder = HloComputation::Builder("outer_body"); @@ -664,18 +672,18 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) { auto nested_while = outer_builder.AddInstruction(HloInstruction::CreateWhile( tuple_shape, condition, inner_body, outer_tuple)); HloComputation* outer_body = - module_.AddEmbeddedComputation(outer_builder.Build()); + module_->AddEmbeddedComputation(outer_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto entry_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, outer_body, tuple)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -751,26 +759,26 @@ TEST_P(HloDataflowAnalysisTest, SwizzlingWhile) { HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); body_builder.AddInstruction( HloInstruction::CreateTuple({body_element_1, body_element_0})); - HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build()); + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); auto cond_builder = HloComputation::Builder("condition"); auto cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); HloComputation* condition = - module_.AddEmbeddedComputation(cond_builder.Build()); + module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, tuple)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -817,15 +825,15 @@ TEST_P(HloDataflowAnalysisTest, ArraySelect) { // Test a kSelect of an array value. auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( scalar_shape_, HloOpcode::kSelect, pred, constant1, constant2)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -841,15 +849,15 @@ TEST_P(HloDataflowAnalysisTest, TupleSelect) { // instruction. auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + HloInstruction::CreateConstant(Literal::CreateR0(3.0))); auto constant4 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(4.0))); + HloInstruction::CreateConstant(Literal::CreateR0(4.0))); auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({constant1})); auto tuple2 = @@ -868,7 +876,7 @@ TEST_P(HloDataflowAnalysisTest, TupleSelect) { auto select1234 = builder.AddInstruction(HloInstruction::CreateTernary( tuple_shape, HloOpcode::kSelect, pred, select12, select34)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -899,31 +907,33 @@ TEST_P(HloDataflowAnalysisTest, TupleSelect) { analysis.GetValueDefinedAt(constant4))); EXPECT_THAT( - analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre(HloUse{tuple1, 0, {}}, HloUse{select11, 1, {0}}, - HloUse{select11, 2, {0}}, HloUse{select12, 1, {0}}, - HloUse{select1234, 1, {0}})); - EXPECT_THAT( - analysis.GetValueDefinedAt(constant2).uses(), - UnorderedElementsAre(HloUse{tuple2, 0, {}}, HloUse{select12, 2, {0}}, - HloUse{select1234, 1, {0}})); + analysis.GetValueDefinedAt(tuple1, /*index=*/{}).uses(), + UnorderedElementsAre(HloUse{select11, 1, {}}, HloUse{select11, 2, {}}, + HloUse{select12, 1, {}})); + + // The two constant values just pass through the Selects and are not + // used. They are live out however. + EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).uses().empty()); + EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).uses().empty()); + EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); + EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).live_out_of_module()); } TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) { // Test kSelect of a nested tuple. auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + HloInstruction::CreateConstant(Literal::CreateR0(3.0))); auto constant4 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(4.0))); + HloInstruction::CreateConstant(Literal::CreateR0(4.0))); auto constant5 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0))); + HloInstruction::CreateConstant(Literal::CreateR0(5.0))); auto inner_tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant2, constant3})); auto tuple1 = builder.AddInstruction( @@ -935,7 +945,7 @@ TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) { auto select = builder.AddInstruction(HloInstruction::CreateTernary( tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -993,24 +1003,24 @@ TEST_P(HloDataflowAnalysisTest, TupleSelectToWhile) { scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); body_builder.AddInstruction( HloInstruction::CreateTuple({body_element_0, add})); - HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build()); + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); auto cond_builder = HloComputation::Builder("condition"); cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); HloComputation* condition = - module_.AddEmbeddedComputation(cond_builder.Build()); + module_->AddEmbeddedComputation(cond_builder.Build()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + HloInstruction::CreateConstant(Literal::CreateR0(3.0))); auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({constant1})); auto tuple2 = @@ -1024,7 +1034,7 @@ TEST_P(HloDataflowAnalysisTest, TupleSelectToWhile) { auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple->shape(), condition, body, tuple)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -1062,11 +1072,11 @@ TEST_P(HloDataflowAnalysisTest, BitcastDefinesValue) { // Test the bitcast_defines_value flag to the dataflow analysis. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary( scalar_shape_, HloOpcode::kBitcast, constant)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); { @@ -1102,7 +1112,7 @@ TEST_P(HloDataflowAnalysisTest, TupleCopy) { builder.AddInstruction(HloInstruction::CreateTuple({param0, param1})); auto copy = builder.AddInstruction( HloInstruction::CreateUnary(tuple->shape(), HloOpcode::kCopy, tuple)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -1126,6 +1136,352 @@ TEST_P(HloDataflowAnalysisTest, TupleCopy) { analysis.GetValueDefinedAt(copy, /*index=*/{}).live_out_of_module()); } +TEST_P(HloDataflowAnalysisTest, ElementwiseChainInterference) { + // A simple chain of elementwise operations. No values should interfere. + // + // param --> negate -> exp -> log + // + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, vector_shape_, "param")); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, negate)); + auto log = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kLog, exp)); + + module_->AddEntryComputation(builder.Build()); + RunAnalysis(GetParam()); + + DependencyHloOrdering ordering(module_.get()); + + // No values should interfere. + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, log)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, exp)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, log)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, log)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, log, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, log, exp)); + + // Values should interfere with itself. + EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, exp)); +} + +TEST_P(HloDataflowAnalysisTest, MultipleEntryParameters_Sequential) { + // Two entry params, which interfere with each other. + // + // param0 --> negate ---------------\ + // param1 --> exp --> add + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, vector_shape_, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, vector_shape_, "param1")); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param0)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param1)); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + vector_shape_, HloOpcode::kAdd, negate, exp)); + + auto entry = module_->AddEntryComputation(builder.Build()); + RunAnalysis(GetParam()); + + SequentialHloOrdering::HloModuleSequence sequence; + sequence.insert({entry, {param0, negate, param1, exp, add}}); + SequentialHloOrdering ordering(module_.get(), sequence); + + // Entry parameters interfere as if they are defined simultaneously at + // the very beginning. + EXPECT_TRUE(InstructionsMayInterfere(ordering, param0, param1)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, exp)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, add)); + EXPECT_TRUE(InstructionsMayInterfere(ordering, param1, param0)); + EXPECT_TRUE(InstructionsMayInterfere(ordering, param1, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param1, exp)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param1, add)); + + // Negate and exp still interfere. + EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp)); + EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate)); + + // But {negate, add} and {exp, add} don't interfere. + EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp)); +} + +TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) { + // Similar to MultipleEntryParameters_Sequential, but the parameter is of + // while body computation. Body computation in the sequential order: + // + // %constant = Constant(...) + // %exp = Exp(%constant) + // %param = Param(0) + // %add = Add(%param, %exp) ;; Root of body + // %dead_constant = Constant(...) + // %dead_negate = Negate(%dead_constant) + // + // %constant and its only use %exp are ordered before 'param'. However, the + // %constant and %param values still interfere because the parameter is + // considered live into the while body. + // + // Similarly, %dead_constant and %dead_negate are ordered after the root of + // the body computation %add. However, %add is liveout of the computation so + // %dead_constant and %add interfere. + auto body_builder = HloComputation::Builder(TestName()); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "body_param")); + auto constant = body_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto exp = body_builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kExp, constant)); + auto add = body_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, exp, body_param)); + auto dead_constant = body_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto dead_negate = body_builder.AddInstruction(HloInstruction::CreateUnary( + scalar_shape_, HloOpcode::kNegate, dead_constant)); + HloComputation* body = module_->AddEmbeddedComputation( + body_builder.Build(/*root_instruction=*/add)); + + auto cond_builder = HloComputation::Builder("condition"); + auto cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "cond_param")); + auto cond_constant = cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloComputation* condition = + module_->AddEmbeddedComputation(cond_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + auto xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(scalar_shape_, condition, body, param)); + + auto entry = module_->AddEntryComputation(builder.Build()); + bool ssa_form = GetParam(); + const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + + SequentialHloOrdering::HloModuleSequence sequence; + sequence.insert({entry, {param, xla_while}}); + sequence.insert({condition, {cond_param, cond_constant}}); + // Construct the order such that 'constant' and its use 'exp' are before + // body_param. + sequence.insert({body, {constant, exp, body_param, add}}); + + SequentialHloOrdering ordering(module_.get(), sequence); + + // 'add' is the body root even though later instructions follow in the order + // like 'dead_negate'. Only 'add' should be live out of the computation. + EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); + EXPECT_FALSE( + analysis.GetValueDefinedAt(dead_negate).live_out_of_computation()); + + // 'add' is live out of the body and will interfere with an later instructions + // such as 'dead_constant' and 'dead_negate'. + EXPECT_TRUE(InstructionsMayInterfere(ordering, add, dead_constant)); + EXPECT_TRUE(InstructionsMayInterfere(ordering, add, dead_negate)); + + // The remaining checks test phi values defined by body and condition + // parameters which only occur in the SSA form of the analysis. + if (ssa_form) { + // Though the ordering suggests 'constant' and 'param' should not interfere, + // 'param' is live in and thus interferes with any earlier instruction of + // the computation in the order (eg 'constant')' + EXPECT_TRUE(InstructionsMayInterfere(ordering, body_param, constant)); + EXPECT_TRUE(InstructionsMayInterfere(ordering, body_param, exp)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, add)); + + // The following values end up in the same buffer: + // (1) the init value: 'param' + // (2) the body parameter: 'body_param' + // (3) the condition parameter: 'cond_param' + // (4) the root value of the while body: 'add' + // (5) the while value: 'xla_while' + // None should interfere. + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, body_param)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, cond_param)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, xla_while)); + + EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, cond_param)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, add)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, xla_while)); + + EXPECT_FALSE(InstructionsMayInterfere(ordering, cond_param, add)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, cond_param, xla_while)); + + EXPECT_FALSE(InstructionsMayInterfere(ordering, add, xla_while)); + } +} + +TEST_P(HloDataflowAnalysisTest, NonElementwiseOperand) { + // A chain of operations with two elementwise and one non-elementwise. The + // elementwise op should not interfere with its operand, while the + // non-elementwise op should interfere. Entry params always interfere. + // + // param --> exp -> negate -> reverse + // + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, vector_shape_, "param")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param)); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, exp)); + auto reverse = builder.AddInstruction( + HloInstruction::CreateReverse(vector_shape_, negate, {0})); + + module_->AddEntryComputation(builder.Build()); + RunAnalysis(GetParam()); + + DependencyHloOrdering ordering(module_.get()); + + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, reverse)); + + // Negate is elementwise, so doesn't interfere with its operand. + // Reverse is non-elementwise, so does interfere with its operand. + EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, negate)); + EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, reverse)); +} + +TEST_P(HloDataflowAnalysisTest, OverlappedValues) { + // Verify simultaneously live values interfere (exp and negate). + // + // param --> negate -> add + // \---> exp -----/ + // + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, vector_shape_, "param")); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param)); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + vector_shape_, HloOpcode::kAdd, negate, exp)); + + module_->AddEntryComputation(builder.Build()); + RunAnalysis(GetParam()); + + DependencyHloOrdering ordering(module_.get()); + + EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate)); + EXPECT_TRUE(InstructionsMayInterfere(ordering, param, exp)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add)); + + // Negate and exp interfere with each other, but not with add. + EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp)); + EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp)); +} + +TEST_P(HloDataflowAnalysisTest, OverlappedValuesSequentialOrder) { + // Identical to the test OverlappedValue but using a sequential ordering of + // HLO instructions. + // + // param --> negate -> add + // \---> exp -----/ + // + // Sequential order: + // param, negate, exp, add + // + // Liveness is identical to the DependencyHloOrdering. + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, vector_shape_, "param")); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param)); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + vector_shape_, HloOpcode::kAdd, negate, exp)); + + auto entry = module_->AddEntryComputation(builder.Build()); + RunAnalysis(GetParam()); + + SequentialHloOrdering::HloModuleSequence sequence; + std::vector order = {param, negate, exp, add}; + sequence.emplace(entry, order); + + SequentialHloOrdering ordering(module_.get(), sequence); + + EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add)); + + // Negate and exp interfere with each other, but not with add. + EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp)); + EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp)); +} + +TEST_P(HloDataflowAnalysisTest, EmbeddedComputationInterference) { + // Test MayInterfere() for embedded computation, specifically the interference + // of values in different computations. + // + // embedded_computation: + // %embedded_param = Param(0) + // %embedded_log = Log(%embedded_param) + // + // entry computation: + // %param = Param(0) + // %negate = Negate(%param) + // %exp = Negate(%exp) + // %call = Call(embedded_computation, {%exp}) + // %add = Add(%negate, %call) + // + // Note %negate is live across the call and should interfere with all values + // in the embedded computation. + auto embedded_builder = HloComputation::Builder(TestName() + "_embedded"); + auto embedded_param = embedded_builder.AddInstruction( + HloInstruction::CreateParameter(0, vector_shape_, "embedded_param")); + auto embedded_log = + embedded_builder.AddInstruction(HloInstruction::CreateUnary( + vector_shape_, HloOpcode::kLog, embedded_param)); + auto embedded_computation = + module_->AddEmbeddedComputation(embedded_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, vector_shape_, "param")); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param)); + auto call = builder.AddInstruction( + HloInstruction::CreateCall(vector_shape_, {exp}, embedded_computation)); + builder.AddInstruction(HloInstruction::CreateBinary( + vector_shape_, HloOpcode::kAdd, negate, call)); + module_->AddEntryComputation(builder.Build()); + RunAnalysis(GetParam()); + + DependencyHloOrdering ordering(module_.get()); + + // Exp only use is the call so it should not interfere with values inside the + // embedded computation. + EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, embedded_log)); + + // Negate is live across the call and should interfere with values in the + // embedded computation + EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, embedded_log)); +} + INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation, HloDataflowAnalysisTest, ::testing::Values(false, true)); diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index 10cd7ca7c0990ab553c865da01b00475382316e2..704b8dfca700f7c4a00689593aea9743de1f817c 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -45,9 +45,9 @@ TEST_F(HloDceTest, NoDeadCode) { // Verify that no dead code is removed from a computation with no dead code. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(123.0f))); builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); @@ -98,9 +98,9 @@ TEST_F(HloDceTest, ControlDependencies) { // Verify that instructions with control dependencies are not removed. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(123.0f))); // Create two dead instructions: a negate and an add. auto dead_negate = builder.AddInstruction(HloInstruction::CreateUnary( diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 3e7f5b1f3d97ace48fbc22b224667acebcc52093..a0c5cbe916050a8aa7849c3e37daad70bc8d6190 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -31,11 +31,13 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_query.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/bitmap.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" @@ -89,11 +91,11 @@ StatusOr> Compare(const Shape& shape, HloOpcode opcode, << HloOpcodeString(opcode); } - auto result = LiteralUtil::CreateFromShape(shape); - TF_RETURN_IF_ERROR(LiteralUtil::Populate( - result.get(), [&](tensorflow::gtl::ArraySlice multi_index) { - return compare_op(LiteralUtil::Get(lhs_literal, multi_index), - LiteralUtil::Get(rhs_literal, multi_index)); + auto result = Literal::CreateFromShape(shape); + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + return compare_op(lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index)); })); return std::move(result); @@ -117,12 +119,11 @@ StatusOr> ElementWiseUnaryOpImpl( ShapeUtil::HumanString(operand->shape()).c_str()); } - auto result = LiteralUtil::CreateFromShape(shape); + auto result = Literal::CreateFromShape(shape); - TF_RETURN_IF_ERROR(LiteralUtil::Populate( - result.get(), [&](tensorflow::gtl::ArraySlice multi_index) { - return unary_op( - LiteralUtil::Get(operand_literal, multi_index)); + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + return unary_op(operand_literal.Get(multi_index)); })); return std::move(result); } @@ -168,6 +169,23 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return HandleAbs(abs, operand); }; + Status HandleBroadcast(HloInstruction* broadcast) override { + parent_->evaluated_[broadcast] = + Literal::CreateFromShape(broadcast->shape()); + auto output = parent_->evaluated_[broadcast].get(); + auto operand_to_broadcast = + parent_->GetEvaluatedLiteralFor(broadcast->operand(0)); + std::vector broadcast_indices( + ShapeUtil::Rank(broadcast->operand(0)->shape()), 0); + return output->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { + broadcast_indices[i] = multi_index[broadcast->dimensions(i)]; + } + return operand_to_broadcast.Get(broadcast_indices); + }); + } + Status HandleCeil(HloInstruction* ceil, HloInstruction* operand) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[ceil], ElementWiseUnaryOp(ceil, [](ReturnT elem_operand) { @@ -176,7 +194,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; - Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override { + Status HandleCopy(HloInstruction* copy) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[copy], ElementWiseUnaryOp(copy, [](ReturnT elem_operand) { return elem_operand; @@ -184,42 +202,19 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; - template - std::unique_ptr ConvertIfTypesMatch(const Literal& src_literal) { - DCHECK_EQ(src_type, src_literal.shape().element_type()); - return LiteralUtil::Convert< - typename primitive_util::PrimitiveTypeToNative::type, - typename primitive_util::PrimitiveTypeToNative::type>( - src_literal); - } - - Status HandleConvert(HloInstruction* convert, - HloInstruction* operand) override { - auto operand_literal = parent_->GetEvaluatedLiteralFor(operand); - - switch (operand->shape().element_type()) { -#define CONVERT_IF_TYPES_MATCH(src_type) \ - case (src_type): \ - parent_->evaluated_[convert] = LiteralUtil::Convert< \ - typename primitive_util::PrimitiveTypeToNative::type, \ - ReturnT>(operand_literal); \ - break; - CONVERT_IF_TYPES_MATCH(PRED) - CONVERT_IF_TYPES_MATCH(S8) - CONVERT_IF_TYPES_MATCH(S32) - CONVERT_IF_TYPES_MATCH(S64) - CONVERT_IF_TYPES_MATCH(U8) - CONVERT_IF_TYPES_MATCH(U32) - CONVERT_IF_TYPES_MATCH(U64) - CONVERT_IF_TYPES_MATCH(F32) - CONVERT_IF_TYPES_MATCH(F64) -#undef CONVERT_IF_TYPES_MATCH - // Other types are not yet supported. - default: - LOG(FATAL) << "unimplemented operand type for HandleCovert: " - << PrimitiveType_Name(operand->shape().element_type()); + Status HandleConvert(HloInstruction* convert) override { + const HloInstruction* operand = convert->operand(0); + TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); + TF_ASSIGN_OR_RETURN(std::unique_ptr result, + parent_->GetEvaluatedLiteralFor(operand).Convert( + convert->shape().element_type())); + + if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { + parent_->evaluated_[convert] = std::move(result); + } else { + parent_->evaluated_[convert] = + result->Relayout(convert->shape().layout()); } - return Status::OK(); } @@ -322,8 +317,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; - Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs, - HloInstruction* rhs) override { + Status HandleMaximum(HloInstruction* maximum) override { TF_ASSIGN_OR_RETURN( parent_->evaluated_[maximum], ElementWiseBinaryOp(maximum, [](ReturnT lhs, ReturnT rhs) { @@ -332,8 +326,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; - Status HandleMinimum(HloInstruction* minimum, HloInstruction* lhs, - HloInstruction* rhs) override { + Status HandleMinimum(HloInstruction* minimum) override { TF_ASSIGN_OR_RETURN( parent_->evaluated_[minimum], ElementWiseBinaryOp(minimum, [](ReturnT lhs_el, ReturnT rhs_el) { @@ -409,6 +402,258 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; + Status HandleConvolution(HloInstruction* conv, HloInstruction* lhs, + HloInstruction* rhs, const Window& window) override { + CHECK(ShapeUtil::IsArray(lhs->shape())); + CHECK(ShapeUtil::IsArray(rhs->shape())); + CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape())); + CHECK(ShapeUtil::SameElementType(lhs->shape(), conv->shape())); + TF_CHECK_OK(ShapeUtil::ValidateShape(lhs->shape())); + TF_CHECK_OK(ShapeUtil::ValidateShape(rhs->shape())); + + const auto& dnums = conv->convolution_dimension_numbers(); + const int64 num_spatial_dims = dnums.spatial_dimensions_size(); + CHECK_EQ(num_spatial_dims, dnums.kernel_spatial_dimensions_size()); + CHECK_GE(num_spatial_dims, 1); + CHECK_EQ(window.dimensions_size(), num_spatial_dims); + + CHECK_EQ(num_spatial_dims + 2, ShapeUtil::Rank(lhs->shape())); + CHECK_EQ(num_spatial_dims + 2, ShapeUtil::Rank(rhs->shape())); + + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferConvolveShape( + lhs->shape(), rhs->shape(), window, dnums)); + CHECK(ShapeUtil::Compatible(conv->shape(), inferred_return_shape)) + << "return shape set to: " << ShapeUtil::HumanString(conv->shape()) + << " but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + + const auto lhs_rank = ShapeUtil::Rank(lhs->shape()); + const auto rhs_rank = ShapeUtil::Rank(rhs->shape()); + + // Dimension number applicable for both input (lhs), and output. + const int64 batch_dim = dnums.batch_dimension(); + const int64 z_dim = dnums.feature_dimension(); + // Dimension number applicable for kernel (rhs). + const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension(); + const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension(); + + const int64 z_size = ShapeUtil::GetDimension(lhs->shape(), z_dim); + + std::vector window_dimension_sizes; + for (auto i : dnums.kernel_spatial_dimensions()) { + window_dimension_sizes.push_back( + ShapeUtil::GetDimension(rhs->shape(), i)); + } + + const Shape& window_shape = ShapeUtil::MakeShape( + rhs->shape().element_type(), window_dimension_sizes); + + auto result = Literal::CreateFromShape(conv->shape()); + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice out_index) { + ReturnT result_val = static_cast(0); + + std::vector lhs_index(lhs_rank, 0); + std::vector rhs_index(rhs_rank, 0); + + lhs_index[batch_dim] = out_index[batch_dim]; + rhs_index[kernel_output_z_dim] = out_index[z_dim]; + + std::vector rhs_spatial_index( + dnums.kernel_spatial_dimensions_size(), 0); + + // Convolve input feature with kernel. + do { + for (int64 iz = 0; iz < z_size; ++iz) { + lhs_index[z_dim] = iz; + rhs_index[kernel_input_z_dim] = iz; + + // Find corresponding spatial dimension index for input (lhs). + for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) { + // Spatial dimension number for input (lhs) and output. + const int64 spatial_dim = dnums.spatial_dimensions(ki); + + // Calculate lhs (input) index without taking base dilation into + // account. + const int64 undilated_index = + out_index[spatial_dim] * window.dimensions(ki).stride() - + window.dimensions(ki).padding_low() + + rhs_spatial_index[ki] * + window.dimensions(ki).window_dilation(); + // Skip if the lhs (input) index is to be dilated. + if (undilated_index % window.dimensions(ki).base_dilation() != + 0) { + goto cnt; + } + + // Calculate the actual lhs (input) index after dilation. + lhs_index[spatial_dim] = + undilated_index / window.dimensions(ki).base_dilation(); + + // Skip if input index is not in bound. + if (!(lhs_index[spatial_dim] >= 0 && + lhs_index[spatial_dim] < + lhs->shape().dimensions(spatial_dim))) { + goto cnt; + } + + rhs_index[dnums.kernel_spatial_dimensions(ki)] = + rhs_spatial_index[ki]; + } + + result_val += lhs_literal.Get(lhs_index) * + rhs_literal.Get(rhs_index); + } + cnt:; + } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index)); + + return result_val; + })); + + parent_->evaluated_[conv] = std::move(result); + return Status::OK(); + }; + + Status HandleDot(HloInstruction* dot, HloInstruction* lhs, + HloInstruction* rhs) override { + CHECK(ShapeUtil::IsArray(dot->shape())); + CHECK(ShapeUtil::IsArray(lhs->shape())); + CHECK(ShapeUtil::IsArray(rhs->shape())); + + // Dot only supports operands of rank 1 and 2. + const auto dot_rank = ShapeUtil::Rank(dot->shape()); + const auto lhs_rank = ShapeUtil::Rank(lhs->shape()); + const auto rhs_rank = ShapeUtil::Rank(rhs->shape()); + CHECK(lhs_rank > 0 && lhs_rank <= 2); + CHECK(rhs_rank > 0 && rhs_rank <= 2); + CHECK_EQ(dot_rank, lhs_rank + rhs_rank - 2); + + CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape())); + CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape())); + + // Check contracted dimensions are the same. + // + // Determine the index of the contracted dimensions for input tensors. + // dimensions -1 of lhs and dimension 0 of rhs are contracted. + const int64 lhs_contracted_dimension = + ShapeUtil::GetDimensionNumber(lhs->shape(), -1); + const int64 rhs_contracted_dimension = 0; + CHECK_EQ(lhs->shape().dimensions(lhs_contracted_dimension), + rhs->shape().dimensions(rhs_contracted_dimension)) + << "lhs contracted dimension: " + << lhs->shape().dimensions(lhs_contracted_dimension) + << " rhs contracted dimension: " + << rhs->shape().dimensions(rhs_contracted_dimension); + const int64 contracted_dimension_size = + lhs->shape().dimensions(lhs_contracted_dimension); + + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + + auto result = Literal::CreateFromShape(dot->shape()); + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + ReturnT result_val = static_cast(0); + + std::vector lhs_index(lhs_rank, 0); + std::vector rhs_index(rhs_rank, 0); + // Set index for non-contracted dimension for lhs and rhs. + if (lhs_rank > 1) { + lhs_index[0] = multi_index[0]; + } + if (rhs_rank > 1) { + rhs_index[1] = multi_index[multi_index.size() - 1]; + } + + // Accumulates resulting product along the contracted dimension. + for (int64 i = 0; i < contracted_dimension_size; ++i) { + lhs_index[lhs_contracted_dimension] = i; + rhs_index[rhs_contracted_dimension] = i; + + result_val += lhs_literal.Get(lhs_index) * + rhs_literal.Get(rhs_index); + } + + return result_val; + })); + + parent_->evaluated_[dot] = std::move(result); + return Status::OK(); + }; + + Status HandlePad(HloInstruction* pad) override { + CHECK(!ShapeUtil::IsTuple(pad->operand(0)->shape())); + // Padding value must be scalar. + CHECK(ShapeUtil::IsScalar(pad->operand(1)->shape())); + CHECK_EQ(ShapeUtil::Rank(pad->operand(0)->shape()), + pad->padding_config().dimensions_size()); + + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferPadShape( + /*operand_shape=*/pad->operand(0)->shape(), + /*padding_value_shape=*/pad->operand(1)->shape(), + /*padding_config=*/pad->padding_config())); + CHECK(ShapeUtil::Compatible(pad->shape(), inferred_return_shape)) + << "return shape is set to: " << ShapeUtil::HumanString(pad->shape()) + << "but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + + // Create new HLO of padded shape with padding value. + ReturnT scalar = + parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get({}); + auto result = Literal::CreateFromShape(pad->shape()); + TF_RETURN_IF_ERROR(result->Populate( + [&scalar](tensorflow::gtl::ArraySlice multi_index) { + return scalar; + })); + + auto evaluated_operand = parent_->GetEvaluatedLiteralFor(pad->operand(0)); + + std::vector input_index(ShapeUtil::Rank(evaluated_operand.shape()), + 0); + std::vector target_index(ShapeUtil::Rank(result->shape()), 0); + + // Loop through each element of the operand, assign them to the + // corresponding index of the resulting padded literal. + const PaddingConfig& pad_config = pad->padding_config(); + + auto func = [&](const std::vector& input_index) { + for (auto i = 0; i < input_index.size(); ++i) { + // Interior padding occurs logically before edge padding, so in the case + // of negative edge padding elements are removed from the + // interior-padded operand. + target_index[i] = + pad_config.dimensions(i).edge_padding_low() + + input_index[i] * (pad_config.dimensions(i).interior_padding() + 1); + + // Account for negative low and high padding: skip assignment if the + // any target index is out of range. + if (!(target_index[i] >= 0 && + target_index[i] < pad->shape().dimensions(i))) { + return true; + } + } + result->Set(target_index, + evaluated_operand.Get(input_index)); + return true; + }; + + std::vector zero_base(evaluated_operand.shape().dimensions_size(), + 0); + std::vector step(evaluated_operand.shape().dimensions_size(), 1); + + ShapeUtil::ForEachIndex( + evaluated_operand.shape(), zero_base, + AsInt64Slice(evaluated_operand.shape().dimensions()), step, func); + + parent_->evaluated_[pad] = std::move(result); + return Status::OK(); + }; + Status Preprocess(HloInstruction* hlo) override { VLOG(2) << hlo->ToString(); return Status::OK(); @@ -446,12 +691,12 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - auto result = LiteralUtil::CreateFromShape(shape); + auto result = Literal::CreateFromShape(shape); - TF_RETURN_IF_ERROR(LiteralUtil::Populate( - result.get(), [&](tensorflow::gtl::ArraySlice multi_index) { - return binary_op(LiteralUtil::Get(lhs_literal, multi_index), - LiteralUtil::Get(rhs_literal, multi_index)); + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + return binary_op(lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index)); })); return std::move(result); } @@ -483,14 +728,13 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); - auto result = LiteralUtil::CreateFromShape(shape); + auto result = Literal::CreateFromShape(shape); - TF_RETURN_IF_ERROR(LiteralUtil::Populate( - result.get(), [&](tensorflow::gtl::ArraySlice multi_index) { - return ternary_op( - LiteralUtil::Get(lhs_literal, multi_index), - LiteralUtil::Get(rhs_literal, multi_index), - LiteralUtil::Get(ehs_literal, multi_index)); + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + return ternary_op(lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index), + ehs_literal.Get(multi_index)); })); return std::move(result); @@ -552,7 +796,7 @@ StatusOr> HloEvaluator::Evaluate( if (operand->opcode() == HloOpcode::kParameter) { const Literal* input_literal = arg_literals_[operand->parameter_number()]; VLOG(2) << "Parameter operand evaluated to: " - << LiteralUtil::ToString(*input_literal); + << input_literal->ToString(); TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape())); evaluated_[operand] = MakeUnique(*input_literal); @@ -589,8 +833,7 @@ std::unique_ptr HloEvaluator::TryEvaluate( Status HloEvaluator::HandleParameter(HloInstruction* parameter) { VLOG(2) << "HandleParameter: " << parameter->ToString(); const Literal* input_literal = arg_literals_[parameter->parameter_number()]; - VLOG(2) << "Parameter evaluated to: " - << LiteralUtil::ToString(*input_literal); + VLOG(2) << "Parameter evaluated to: " << input_literal->ToString(); DCHECK(ShapeUtil::Equal(parameter->shape(), input_literal->shape())); evaluated_[parameter] = MakeUnique(*input_literal); @@ -606,14 +849,14 @@ Status HloEvaluator::HandleConstant(HloInstruction* constant, Status HloEvaluator::HandleReshape(HloInstruction* reshape) { TF_ASSIGN_OR_RETURN( evaluated_[reshape], - LiteralUtil::Reshape(GetEvaluatedLiteralFor(reshape->operand(0)), - AsInt64Slice(reshape->shape().dimensions()))); + GetEvaluatedLiteralFor(reshape->operand(0)) + .Reshape(AsInt64Slice(reshape->shape().dimensions()))); return Status::OK(); } Status HloEvaluator::HandleTranspose(HloInstruction* transpose) { - evaluated_[transpose] = LiteralUtil::Transpose( - GetEvaluatedLiteralFor(transpose->operand(0)), transpose->dimensions()); + evaluated_[transpose] = GetEvaluatedLiteralFor(transpose->operand(0)) + .Transpose(transpose->dimensions()); return Status::OK(); } @@ -641,16 +884,16 @@ Status HloEvaluator::HandleConcatenate( ShapeUtil::GetDimension(operand_shape, concat_dim); } - auto result_literal = LiteralUtil::CreateFromDimensions( + auto result_literal = Literal::CreateFromDimensions( reference_shape.element_type(), concat_dimensions); DimensionVector source_indices(rank, 0); DimensionVector dest_indices(concat_dimensions.size(), 0); for (auto operand : operands) { const Shape& operand_shape = operand->shape(); - TF_RETURN_IF_ERROR(LiteralUtil::Copy( - GetEvaluatedLiteralFor(operand), source_indices, result_literal.get(), - dest_indices, AsInt64Slice(operand_shape.dimensions()))); + TF_RETURN_IF_ERROR(result_literal->Copy( + GetEvaluatedLiteralFor(operand), source_indices, dest_indices, + AsInt64Slice(operand_shape.dimensions()))); dest_indices[concat_dim] += ShapeUtil::GetDimension(operand_shape, concat_dim); } @@ -775,14 +1018,14 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare, HloOpcode opcode, Status HloEvaluator::HandleSlice(HloInstruction* slice, HloInstruction* operand) { const Shape& shape = slice->shape(); - auto literal = LiteralUtil::CreateFromDimensions( + auto literal = Literal::CreateFromDimensions( shape.element_type(), AsInt64Slice(shape.dimensions())); DimensionVector dest_indices(slice->slice_starts().size(), 0); - TF_RETURN_IF_ERROR(LiteralUtil::Copy( - GetEvaluatedLiteralFor(operand), slice->slice_starts(), literal.get(), - dest_indices, AsInt64Slice(shape.dimensions()))); + TF_RETURN_IF_ERROR(literal->Copy(GetEvaluatedLiteralFor(operand), + slice->slice_starts(), dest_indices, + AsInt64Slice(shape.dimensions()))); evaluated_[slice] = std::move(literal); return Status::OK(); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 91fd56f54c592b8bbe68f6b38e761e1f10a20c8b..976a2325ea970f570748a6872d7bf2459f8ffa4a 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -92,7 +92,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { return hlo->Visit(typed_visitors_.at(hlo->shape().element_type()).get()); } - // Operations that are type-agnostic. + // Operations that are type-agnostic or always return a specific type, such as + // HandleIsFinite where boolean is always returned. // Status HandleParameter(HloInstruction* parameter) override; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index b26ece28b756097b06b4a04d4873775e13760014..626bd3b02b1d2c8cafed196cdf82f05d24017516 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -14,27 +14,33 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include #include #include +#include #include #include +#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace { -class HloEvaluatorTest : public ::testing::Test { +class HloEvaluatorTest : public HloTestBase { protected: HloEvaluatorTest() { evaluator_ = MakeUnique(); } @@ -44,9 +50,9 @@ class HloEvaluatorTest : public ::testing::Test { // Verifies that HloEvaluator evaluates a HLO instruction that performs clamp // with 3 operands. TEST_F(HloEvaluatorTest, DoesClamp) { - auto low = LiteralUtil::CreateR2({{0.f, 2.f}, {2.f, 4.f}}); - auto high = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); - auto value = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); + auto low = Literal::CreateR2({{0.f, 2.f}, {2.f, 4.f}}); + auto high = Literal::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); + auto value = Literal::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); Shape shape = low->shape(); auto c1 = HloInstruction::CreateConstant(std::move(low)); @@ -58,17 +64,17 @@ TEST_F(HloEvaluatorTest, DoesClamp) { std::unique_ptr result = evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); - auto expected = LiteralUtil::CreateR2({{0, 4}, {2, 4}}); + auto expected = Literal::CreateR2({{0, 4}, {2, 4}}); - EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + LiteralTestUtil::ExpectEqual(*expected, *result); } // Verifies that HloEvaluator evaluates a HLO instruction that performs select // with 3 operands. TEST_F(HloEvaluatorTest, DoesSelect) { - auto pred = LiteralUtil::CreateR2({{true, false}, {false, true}}); - auto on_true = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); - auto on_false = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); + auto pred = Literal::CreateR2({{true, false}, {false, true}}); + auto on_true = Literal::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); + auto on_false = Literal::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); Shape shape = on_true->shape(); auto c1 = HloInstruction::CreateConstant(std::move(pred)); @@ -80,16 +86,16 @@ TEST_F(HloEvaluatorTest, DoesSelect) { std::unique_ptr result = evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); - auto expected = LiteralUtil::CreateR2({{2, 5}, {0, 4}}); + auto expected = Literal::CreateR2({{2, 5}, {0, 4}}); - EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + LiteralTestUtil::ExpectEqual(*expected, *result); } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise addition with 2 operands. TEST_F(HloEvaluatorTest, DoesAdd) { - auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); - auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); + auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); auto c1 = HloInstruction::CreateConstant(std::move(lhs)); @@ -100,16 +106,16 @@ TEST_F(HloEvaluatorTest, DoesAdd) { std::unique_ptr result = evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); - auto expected = LiteralUtil::CreateR2({{3, 4}, {-96, 8}}); + auto expected = Literal::CreateR2({{3, 4}, {-96, 8}}); - EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + LiteralTestUtil::ExpectEqual(*expected, *result); } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise divide with 2 operands. TEST_F(HloEvaluatorTest, DoesDivide) { - auto lhs_s64 = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); - auto rhs_s64 = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); + auto lhs_s64 = Literal::CreateR2({{1, 0}, {-100, 4}}); + auto rhs_s64 = Literal::CreateR2({{2, 4}, {4, 4}}); Shape shape_s64 = ShapeUtil::MakeShape(S64, {2, 2}); auto c1_s64 = HloInstruction::CreateConstant(std::move(lhs_s64)); @@ -120,12 +126,12 @@ TEST_F(HloEvaluatorTest, DoesDivide) { std::unique_ptr result = evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); - auto expected = LiteralUtil::CreateR2({{0, 0}, {-25, 1}}); + auto expected = Literal::CreateR2({{0, 0}, {-25, 1}}); - EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + LiteralTestUtil::ExpectEqual(*expected, *result); - auto lhs_f64 = LiteralUtil::CreateR2({{1.0, 0.0}, {-100.0, 4.0}}); - auto rhs_f64 = LiteralUtil::CreateR2({{2.2, 4.0}, {4.0, 4.0}}); + auto lhs_f64 = Literal::CreateR2({{1.0, 0.0}, {-100.0, 4.0}}); + auto rhs_f64 = Literal::CreateR2({{2.2, 4.0}, {4.0, 4.0}}); Shape shape_f64 = ShapeUtil::MakeShape(F64, {2, 2}); auto c1_f64 = HloInstruction::CreateConstant(std::move(lhs_f64)); @@ -135,16 +141,15 @@ TEST_F(HloEvaluatorTest, DoesDivide) { result = evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); - expected = - LiteralUtil::CreateR2({{0.45454545454545453, 0}, {-25, 1}}); + expected = Literal::CreateR2({{0.45454545454545453, 0}, {-25, 1}}); - EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + LiteralTestUtil::ExpectEqual(*expected, *result); } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise abs op with 1 operand. TEST_F(HloEvaluatorTest, DoesAbs) { - auto operand = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); + auto operand = Literal::CreateR2({{1, -20}, {-100, 4}}); const Shape& shape = ShapeUtil::MakeShape(S64, {2, 2}); auto c1 = HloInstruction::CreateConstant(std::move(operand)); auto instruction = @@ -153,42 +158,40 @@ TEST_F(HloEvaluatorTest, DoesAbs) { std::unique_ptr result = evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); - auto expected = LiteralUtil::CreateR2({{1, 20}, {100, 4}}); + auto expected = Literal::CreateR2({{1, 20}, {100, 4}}); - EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + LiteralTestUtil::ExpectEqual(*expected, *result); // For R0 literal. const Shape& r0 = ShapeUtil::MakeShape(F32, {}); - operand = LiteralUtil::CreateR0(-1.0f); + operand = Literal::CreateR0(-1.0f); c1 = HloInstruction::CreateConstant(std::move(operand)); instruction = HloInstruction::CreateUnary(r0, HloOpcode::kAbs, c1.get()); result = evaluator_->Evaluate(instruction.get()).ConsumeValueOrDie(); - expected = LiteralUtil::CreateR0(1.0f); + expected = Literal::CreateR0(1.0f); - EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + LiteralTestUtil::ExpectEqual(*expected, *result); // For R1 literal with dimension of size 0. Shape empty_r1 = ShapeUtil::MakeShape(F32, {0}); - operand = LiteralUtil::CreateR1({}); + operand = Literal::CreateR1({}); c1 = HloInstruction::CreateConstant(std::move(operand)); instruction = HloInstruction::CreateUnary(empty_r1, HloOpcode::kAbs, c1.get()); result = evaluator_->Evaluate(instruction.get()).ConsumeValueOrDie(); - expected = LiteralUtil::CreateR1({}); + expected = Literal::CreateR1({}); - EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + LiteralTestUtil::ExpectEqual(*expected, *result); } // namespace // Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor // constant operands. -TEST_F(HloEvaluatorTest, DoesTraveseInstructions) { - HloComputation::Builder builder( - ::testing::UnitTest::GetInstance()->current_test_info()->name()); - - auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); - auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); - auto rhs2 = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); +TEST_F(HloEvaluatorTest, DoesTraverseInstructions) { + HloComputation::Builder builder(TestName()); + auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); + auto rhs2 = Literal::CreateR2({{1, -20}, {-100, 4}}); std::vector args = {lhs.get(), rhs.get(), rhs2.get()}; Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); @@ -206,21 +209,19 @@ TEST_F(HloEvaluatorTest, DoesTraveseInstructions) { std::unique_ptr result = evaluator_->Evaluate(builder.Build().get(), args).ConsumeValueOrDie(); - auto expected = LiteralUtil::CreateR2({{4, -16}, {-196, 12}}); + auto expected = Literal::CreateR2({{4, -16}, {-196, 12}}); - EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + LiteralTestUtil::ExpectEqual(*expected, *result); } // Verifies Reshape operation is correctly evaluated. TEST_F(HloEvaluatorTest, DoesReshape) { - HloComputation::Builder builder( - ::testing::UnitTest::GetInstance()->current_test_info()->name()); - + HloComputation::Builder builder(TestName()); const int64 dimensions[] = {11, 8, 7, 5, 9}; TF_ASSIGN_OR_ASSERT_OK(auto literal, LiteralTestUtil::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); - auto literal_clone = LiteralUtil::CloneToUnique(*literal); + auto literal_clone = literal->CloneToUnique(); HloInstruction* literal_instruction = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -233,13 +234,717 @@ TEST_F(HloEvaluatorTest, DoesReshape) { evaluator_->Evaluate(builder.Build().get(), {}).ConsumeValueOrDie(); using NativeT = typename primitive_util::PrimitiveTypeToNative::type; - LiteralUtil::EachCell( - *result, [&](tensorflow::gtl::ArraySlice indices, NativeT value) { + result->EachCell( + [&](tensorflow::gtl::ArraySlice indices, NativeT value) { std::vector rindexes = Permute(permutation, indices); - EXPECT_TRUE(value == - LiteralUtil::Get(*literal_clone, rindexes)); + EXPECT_TRUE(value == literal_clone->Get(rindexes)); }); } +// Verifies Broadcast operation is correctly evaluated. +TEST_F(HloEvaluatorTest, DoesBroadcast) { + HloComputation::Builder builder(TestName()); + auto input_literal = Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}); + auto output_literal = Literal::CreateR3( + {{{1, 2}, {3, 4}, {5, 6}}, {{1, 2}, {3, 4}, {5, 6}}}); + HloInstruction* literal_instruction = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(input_literal))); + + builder.AddInstruction(HloInstruction::CreateBroadcast( + output_literal->shape(), literal_instruction, {1, 2})); + + std::unique_ptr result = + evaluator_->Evaluate(builder.Build().get(), {}).ConsumeValueOrDie(); + + LiteralTestUtil::ExpectEqual(*result, *output_literal); +} + +TEST_F(HloEvaluatorTest, ConvertWithSameLayout) { + HloComputation::Builder builder(TestName()); + + auto input_literal = Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}); + auto expected = + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}); + ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(), + expected->shape())); + + HloInstruction* constant = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(input_literal))); + builder.AddInstruction( + HloInstruction::CreateConvert(expected->shape(), constant)); + + std::unique_ptr result = + evaluator_->Evaluate(builder.Build().get(), {}).ConsumeValueOrDie(); + + LiteralTestUtil::ExpectEqual(*result, *expected); +} + +TEST_F(HloEvaluatorTest, ConvertWithDifferentLayout) { + HloComputation::Builder builder(TestName()); + + auto input_literal = Literal::CreateR2WithLayout( + {{1, 2}, {3, 4}, {5, 6}}, LayoutUtil::MakeLayout({0, 1})); + auto expected = Literal::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}, LayoutUtil::MakeLayout({1, 0})); + ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(), + expected->shape())); + + HloInstruction* constant = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(input_literal))); + builder.AddInstruction( + HloInstruction::CreateConvert(expected->shape(), constant)); + + std::unique_ptr result = + evaluator_->Evaluate(builder.Build().get(), {}).ConsumeValueOrDie(); + + LiteralTestUtil::ExpectEqual(*result, *expected); +} + +PaddingConfig CreatePaddingConfig( + std::initializer_list> padding_dimensions) { + PaddingConfig padding_config; + + for (auto& paddings_per_dim : padding_dimensions) { + auto dimension = padding_config.add_dimensions(); + dimension->set_edge_padding_low(paddings_per_dim[0]); + dimension->set_edge_padding_high(paddings_per_dim[1]); + dimension->set_interior_padding(paddings_per_dim[2]); + } + return padding_config; +} + +TEST_F(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { + auto operand = Literal::CreateR2({{}, {}}); + auto operand_instruction = HloInstruction::CreateConstant(std::move(operand)); + + constexpr int32 kPadValue = 10; + auto pad_value = Literal::CreateR0(kPadValue); + auto padding_value_instruction = + HloInstruction::CreateConstant(std::move(pad_value)); + + auto padding_config = CreatePaddingConfig({{{1, 0, 2}}, {{0, 2, 1}}}); + Shape shape = ShapeUtil::MakeShape(S32, {5, 2}); + auto pad_instruction = HloInstruction::CreatePad( + shape, operand_instruction.get(), padding_value_instruction.get(), + padding_config); + + auto result = evaluator_->Evaluate(pad_instruction.get()).ConsumeValueOrDie(); + + auto expected = Literal::CreateR2( + {{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}}); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + +TEST_F(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { + HloComputation::Builder b(TestName()); + + Array4D input_array(3, 2, 1, 1, {1, 2, 3, 4, 5, 6}); + auto input = Literal::CreateR4FromArray4D(input_array); + HloInstruction* input_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(input))); + constexpr float kPadValue = 1.5; + auto pad_value = Literal::CreateR0(kPadValue); + HloInstruction* pad_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(pad_value))); + + Shape shape = ShapeUtil::MakeShape(F32, {8, 5, 1, 1}); + auto r4_padding_on_dim0_dim1 = + CreatePaddingConfig({{{1, 0, 2}}, {{0, 2, 1}}, {{0, 0, 0}}, {{0, 0, 0}}}); + b.AddInstruction(HloInstruction::CreatePad( + shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1)); + + std::unique_ptr result = + evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + + auto expected_array = MakeUnique>(8, 5, 1, 1); + expected_array->Fill(kPadValue); + (*expected_array)(1, 0, 0, 0) = 1.0f; + (*expected_array)(1, 2, 0, 0) = 2.0f; + (*expected_array)(4, 0, 0, 0) = 3.0f; + (*expected_array)(4, 2, 0, 0) = 4.0f; + (*expected_array)(7, 0, 0, 0) = 5.0f; + (*expected_array)(7, 2, 0, 0) = 6.0f; + + auto expected = Literal::CreateR4FromArray4D(*expected_array); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + +TEST_F(HloEvaluatorTest, NegativePadding2D) { + HloComputation::Builder b(TestName()); + + // input_array: + // f32[4,3] { + // { 1, 2, 3 }, + // { 5, 6, 7 }, + // { 9, 10, 11 }, + // { 13, 14, 15 }, + // } + auto input_array = MakeUnique>(4, 3); + input_array->FillUnique(1.0f); + auto input = Literal::CreateR2FromArray2D(*input_array); + HloInstruction* input_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(input))); + + auto pad_value_instruction = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(2.718f))); + + auto r2_padding_on_dim0_dim1 = + CreatePaddingConfig({{{-1, -2, 0}}, {{-2, 4, 0}}}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 5}); + b.AddInstruction(HloInstruction::CreatePad(shape, input_instruction, + pad_value_instruction, + r2_padding_on_dim0_dim1)); + + std::unique_ptr result = + evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + + // f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 } + auto expected_array = MakeUnique>(1, 5); + (*expected_array)(0, 0) = 7.0f; + (*expected_array)(0, 1) = 2.718f; + (*expected_array)(0, 2) = 2.718f; + (*expected_array)(0, 3) = 2.718f; + (*expected_array)(0, 4) = 2.718f; + auto expected = Literal::CreateR2FromArray2D(*expected_array); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + +TEST_F(HloEvaluatorTest, NegativeAndInteriorPadding2D) { + HloComputation::Builder b(TestName()); + + // f32[4,3] { + // { 1, 2, 3 }, + // { 5, 6, 7 }, + // { 9, 10, 11 }, + // { 13, 14, 15 }, + // } + auto input_array = MakeUnique>(4, 3); + input_array->FillUnique(1.0f); + auto input = Literal::CreateR2FromArray2D(*input_array); + HloInstruction* input_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(input))); + + auto pad_value_instruction = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(2.718f))); + + PaddingConfig padding_config = MakeNoPaddingConfig(2); + + // Negative padding that results in zero dimensions. + auto r2_padding_on_dim0_dim1 = + CreatePaddingConfig({{{-2, -5, 1}}, {{-2, 4, 2}}}); + + Shape shape = ShapeUtil::MakeShape(F32, {0, 9}); + b.AddInstruction(HloInstruction::CreatePad(shape, input_instruction, + pad_value_instruction, + r2_padding_on_dim0_dim1)); + + std::unique_ptr result = + evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + + auto expected_array = MakeUnique>(0, 9); + auto expected = Literal::CreateR2FromArray2D(*expected_array); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + +TEST_F(HloEvaluatorTest, DotRank2AndRank1) { + HloComputation::Builder b(TestName()); + + // lhs: + // f32[4,1] { + // { 1 }, + // { 2 }, + // { 3 }, + // { 4 }, + // } + auto lhs_array = MakeUnique>(4, 1); + lhs_array->FillUnique(1.0f); + auto lhs_literal = Literal::CreateR2FromArray2D(*lhs_array); + HloInstruction* lhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); + + // rhs: + // f32[2] { 1, 2 }, + auto rhs_literal = Literal::CreateR2({{1, 2}}); + HloInstruction* rhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); + + Shape shape = ShapeUtil::MakeShape(F32, {4, 2}); + b.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, lhs_instruction, rhs_instruction)); + + std::unique_ptr result = + evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + + // clang-format off + auto expected_array = Array2D({ + {1.f, 2.f}, + {2.f, 4.f}, + {3.f, 6.f}, + {4.f, 8.f}, + }); + // clang-format on + auto expected = Literal::CreateR2FromArray2D(expected_array); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + +TEST_F(HloEvaluatorTest, DotRank1AndRank2) { + HloComputation::Builder b(TestName()); + + // lhs: + // f32[3] + // { 1, 2, 3 }, + auto lhs_literal = Literal::CreateR1({1, 2, 3}); + HloInstruction* lhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); + + // rhs: + // f32[3,2] { + // { 1, 2 }, + // { 3, 4 }, + // { 5, 6 }, + // } + auto rhs_array = MakeUnique>(3, 2); + rhs_array->FillUnique(1.0f); + auto rhs_literal = Literal::CreateR2FromArray2D(*rhs_array); + HloInstruction* rhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); + + Shape shape = ShapeUtil::MakeShape(F32, {2}); + b.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, lhs_instruction, rhs_instruction)); + + std::unique_ptr result = + evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + + auto expected = Literal::CreateR1({22.f, 28.f}); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + +TEST_F(HloEvaluatorTest, DotRank2AndRank2) { + HloComputation::Builder b(TestName()); + + // lhs: + // f32[4,3] { + // { 1, 2, 3 }, + // { 5, 6, 7 }, + // { 9, 10, 11 }, + // { 13, 14, 15 }, + // } + auto lhs_array = MakeUnique>(4, 3); + lhs_array->FillUnique(1.0f); + auto lhs_literal = Literal::CreateR2FromArray2D(*lhs_array); + HloInstruction* lhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); + + // rhs: + // f32[3,2] { + // { 1, 2 }, + // { 3, 4 }, + // { 5, 6 }, + // } + auto rhs_array = MakeUnique>(3, 2); + rhs_array->FillUnique(1.0f); + auto rhs_literal = Literal::CreateR2FromArray2D(*rhs_array); + HloInstruction* rhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); + + Shape shape = ShapeUtil::MakeShape(F32, {4, 2}); + b.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, lhs_instruction, rhs_instruction)); + + std::unique_ptr result = + evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + + auto expected_array = Array2D({ + {22.f, 28.f}, {58.f, 76.f}, {94.f, 124.f}, {130.f, 172.f}, + }); + auto expected = Literal::CreateR2FromArray2D(expected_array); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + +TEST_F(HloEvaluatorTest, SimpleConv1D) { + HloComputation::Builder b(TestName()); + + Array3D lhs_array = {{{1, 2, 3}}}; + auto lhs_literal = Literal::CreateR3FromArray3D(lhs_array); + HloInstruction* lhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); + + Array3D rhs_array = {{{3.f, 4.f}}}; + auto rhs_literal = Literal::CreateR3FromArray3D(rhs_array); + HloInstruction* rhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); + + Window window; + WindowDimension dim; + dim.set_size(2); + dim.set_stride(1); + dim.set_padding_low(0); + dim.set_padding_high(1); + dim.set_window_dilation(1); + dim.set_base_dilation(1); + *window.add_dimensions() = dim; + + ConvolutionDimensionNumbers dnums; + dnums.set_batch_dimension(0); + dnums.set_feature_dimension(1); + dnums.add_spatial_dimensions(2); + + dnums.set_kernel_output_feature_dimension(0); + dnums.set_kernel_input_feature_dimension(1); + dnums.add_kernel_spatial_dimensions(2); + + const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 3}); + b.AddInstruction(HloInstruction::CreateConvolve( + shape, lhs_instruction, rhs_instruction, window, dnums)); + + std::unique_ptr result = + evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + + Array3D expected_array = {{{11.f, 18.f, 9.f}}}; + auto expected = Literal::CreateR3FromArray3D(expected_array); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + +TEST_F(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { + HloComputation::Builder b(TestName()); + + Array4D lhs_array(1, 1, 4, 4); + // clang-format off + lhs_array.FillWithYX(Array2D({ + {1, 2, 3, 4 }, + {5, 6, 7, 8 }, + {9, 10, 11, 12}, + {13, 14, 15, 16}, + })); + // clang-format on + auto lhs_literal = Literal::CreateR4FromArray4D(lhs_array); + HloInstruction* lhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); + + Array4D rhs_array(1, 1, 2, 2); + // clang-format off + rhs_array.FillWithYX(Array2D({ + {5, 6}, + {7, 8}, + })); + // clang-format on + auto rhs_literal = Literal::CreateR4FromArray4D(rhs_array); + HloInstruction* rhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); + + Window window; + WindowDimension dim; + dim.set_size(2); + dim.set_stride(1); + dim.set_padding_low(0); + dim.set_padding_high(1); + dim.set_window_dilation(1); + dim.set_base_dilation(1); + *window.add_dimensions() = dim; + *window.add_dimensions() = dim; + + ConvolutionDimensionNumbers dnums = + ComputationBuilder::CreateDefaultConvDimensionNumbers(2); + + const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); + b.AddInstruction(HloInstruction::CreateConvolve( + shape, lhs_instruction, rhs_instruction, window, dnums)); + + std::unique_ptr result = + evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + + Array4D expected_array(1, 1, 4, 4); + // clang-format off + expected_array.FillWithYX(Array2D({ + {100, 126, 152, 76}, + {204, 230, 256, 124}, + {308, 334, 360, 172}, + {149, 160, 171, 80}, + })); + // clang-format on + auto expected = Literal::CreateR4FromArray4D(expected_array); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + +TEST_F(HloEvaluatorTest, Conv2DGeneralDimensions) { + HloComputation::Builder b(TestName()); + + // clang-format off + // Input dimensions: [feature=2, height=3, batch=1, width=4] + Array4D input({ + {{{1, 2, 3, 4}}, + {{5, 6, 7, 8}}, + {{9, 10, 11, 12}}}, + {{{13, 14, 15, 16}}, + {{17, 18, 19, 20}}, + {{21, 22, 23, 24}}} + }); + // Weight dimensions: + // [kernel_output_feature=1, width=3, kernel_input_feature=2, height=3] + Array4D weight({{ + {{1, 7, 13}, + {4, 10, 16}}, + {{2, 8, 14}, + {5, 11, 17}}, + {{3, 9, 15}, + {6, 12, 18}} + }}); + // clang-format on + + auto lhs_literal = Literal::CreateR4FromArray4D(input); + HloInstruction* lhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); + + auto rhs_literal = Literal::CreateR4FromArray4D(weight); + HloInstruction* rhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); + + Window window; + WindowDimension dim; + dim.set_size(3); + dim.set_stride(1); + dim.set_padding_low(0); + dim.set_padding_high(0); + dim.set_window_dilation(1); + dim.set_base_dilation(1); + *window.add_dimensions() = dim; + *window.add_dimensions() = dim; + + ConvolutionDimensionNumbers dnums; + dnums.set_batch_dimension(2); + dnums.set_feature_dimension(0); + dnums.add_spatial_dimensions(1); + dnums.add_spatial_dimensions(3); + + dnums.set_kernel_output_feature_dimension(0); + dnums.set_kernel_input_feature_dimension(2); + dnums.add_kernel_spatial_dimensions(3); + dnums.add_kernel_spatial_dimensions(1); + + const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); + b.AddInstruction(HloInstruction::CreateConvolve( + shape, lhs_instruction, rhs_instruction, window, dnums)); + + std::unique_ptr result = + evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + + // clang-format off + // Result dimensions: [feature=1, height=1, batch=1, width=2] + Array4D expected_array({{{{2514, 2685}}}}); + // clang-format on + auto expected = Literal::CreateR4FromArray4D(expected_array); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + +TEST_F(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { + HloComputation::Builder b(TestName()); + + Array4D lhs_array(1, 1, 4, 4); + // clang-format off + lhs_array.FillWithYX(Array2D({ + {1, 2, 3, 4 }, + {5, 6, 7, 8 }, + {9, 10, 11, 12}, + {13, 14, 15, 16}, + })); + // clang-format on + auto lhs_literal = Literal::CreateR4FromArray4D(lhs_array); + HloInstruction* lhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); + + Array4D rhs_array(1, 1, 2, 2); + // clang-format off + rhs_array.FillWithYX(Array2D({ + {5, 6}, + {7, 8}, + })); + // clang-format on + auto rhs_literal = Literal::CreateR4FromArray4D(rhs_array); + HloInstruction* rhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); + + Window window; + WindowDimension dim; + dim.set_size(2); + dim.set_stride(1); + dim.set_padding_low(0); + dim.set_padding_high(1); + dim.set_window_dilation(1); + dim.set_base_dilation(2); + *window.add_dimensions() = dim; + *window.add_dimensions() = dim; + + ConvolutionDimensionNumbers dnums = + ComputationBuilder::CreateDefaultConvDimensionNumbers(2); + + const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7}); + b.AddInstruction(HloInstruction::CreateConvolve( + shape, lhs_instruction, rhs_instruction, window, dnums)); + + std::unique_ptr result = + evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + + Array4D expected_array(1, 1, 7, 7); + expected_array.FillWithYX(Array2D({ + {5, 12, 10, 18, 15, 24, 20}, + {35, 48, 42, 56, 49, 64, 56}, + {25, 36, 30, 42, 35, 48, 40}, + {63, 80, 70, 88, 77, 96, 84}, + {45, 60, 50, 66, 55, 72, 60}, + {91, 112, 98, 120, 105, 128, 112}, + {65, 84, 70, 90, 75, 96, 80}, + })); + auto expected = Literal::CreateR4FromArray4D(expected_array); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + +TEST_F(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { + HloComputation::Builder b(TestName()); + + Array4D lhs_array(1, 1, 4, 4); + // clang-format off + lhs_array.FillWithYX(Array2D({ + {1, 2, 3, 4 }, + {5, 6, 7, 8 }, + {9, 10, 11, 12}, + {13, 14, 15, 16}, + })); + // clang-format on + auto lhs_literal = Literal::CreateR4FromArray4D(lhs_array); + HloInstruction* lhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); + + Array4D rhs_array(1, 1, 2, 2); + // clang-format off + rhs_array.FillWithYX(Array2D({ + {5, 6}, + {7, 8}, + })); + // clang-format on + auto rhs_literal = Literal::CreateR4FromArray4D(rhs_array); + HloInstruction* rhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); + + Window window; + WindowDimension dim; + dim.set_size(2); + dim.set_stride(1); + dim.set_padding_low(1); + dim.set_padding_high(1); + dim.set_window_dilation(1); + dim.set_base_dilation(2); + *window.add_dimensions() = dim; + *window.add_dimensions() = dim; + + ConvolutionDimensionNumbers dnums = + ComputationBuilder::CreateDefaultConvDimensionNumbers(2); + + const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8}); + b.AddInstruction(HloInstruction::CreateConvolve( + shape, lhs_instruction, rhs_instruction, window, dnums)); + + std::unique_ptr result = + evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + + Array4D expected_array(1, 1, 8, 8); + expected_array.FillWithYX(Array2D({ + {8, 7, 16, 14, 24, 21, 32, 28}, + {6, 5, 12, 10, 18, 15, 24, 20}, + {40, 35, 48, 42, 56, 49, 64, 56}, + {30, 25, 36, 30, 42, 35, 48, 40}, + {72, 63, 80, 70, 88, 77, 96, 84}, + {54, 45, 60, 50, 66, 55, 72, 60}, + {104, 91, 112, 98, 120, 105, 128, 112}, + {78, 65, 84, 70, 90, 75, 96, 80}, + })); + auto expected = Literal::CreateR4FromArray4D(expected_array); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + +TEST_F(HloEvaluatorTest, + DilatedWindowAndBaseConv2DWithDifferentLowAndHighPaddingAndStrides) { + HloComputation::Builder b(TestName()); + + Array4D lhs_array(1, 1, 4, 4); + // clang-format off + lhs_array.FillWithYX(Array2D({ + {1, 2, 3, 4 }, + {5, 6, 7, 8 }, + {9, 10, 11, 12}, + {13, 14, 15, 16}, + })); + // clang-format on + auto lhs_literal = Literal::CreateR4FromArray4D(lhs_array); + HloInstruction* lhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); + + Array4D rhs_array(1, 1, 2, 3); + // clang-format off + rhs_array.FillWithYX(Array2D({ + {5, 6, 7}, + {8, 9, 10}, + })); + // clang-format on + auto rhs_literal = Literal::CreateR4FromArray4D(rhs_array); + HloInstruction* rhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); + + Window window; + WindowDimension dim; + dim.set_size(2); + dim.set_stride(1); + dim.set_padding_low(2); + dim.set_padding_high(2); + dim.set_window_dilation(2); + dim.set_base_dilation(2); + *window.add_dimensions() = dim; + dim.set_size(3); + dim.set_stride(3); + dim.set_padding_low(2); + dim.set_padding_high(-1); + dim.set_window_dilation(1); + dim.set_base_dilation(3); + *window.add_dimensions() = dim; + + ConvolutionDimensionNumbers dnums = + ComputationBuilder::CreateDefaultConvDimensionNumbers(2); + + const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3}); + b.AddInstruction(HloInstruction::CreateConvolve( + shape, lhs_instruction, rhs_instruction, window, dnums)); + + std::unique_ptr result = + evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie(); + + Array4D expected_array(1, 1, 9, 3); + expected_array.FillWithYX(Array2D({ + {10, 20, 30}, + {0, 0, 0}, + {57, 74, 91}, + {0, 0, 0}, + {125, 142, 159}, + {0, 0, 0}, + {193, 210, 227}, + {0, 0, 0}, + {91, 98, 105}, + })); + auto expected = Literal::CreateR4FromArray4D(expected_array); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index eb2e5dfb37f33fd138e20ee930a2242cb1db89ea..7bb0ca2329d263e0d0a251e0526ee4cc41b7dfea 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" @@ -34,6 +33,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/regexp.h" using ::tensorflow::Env; using ::tensorflow::WriteStringToFile; @@ -214,6 +214,7 @@ string InstructionSequenceGraph( case HloOpcode::kCeil: case HloOpcode::kClamp: case HloOpcode::kConvert: + case HloOpcode::kCos: case HloOpcode::kDivide: case HloOpcode::kEq: case HloOpcode::kExp: @@ -282,6 +283,10 @@ string InstructionSequenceGraph( // port for each parameter instruction. No need to emit anything in this // case. continue; + case HloOpcode::kBatchNormTraining: + StrAppend(&name, " feature_index=", instruction->feature_index()); + color = kPurple; + break; case HloOpcode::kReduce: StrAppend(&name, " dims=", Join(instruction->dimensions(), ",")); color = kPurple; @@ -313,6 +318,11 @@ string InstructionSequenceGraph( StrAppend(&name, "
", "custom_call_target=", instruction->custom_call_target()); break; + case HloOpcode::kReducePrecision: + // Make ReducePrecision ops a bit more visible, since typically they + // will be inserted as modifications to an existing graph. + color = kDarkRed; + break; } // Create instruction node with appropriate label, shape, and color. @@ -325,8 +335,7 @@ string InstructionSequenceGraph( ShapeUtil::IsEffectiveScalar(instruction->shape())) { auto elem_idx = IndexUtil::LinearIndexToMultidimensionalIndex( instruction->shape(), /*linear_index=*/0); - StrAppend(&label, " = {", - LiteralUtil::GetAsString(instruction->literal(), elem_idx), + StrAppend(&label, " = {", instruction->literal().GetAsString(elem_idx), "}"); } @@ -508,10 +517,9 @@ namespace { class FileGraphRenderer : public GraphRendererInterface { public: - string RenderGraph(const string& graph, GraphKind graph_kind) override { + string RenderGraph(const string& graph, GraphKind graph_kind, + const DebugOptions& debug_options) override { static std::atomic output_num(0); - legacy_flags::HloGraphDumperFlags* flags = - legacy_flags::GetHloGraphDumperFlags(); string file_extension; switch (graph_kind) { case DOT_GRAPH: @@ -522,7 +530,7 @@ class FileGraphRenderer : public GraphRendererInterface { break; } string path = - JoinPath(flags->xla_hlo_dump_graph_path, + JoinPath(debug_options.xla_hlo_graph_path(), StrCat("hlo_graph_", output_num++, ".XXXXXX", file_extension)); auto status = Status::OK(); int fd = mkstemps(&path[0], file_extension.length()); @@ -548,13 +556,11 @@ XLA_REGISTER_GRAPH_RENDERER(FileGraphRenderer, 0); } // namespace string DumpGraph(const HloComputation& computation, const string& label, - bool show_addresses, bool show_layouts, + const DebugOptions& debug_options, const HloExecutionProfile* hlo_execution_profile) { string graph; string graph_url; - legacy_flags::HloGraphDumperFlags* flags = - legacy_flags::GetHloGraphDumperFlags(); - if (flags->xla_hlo_dump_as_graphdef) { + if (debug_options.xla_hlo_dump_as_graphdef()) { HloTfGraphBuilder builder; TF_CHECK_OK(builder.AddComputation(computation)); CHECK(tensorflow::protobuf::TextFormat::PrintToString(builder.GetGraphDef(), @@ -563,12 +569,13 @@ string DumpGraph(const HloComputation& computation, const string& label, // renderers support rendering GraphDefs. Always dump GraphDefs to files // for now. graph_url = FileGraphRenderer().RenderGraph( - graph, GraphRendererInterface::TF_GRAPHDEF); + graph, GraphRendererInterface::TF_GRAPHDEF, debug_options); } else { - graph = ComputationToDotGraph(computation, label, show_addresses, - show_layouts, hlo_execution_profile); + graph = ComputationToDotGraph( + computation, label, debug_options.xla_hlo_graph_addresses(), + debug_options.xla_hlo_graph_layout(), hlo_execution_profile); graph_url = GetGraphRenderer()->RenderGraph( - graph, GraphRendererInterface::DOT_GRAPH); + graph, GraphRendererInterface::DOT_GRAPH, debug_options); } LOG(INFO) << "computation " << computation.name() << " [" << label << "]: " << graph_url; @@ -584,6 +591,30 @@ void DumpText(const HloModule& module, const string& label, do_prefix ? StrCat(prefix, "-", label, ".txt") : StrCat(label, ".txt"); string path = JoinPath(directory_path, filename); TF_CHECK_OK(WriteStringToFile(env, path, module.ToString())); + LOG(INFO) << "dumping module '" << module.name() << "' to " << path; +} + +string MaybeDumpHloModule(const HloModule& module, const string& label, + const HloExecutionProfile* profile) { + VLOG(2) << "MaybeDumpHloModule called on module " << module.name(); + string graph_url; + const DebugOptions& debug_options = module.config().debug_options(); + if (!debug_options.xla_generate_hlo_graph().empty() && + RE2::PartialMatch(module.name(), + debug_options.xla_generate_hlo_graph())) { + graph_url = + DumpGraph(*module.entry_computation(), label, debug_options, profile); + } + if (!debug_options.xla_log_hlo_text().empty() && + RE2::PartialMatch(module.name(), debug_options.xla_log_hlo_text())) { + LOG(INFO) << "HLO for module " << module.name(); + LOG(INFO) << "Label: " << label; + XLA_LOG_LINES(2, module.ToString()); + } + if (!debug_options.xla_generate_hlo_text_to().empty()) { + DumpText(module, label, debug_options.xla_generate_hlo_text_to()); + } + return graph_url; } } // namespace hlo_graph_dumper diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index 8ed50c38473a6f6dd36603e155285e855ff0c5be..bc404a7a37fe43705ae0ace38c714ad649218fdd 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla.pb.h" namespace xla { namespace hlo_graph_dumper { @@ -38,14 +39,22 @@ class GraphRendererInterface { // Renders a DOT graph, returning a description of the rendered output // (e.g., a URL) - virtual string RenderGraph(const string& graph, GraphKind graph_kind) = 0; + virtual string RenderGraph(const string& graph, GraphKind graph_kind, + const DebugOptions& debug_options) = 0; }; +// Dump the given HLO module if a dump is requested in its debug options. Based +// on the debug options, either a graph dump, a text dump or both may be +// generated. If a graph dump is generated, the description (e.g. an URL) is +// returned; otherwise an empty string is returned. +string MaybeDumpHloModule(const HloModule& module, const string& label, + const HloExecutionProfile* profile = nullptr); + // Dumps a graph of the computation and returns a description of the rendered // graph (e.g., a URL) based on the renderer. The "best" renderer in the // registry is used. string DumpGraph(const HloComputation& computation, const string& label, - bool show_addresses, bool show_layouts, + const DebugOptions& debug_options, const HloExecutionProfile* hlo_execution_profile = nullptr); // Dumps the HloModule::ToString() as a file into the provided directory path diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index ea813c98743f7c34a891a3b648a2818f5dada8ec..f0f97c80f744915ab3145939a45fc32ede43db56 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -122,6 +122,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kBitcast: case HloOpcode::kCeil: case HloOpcode::kCopy: + case HloOpcode::kCos: case HloOpcode::kExp: case HloOpcode::kFloor: case HloOpcode::kIsFinite: @@ -226,6 +227,19 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, return instruction; } +/* static */ std::unique_ptr +HloInstruction::CreateReducePrecision(const Shape& shape, + HloInstruction* operand, + const int exponent_bits, + const int mantissa_bits) { + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kReducePrecision, shape)); + instruction->AppendOperand(operand); + instruction->exponent_bits_ = exponent_bits; + instruction->mantissa_bits_ = mantissa_bits; + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateCrossReplicaSum(const Shape& shape, HloInstruction* operand) { @@ -299,6 +313,12 @@ HloInstruction::CreateCrossReplicaSum(const Shape& shape, instruction->slice_starts_.assign(start_indices.begin(), start_indices.end()); instruction->slice_limits_.assign(limit_indices.begin(), limit_indices.end()); instruction->slice_strides_.assign(strides.begin(), strides.end()); + // For backward compatibility with old serialized computations: if there are + // no strides, assume all strides are 1. + // TODO(b/63317920): remove this code. + if (instruction->slice_strides_.empty()) { + instruction->slice_strides_ = std::vector(start_indices.size(), 1LL); + } return instruction; } @@ -371,6 +391,22 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, return instruction; } +/* static */ std::unique_ptr +HloInstruction::CreateBatchNormTraining(const Shape& shape, + HloInstruction* operand, + HloInstruction* scale, + HloInstruction* offset, float epsilon, + int64 feature_index) { + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kBatchNormTraining, shape)); + instruction->AppendOperand(operand); + instruction->AppendOperand(scale); + instruction->AppendOperand(offset); + instruction->epsilon_ = epsilon; + instruction->feature_index_ = feature_index; + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateSelectAndScatter( const Shape& shape, HloInstruction* operand, HloComputation* select, @@ -730,6 +766,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kBitcast: case HloOpcode::kCeil: case HloOpcode::kCopy: + case HloOpcode::kCos: case HloOpcode::kExp: case HloOpcode::kIsFinite: case HloOpcode::kFloor: @@ -780,6 +817,10 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kConvert: CHECK_EQ(new_operands.size(), 1); return CreateConvert(shape, new_operands[0]); + case HloOpcode::kReducePrecision: + CHECK_EQ(new_operands.size(), 1); + return CreateReducePrecision(shape, new_operands[0], exponent_bits_, + mantissa_bits_); case HloOpcode::kConvolution: CHECK_EQ(new_operands.size(), 2); return CreateConvolve(shape, new_operands[0], new_operands[1], *window_, @@ -838,11 +879,16 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( return CreateWhile(shape, while_condition(), while_body(), new_operands[0]); case HloOpcode::kConstant: - return CreateConstant(LiteralUtil::CloneToUnique(*literal_)); + return CreateConstant(literal_->CloneToUnique()); case HloOpcode::kFusion: return CloneFusionWithNewOperands(shape, new_operands); case HloOpcode::kParameter: return CreateParameter(parameter_number_, shape, parameter_name_); + case HloOpcode::kBatchNormTraining: + CHECK_EQ(new_operands.size(), 3); + return CreateBatchNormTraining(shape, new_operands[0], new_operands[1], + new_operands[2], epsilon(), + feature_index()); // Unsupported ops for cloning. case HloOpcode::kRecv: case HloOpcode::kSend: @@ -1041,7 +1087,7 @@ Status HloInstruction::RemoveControlDependencyTo(HloInstruction* instruction) { auto pred_it = std::find(instruction->control_predecessors_.begin(), instruction->control_predecessors_.end(), this); TF_RET_CHECK(pred_it != instruction->control_predecessors_.end()); - instruction->control_predecessors_.erase(succ_it); + instruction->control_predecessors_.erase(pred_it); return Status::OK(); } @@ -1099,6 +1145,7 @@ bool HloInstruction::Identical( case HloOpcode::kCeil: case HloOpcode::kClamp: case HloOpcode::kCopy: + case HloOpcode::kCos: case HloOpcode::kCrossReplicaSum: case HloOpcode::kDivide: case HloOpcode::kDot: @@ -1141,15 +1188,24 @@ bool HloInstruction::Identical( // different HloComputations. ShapeUtil::Compatible(shape(), other.shape()); + case HloOpcode::kBatchNormTraining: + return feature_index() == other.feature_index() && + epsilon() == other.epsilon(); + // A constant is defined by the value in the literal. case HloOpcode::kConstant: - return LiteralUtil::Equal(literal(), other.literal()); + return literal().Equal(other.literal()); // A convert result is determined by the primitive type that the operand is // converted into. case HloOpcode::kConvert: return shape().element_type() == other.shape().element_type(); + // A reduce-precision operation is determined by the bit sizes. + case HloOpcode::kReducePrecision: + return exponent_bits() == other.exponent_bits() && + mantissa_bits() == other.mantissa_bits(); + // Convolution has a window and dimensions. case HloOpcode::kConvolution: return protobuf_util::ProtobufEquals(window(), other.window()) && @@ -1438,10 +1494,10 @@ string HloInstruction::ToString(bool compact_operands, string operands; if (opcode() == HloOpcode::kConstant) { // For constants, show the actual value in place of an empty operand list. - if (ShapeUtil::ElementsIn(shape()) <= 10) { - // LiteralUtil::ToString emits multidimensional arrays over multiple + if (!ShapeUtil::IsTuple(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) { + // Literal::ToString emits multidimensional arrays over multiple // lines. Compact this into one line by stripping out white space. - string tmp = LiteralUtil::ToString(literal()); + string tmp = literal().ToString(); std::replace(tmp.begin(), tmp.end(), '\n', ' '); std::vector v = tensorflow::str_util::Split(tmp, ' '); bool first = true; @@ -1455,7 +1511,7 @@ string HloInstruction::ToString(bool compact_operands, first = false; } } else { - // Do not show large constants. + // Do not show large constants or tuples. operands = "{...}"; } } else if (opcode() == HloOpcode::kParameter) { @@ -1565,7 +1621,7 @@ HloInstructionProto HloInstruction::ToProto() const { case HloOpcode::kFusion: { HloComputationProto* proto_fused_computation = proto.mutable_fused_instructions_computation(); - proto_fused_computation->set_name(FullyQualifiedName()); + proto_fused_computation->set_name(name()); // Fill in fused instructions. Note that fused_instructions() returns in // reverse post-order (i.e. root first), so we reverse to get post-order. @@ -1629,6 +1685,8 @@ string HloInstruction::ToCategory() const { case FusionKind::kConvBackwardFilter: case FusionKind::kConvBackwardInput: return "convolution fusion"; + case FusionKind::kCustom: + return "custom fusion"; } } @@ -1639,14 +1697,6 @@ string HloInstruction::ToCategory() const { return HloOpcodeString(opcode()); } -string HloInstruction::FullyQualifiedName() const { - if (IsFused()) { - return StrCat(fusion_instruction()->parent()->name(), - "::", fusion_instruction()->name(), "::", name_); - } - return StrCat(parent_->name(), "::", name_); -} - HloInstruction* HloInstruction::tracing() const { return trace_instruction_; } void HloInstruction::set_tracing(HloInstruction* trace_instruction) { @@ -1736,6 +1786,8 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { switch (opcode_) { case HloOpcode::kAbs: return visitor->HandleAbs(this, operands_[0]); + case HloOpcode::kBatchNormTraining: + return visitor->HandleBatchNormTraining(this); case HloOpcode::kSign: return visitor->HandleSign(this, operands_[0]); case HloOpcode::kConstant: @@ -1758,9 +1810,9 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { case HloOpcode::kSubtract: return visitor->HandleSubtract(this, operands_[0], operands_[1]); case HloOpcode::kMaximum: - return visitor->HandleMaximum(this, operands_[0], operands_[1]); + return visitor->HandleMaximum(this); case HloOpcode::kMinimum: - return visitor->HandleMinimum(this, operands_[0], operands_[1]); + return visitor->HandleMinimum(this); case HloOpcode::kLogicalAnd: return visitor->HandleLogicalAnd(this, operands_[0], operands_[1]); case HloOpcode::kLogicalOr: @@ -1768,9 +1820,9 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { case HloOpcode::kConcatenate: return visitor->HandleConcatenate(this, operands_); case HloOpcode::kConvert: - return visitor->HandleConvert(this, operands_[0]); + return visitor->HandleConvert(this); case HloOpcode::kCopy: - return visitor->HandleCopy(this, operands_[0]); + return visitor->HandleCopy(this); case HloOpcode::kMultiply: return visitor->HandleMultiply(this, operands_[0], operands_[1]); case HloOpcode::kDot: @@ -1814,6 +1866,8 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { return visitor->HandleLog(this, operands_[0]); case HloOpcode::kTanh: return visitor->HandleTanh(this, operands_[0]); + case HloOpcode::kCos: + return visitor->HandleCos(this, operands_[0]); case HloOpcode::kIsFinite: return visitor->HandleIsFinite(this, operands_[0]); case HloOpcode::kLogicalNot: @@ -1830,6 +1884,8 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { return visitor->HandleTranspose(this); case HloOpcode::kReverse: return visitor->HandleReverse(this, operands_[0]); + case HloOpcode::kReducePrecision: + return visitor->HandleReducePrecision(this); case HloOpcode::kSlice: return visitor->HandleSlice(this, operands_[0]); case HloOpcode::kDynamicSlice: @@ -1868,72 +1924,90 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { HloOpcodeString(opcode_).c_str()); } -Status HloInstruction::AcceptInternal(DfsHloVisitor* visitor, - const CompareFunction* operand_order, - bool ignore_control_predecessors) { - // Do not visit this HLO node again if it is already visited. - if (visitor->DidVisit(*this)) { - VLOG(3) << "Not visiting HLO " << name() << " as it was already visited."; - return Status::OK(); +static Status PushDFSChild(DfsHloVisitor* visitor, + std::vector* dfs_stack, + HloInstruction* parent, HloInstruction* child) { + switch (visitor->GetVisitState(*child)) { + case DfsHloVisitor::kVisiting: + return FailedPrecondition( + "A cycle is detected while visiting instruction %s", + parent->ToString().c_str()); + + case DfsHloVisitor::kVisited: + VLOG(3) << "Not visiting HLO " << child->name() + << " as it was already visited."; + return Status::OK(); + + case DfsHloVisitor::kNotVisited: + dfs_stack->push_back(child); + return Status::OK(); } +} - // If the instruction is in the visiting state, it means a cycle. - if (visitor->IsVisiting(*this)) { - return FailedPrecondition( - "A cycle is detected while visiting instruction %s", - ToString().c_str()); - } - visitor->SetVisiting(*this); - - // Sort operands, if an ordering was provided. 'temp_sorted_operands' must - // live at this scope, since 'operands' will point to it if the operands are - // sorted. The purpose of the 'operands' pointer is to avoid copying the - // operands in the common case where the operands are not sorted. - std::vector* operands = &operands_; - std::vector temp_sorted_operands; - if (operand_order != nullptr) { - temp_sorted_operands = operands_; - std::sort(temp_sorted_operands.begin(), temp_sorted_operands.end(), - *operand_order); - operands = &temp_sorted_operands; - } - for (HloInstruction* operand : *operands) { - VLOG(3) << "Going to visit HLO " << operand->name() << " as operand of HLO " - << name(); - TF_RETURN_IF_ERROR(operand->AcceptInternal(visitor, operand_order, - ignore_control_predecessors)); - } - - if (!ignore_control_predecessors) { - // This uses the same pointer/vector sorting to avoid extra copies as above. - std::vector* predecessors = &control_predecessors_; - std::vector temp_sorted_predecessors; - if (operand_order != nullptr) { - temp_sorted_predecessors = control_predecessors_; - std::sort(temp_sorted_predecessors.begin(), - temp_sorted_predecessors.end(), *operand_order); - predecessors = &temp_sorted_predecessors; +static Status PostOrderDFS(HloInstruction* root, DfsHloVisitor* visitor, + const HloInstruction::CompareFunction* operand_order, + bool ignore_control_predecessors) { + std::vector dfs_stack; + dfs_stack.push_back(root); + + do { + DCHECK(!dfs_stack.empty()); + + HloInstruction* current_node = dfs_stack.back(); + DfsHloVisitor::VisitState visit_state = + visitor->GetVisitState(*current_node); + if (visit_state == DfsHloVisitor::kVisited) { + dfs_stack.pop_back(); + VLOG(3) << "Not visiting HLO " << current_node->name() + << " as it was already visited."; + continue; } - for (HloInstruction* control_predecessor : *predecessors) { - VLOG(3) << "Going to visit HLO " << control_predecessor->name() - << " as a control predecessor of HLO " << name(); - TF_RETURN_IF_ERROR(control_predecessor->AcceptInternal( - visitor, operand_order, ignore_control_predecessors)); + + if (visit_state == DfsHloVisitor::kVisiting) { + dfs_stack.pop_back(); + + TF_RETURN_IF_ERROR(visitor->Preprocess(current_node)); + VLOG(2) << "Visiting HLO " << current_node->name(); + TF_RETURN_IF_ERROR(current_node->Visit(visitor)); + visitor->SetVisited(*current_node); + TF_RETURN_IF_ERROR(visitor->Postprocess(current_node)); + continue; } - } - TF_RETURN_IF_ERROR(visitor->Preprocess(this)); - VLOG(2) << "Visiting HLO " << name(); - TF_RETURN_IF_ERROR(Visit(visitor)); - visitor->SetVisited(*this); - return visitor->Postprocess(this); + visitor->SetVisiting(*current_node); + + const size_t old_dfs_stack_size = dfs_stack.size(); + + for (HloInstruction* child : current_node->operands()) { + TF_RETURN_IF_ERROR( + PushDFSChild(visitor, &dfs_stack, current_node, child)); + } + + if (!ignore_control_predecessors) { + for (HloInstruction* child : current_node->control_predecessors()) { + TF_RETURN_IF_ERROR( + PushDFSChild(visitor, &dfs_stack, current_node, child)); + } + } + + if (operand_order != nullptr) { + std::sort(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end(), + *operand_order); + } + + // This makes the traversal order the same as what you'd expect + // out of a recursive algorithm. + std::reverse(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end()); + } while (!dfs_stack.empty()); + + return Status::OK(); } Status HloInstruction::Accept(DfsHloVisitor* visitor, bool call_finish_visit, bool ignore_control_predecessors) { VLOG(2) << "HloInstruction::Accept(" << name() << ")"; TF_RETURN_IF_ERROR( - AcceptInternal(visitor, nullptr, ignore_control_predecessors)); + PostOrderDFS(this, visitor, nullptr, ignore_control_predecessors)); if (call_finish_visit) { TF_RETURN_IF_ERROR(visitor->FinishVisit(this)); } @@ -1944,8 +2018,8 @@ Status HloInstruction::AcceptWithOperandOrder( DfsHloVisitor* visitor, const CompareFunction& operand_order, bool call_finish_visit) { VLOG(2) << "HloInstruction::AcceptWithOperandOrder(" << name() << ")"; - TF_RETURN_IF_ERROR(AcceptInternal(visitor, &operand_order, - /*ignore_control_predecessors=*/false)); + TF_RETURN_IF_ERROR(PostOrderDFS(this, visitor, &operand_order, + /*ignore_control_predecessors=*/false)); if (call_finish_visit) { TF_RETURN_IF_ERROR(visitor->FinishVisit(this)); } @@ -2060,12 +2134,14 @@ bool HloInstruction::IsElementwise() const { case HloOpcode::kCeil: case HloOpcode::kConvert: case HloOpcode::kCopy: + case HloOpcode::kCos: case HloOpcode::kExp: case HloOpcode::kFloor: case HloOpcode::kIsFinite: case HloOpcode::kLog: case HloOpcode::kLogicalNot: case HloOpcode::kNegate: + case HloOpcode::kReducePrecision: case HloOpcode::kSign: case HloOpcode::kTanh: return true; @@ -2274,6 +2350,8 @@ string ToString(HloInstruction::FusionKind kind) { return "kConvBackwardFilter"; case HloInstruction::FusionKind::kConvBackwardInput: return "kConvBackwardInput"; + case HloInstruction::FusionKind::kCustom: + return "kCustom"; } } @@ -2348,4 +2426,9 @@ void HloInstruction::UniquifyName(NameUniquer* name_uniquer) { name_ = name_uniquer->GetUniqueName(name_); } +void HloInstruction::set_outer_dimension_partitions( + const std::vector& outer_dimension_partitions) { + outer_dimension_partitions_ = outer_dimension_partitions; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index c7cd729934b2a52d95b32b4ba5f5c84dc087cfd4..44eacc450c054b8facb28412735f106a2dfb9a1a 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -63,6 +63,9 @@ class HloInstruction { kTransposeDot, // Fused into a dot with transposed operands. kConvBackwardFilter, // Fused into a backward filter convolution. kConvBackwardInput, // Fused into a backward input convolution. + + kCustom, // Custom category for backend-specific fusions that + // do not match any of the more specific ones. }; ~HloInstruction(); @@ -131,6 +134,13 @@ class HloInstruction { const Window& window, const ConvolutionDimensionNumbers& dimension_numbers); + // Creates a reduce-precision op, where operand is the data to reduce in + // precision, and exponent_bits and mantissa_bits describe the precision to + // reduce it to. + static std::unique_ptr CreateReducePrecision( + const Shape& shape, HloInstruction* operand, const int exponent_bits, + const int mantissa_bits); + // Creates a cross replica sum op. static std::unique_ptr CreateCrossReplicaSum( const Shape& shape, HloInstruction* operand); @@ -209,6 +219,11 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, HloInstruction* init_value, const Window& window, HloComputation* reduce_computation); + // Creates a batch-norm-training instruction. + static std::unique_ptr CreateBatchNormTraining( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* offset, float epsilon, int64 feature_index); + // Creates a scatter computation that scatters the `source` array to the // selected indices of each window. static std::unique_ptr CreateSelectAndScatter( @@ -510,11 +525,6 @@ class HloInstruction { // or "elementwise". string ToCategory() const; - // Returns the string concatenation of parent name and this instructions - // name. This name is guaranteed to be unique among all instructions in the - // HloModule. - string FullyQualifiedName() const; - // Returns a logging instruction, if the output of this instruction is logged. // // Postcondition: retval == nullptr || retval->opcode() == HloOpcode::kTrace @@ -528,6 +538,18 @@ class HloInstruction { // Precondition: opcode() == HloOpcode::kSend or HloOpcode::kRecv int64 channel_id() const { return channel_id_; } + // Returns feature_index field associated with the instruction. The index + // represents the index of the feature dimension. + // + // Precondition: opcode() == HloOpcode::kBatchNormTraining + int64 feature_index() const { return feature_index_; } + + // Returns a epsilon value associated with the instruction. The is a small + // number added to the variance to avoid divide-by-zero error. + // + // Precondition: opcode() == HloOpcode::kBatchNormTraining + float epsilon() const { return epsilon_; } + // Returns the infeed configuration string. The infeed configuration includes // any metadata needed for the backend compiler (e.g., infeed buffer address) // and is target-dependent. @@ -661,6 +683,22 @@ class HloInstruction { return dynamic_slice_sizes_; } + // Returns the number of exponent bits for a reduce-precision node. + // + // Precondition: opcode() == HloOpcode::kReducePrecision + int32 exponent_bits() const { + CHECK_EQ(HloOpcode::kReducePrecision, opcode_); + return exponent_bits_; + } + + // Returns the number of mantissa bits for a reduce-precision node. + // + // Precondition: opcode() == HloOpcode::kReducePrecision + int32 mantissa_bits() const { + CHECK_EQ(HloOpcode::kReducePrecision, opcode_); + return mantissa_bits_; + } + // Returns data on the window in a windowed operation such as // convolution. const Window& window() const { @@ -708,6 +746,16 @@ class HloInstruction { return called_computations_; } + // Replaces all called computations based on a map function. This is needed + // when we clone hlo_computations and want to let the instructions to point + // to the newly cloned nodes. + void ReplaceCalledComputations( + std::function map_function) { + for (int64 i = 0; i < called_computations_.size(); ++i) { + called_computations_[i] = map_function(called_computations_[i]); + } + } + // Returns true if this instruction performs an elementwise operation on // `operand_idx`-th operand. An instruction is elementwise on an operand iff, // after performing necessary implicit broadcast @@ -742,9 +790,9 @@ class HloInstruction { std::tuple, std::vector> ReshapeMerelyInsertsOrDeletes1SizedDimensions() const; - // Returns the opcode string for this instruction. Compared with - // HloOpcodeString method, this wrapper dumps additional information - // such as fusion kind. + // Returns the opcode string for this instruction. This is the result from + // HloOpcodeString plus, for fusion nodes, the fusion kind, separated by a + // ':'. string ExtendedOpcodeStr() const; // Returns a string identifier for this instruction. If no string identifier @@ -782,6 +830,17 @@ class HloInstruction { parent_fusion_instruction_ = fusion_instruction; } + // Get/Set the number of partitions per outer dimension (in order, starting + // with outer-most dimension first). Currently used by the parallel cpu + // backend to partition HLOs into parallel tasks. + // TODO(b/62783254) Replace these methods with a more general way to + // annotate HLOs with backend-specific information. + const std::vector& outer_dimension_partitions() const { + return outer_dimension_partitions_; + } + void set_outer_dimension_partitions( + const std::vector& outer_dimension_partitions); + private: enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse }; @@ -818,12 +877,6 @@ class HloInstruction { const Shape& shape, tensorflow::gtl::ArraySlice operands); - // Inner DFS traversal function -- this function being called (rather than - // Accept above) allows us to distinguish the root of the traversal. - Status AcceptInternal(DfsHloVisitor* visitor, - const CompareFunction* operand_order, - bool ignore_control_predecessors); - // CHECKs various invariants of a fusion instruction. void CheckFusionInstruction() const; @@ -864,6 +917,10 @@ class HloInstruction { std::vector slice_limits_; std::vector slice_strides_; + // The bit sizes for a reduce-precision operation. + int32 exponent_bits_; + int32 mantissa_bits_; + // Describes the [start, start + size) range size for a dynamic slice // ('start' is specified dynamically in the second operand of the operation). std::vector dynamic_slice_sizes_; @@ -934,6 +991,14 @@ class HloInstruction { // Only present for kRng. RandomDistribution distribution_; + // A small float number added to the variance to avoid divide-by-zero error. + // Only present for kBatchNormTraining. + float epsilon_; + + // An integer value representing the index of the feature dimension. + // Only present for kBatchNormTraining. + int64 feature_index_; + // Represents a unique identifier for each Send/Recv instruction pair. // Only present for kSend or kRecv. int64 channel_id_ = -1; @@ -950,6 +1015,10 @@ class HloInstruction { // Metadata for debugging. OpMetadata metadata_; + // The number of partitions per outer dimension (listed in order from + // outer-most dimension first). + std::vector outer_dimension_partitions_; + TF_DISALLOW_COPY_AND_ASSIGN(HloInstruction); }; diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index bcf81cd8ddf63eff2f1df9c6c797588eee42f6b5..bb1b477e1397f9094ef2f8f4cd69f9fee074f65e 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -232,7 +232,7 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperands) { // ------- auto param0 = HloInstruction::CreateParameter(0, r0f32_, "param0"); auto param1 = HloInstruction::CreateParameter(1, r0f32_, "param1"); - auto c0 = HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); + auto c0 = HloInstruction::CreateConstant(Literal::CreateR0(1.1f)); auto addleft = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param0.get(), c0.get()); auto addright = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, @@ -271,7 +271,7 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperandsWithUnaryOps) { // ------- auto param0 = HloInstruction::CreateParameter(0, r0f32_, "param0"); auto param1 = HloInstruction::CreateParameter(1, r0f32_, "param1"); - auto c0 = HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); + auto c0 = HloInstruction::CreateConstant(Literal::CreateR0(1.1f)); auto neg1 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, c0.get()); auto addleft = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param0.get(), neg1.get()); @@ -307,7 +307,7 @@ TEST_F(HloInstructionTest, TrivialMap) { auto param = builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32, "x")); auto value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, value)); auto add_f32 = builder.Build(); @@ -349,9 +349,8 @@ TEST_F(HloInstructionTest, TrivialReduce) { // Builds a parameter and an initial value and feeds them to the reduce. auto param0 = HloInstruction::CreateParameter(0, f32a100x10, ""); - auto const0 = - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f)); - auto c0 = HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); + auto const0 = HloInstruction::CreateConstant(Literal::CreateR0(0.0f)); + auto c0 = HloInstruction::CreateConstant(Literal::CreateR0(1.1f)); auto reduce = HloInstruction::CreateReduce(f32v100, param0.get(), const0.get(), /*dimensions_to_reduce=*/{1}, add_f32.get()); @@ -560,7 +559,7 @@ TEST_F(HloInstructionTest, PostProcessAllVisitedNodes) { TEST_F(HloInstructionTest, SingletonFusionOp) { // Create a fusion instruction containing a single unary operation. auto constant = - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); + HloInstruction::CreateConstant(Literal::CreateR0(1.1f)); auto exp = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get()); @@ -574,9 +573,9 @@ TEST_F(HloInstructionTest, SingletonFusionOp) { TEST_F(HloInstructionTest, BinaryFusionOp) { // Create a fusion instruction containing a single binary operation. auto constant1 = - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); + HloInstruction::CreateConstant(Literal::CreateR0(1.1f)); auto constant2 = - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.1f)); + HloInstruction::CreateConstant(Literal::CreateR0(42.1f)); auto add = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, constant1.get(), constant2.get()); @@ -594,7 +593,7 @@ TEST_F(HloInstructionTest, BinaryFusionOp) { TEST_F(HloInstructionTest, ChainFusionOp) { // Create a chain of fused unary ops. auto constant = - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); + HloInstruction::CreateConstant(Literal::CreateR0(1.1f)); auto exp1 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get()); auto exp2 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1.get()); @@ -613,7 +612,7 @@ TEST_F(HloInstructionTest, ChainFusionOp) { TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) { // Create a chain of fused unary ops. auto constant = - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); + HloInstruction::CreateConstant(Literal::CreateR0(1.1f)); auto exp1 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get()); auto exp2 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1.get()); @@ -644,7 +643,7 @@ TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { std::unique_ptr computation_y = make_map_computation(); auto constant = - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); + HloInstruction::CreateConstant(Literal::CreateR0(1.1f)); auto map_1_x = HloInstruction::CreateMap(scalar_shape, {constant.get()}, computation_x.get(), /*static_operands=*/{}); @@ -681,9 +680,9 @@ TEST_F(HloInstructionTest, ComplexFusionOp) { // // Notable complexities are repeated operands in a same instruction, different // shapes, use of value in different expressions. - auto c1 = HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); - auto c2 = HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.1f)); - auto c3 = HloInstruction::CreateConstant(LiteralUtil::CreateR0(9.0f)); + auto c1 = HloInstruction::CreateConstant(Literal::CreateR0(1.1f)); + auto c2 = HloInstruction::CreateConstant(Literal::CreateR0(2.1f)); + auto c3 = HloInstruction::CreateConstant(Literal::CreateR0(9.0f)); auto add = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c1.get(), c2.get()); @@ -732,11 +731,11 @@ TEST_F(HloInstructionTest, IdenticalInstructions) { // Create a set of random constant operands to use below. Make them matrices // so dimensions are interesting. auto operand1 = HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); auto operand2 = HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{10.0, 20.0}, {30.0, 40.0}})); - auto vector_operand = HloInstruction::CreateConstant( - LiteralUtil::CreateR1({42.0, 123.0})); + Literal::CreateR2({{10.0, 20.0}, {30.0, 40.0}})); + auto vector_operand = + HloInstruction::CreateConstant(Literal::CreateR1({42.0, 123.0})); Shape shape = operand1->shape(); // Convenient short names for the operands. diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 141251011cc0b4205b6069ff90415492ead9f7a9..79f17bbb6bd9bfc0c6ed48c68599ef51fbd27af8 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -95,6 +95,7 @@ HLO_MATCHER(Parameter); HLO_MATCHER(Power); HLO_MATCHER(Recv); HLO_MATCHER(Reduce); +HLO_MATCHER(ReducePrecision); HLO_MATCHER(ReduceWindow); HLO_MATCHER(Remainder); HLO_MATCHER(Reshape); diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 22ef9c590bcf63a4e0c60931f771455601b0c019..da6f1d77ecb82ddbce11ca43c184ce0552b757fa 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -37,19 +37,17 @@ HloModule::HloModule(const string& name, const HloModuleConfig& config) : name_(name), config_(config), - entry_computation_(nullptr), has_entry_computation_handle_(true), - entry_computation_handle_(entry_computation_handle), - computation_name_uniquer_(/*separator=*/".") {} + entry_computation_handle_(entry_computation_handle) {} -HloModule::HloModule(const string& name) - : name_(name), - entry_computation_(nullptr), - computation_name_uniquer_(/*separator=*/".") {} +HloModule::HloModule(const string& name) : name_(name) {} HloComputation* HloModule::AddComputationInternal( std::unique_ptr computation) { computation->UniquifyName(&computation_name_uniquer_); + for (auto& instruction : computation->instructions()) { + instruction->UniquifyName(&instruction_name_uniquer_); + } computation->set_parent(this); computations_.push_back(std::move(computation)); return computations_.back().get(); @@ -301,6 +299,36 @@ std::list HloModule::MakeComputationPostOrder() const { return post_order; } +std::unique_ptr HloModule::Clone(const string& suffix) { + VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n"; + auto module = MakeUnique(name_ + "-" + suffix); + module->config_ = config_; + module->entry_computation_handle_ = entry_computation_handle_; + module->has_entry_computation_handle_ = has_entry_computation_handle_; + + std::unordered_map clone_map; + for (auto& computation : computations_) { + auto cloned_computation = computation->Clone(suffix); + InsertOrDie(&clone_map, computation.get(), cloned_computation.get()); + + if (entry_computation_ == computation.get()) { + module->AddEntryComputation(std::move(cloned_computation)); + } else { + module->AddEmbeddedComputation(std::move(cloned_computation)); + } + } + + for (auto& cloned_computation : module->computations_) { + for (auto& instruction : cloned_computation->instructions()) { + // Rewrite instruction's called_computation to point to the cloned + // computations. + instruction->ReplaceCalledComputations( + [&](HloComputation* hlo) { return FindOrDie(clone_map, hlo); }); + } + } + return module; +} + uint64 HloModule::RandomNew64() const { tensorflow::mutex_lock l(rng_mutex_); return rng_(); diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 4b14b4fd62a460ede0639e4417507ff2af02abd6..ae8ec02fbd1a59fa1f4a4a6160de6db0c033c4b1 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -75,6 +75,9 @@ class HloModule { const string& name() const { return name_; } + // Returns a deep copy of this module including all computations. + std::unique_ptr Clone(const string& suffix = "clone"); + // Return a pointer to the entry computation of the module.. HloComputation* entry_computation() const { CHECK_NE(nullptr, entry_computation_); @@ -121,13 +124,16 @@ class HloModule { return computation_name_uniquer_.GetUniqueName(prefix); } + // Returns the NameUniquer for uniquing instruction names in this module. + NameUniquer& instruction_name_uniquer() { return instruction_name_uniquer_; } + private: HloComputation* AddComputationInternal( std::unique_ptr computation); const string name_; HloModuleConfig config_; - HloComputation* entry_computation_; + HloComputation* entry_computation_ = nullptr; std::vector> computations_; // Random number generator engine to use when generating random numbers per @@ -141,8 +147,10 @@ class HloModule { bool has_entry_computation_handle_ = false; VersionedComputationHandle entry_computation_handle_; - // Unique name generator for computation names, which are unique per module. - NameUniquer computation_name_uniquer_; + // Unique name generator for computation and instruction names, which are + // unique per module. + NameUniquer computation_name_uniquer_{/*separator=*/"."}; + NameUniquer instruction_name_uniquer_{/*separator=*/"."}; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc index a2235a268235860a633fdc5f26c5127574a9487c..8974deb530c2e4561b5ab57f43c65fd525db3617 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.cc +++ b/tensorflow/compiler/xla/service/hlo_module_config.cc @@ -58,6 +58,10 @@ string HloModuleConfig::compilation_cache_key() const { StrAppend(&key, "::replica_count=", replica_count()); } StrAppend(&key, debug_options_.DebugString()); + if (intra_op_parallelism_threads() > 0) { + StrAppend(&key, "::intra_op_parallelism_threads=", + intra_op_parallelism_threads()); + } return key; } diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index ee32ab9bc4b5dd406d0dd9b6dfff52f852883dd9..2299200b5be969c065fded840709a3d6034efe47 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -92,6 +92,15 @@ class HloModuleConfig { debug_options_ = debug_options; } + // Sets/returns the number of intra op threads for this module. + void set_intra_op_parallelism_threads( + const int intra_op_parallelism_threads) { + intra_op_parallelism_threads_ = intra_op_parallelism_threads; + } + int64 intra_op_parallelism_threads() const { + return intra_op_parallelism_threads_; + } + private: // If you add new members, be sure to update compilation_cache_key. @@ -116,6 +125,10 @@ class HloModuleConfig { // The number of replicas to compile this binary for. int64 replica_count_ = 1; + // The target maximum parallelism at which to partition HLOs for parallel + // execution on the CPU backend. + int64 intra_op_parallelism_threads_ = -1; + DebugOptions debug_options_; }; diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 870bc729aec98a2959de5aa322850898502394ad..56dc5632035c625445018becfd25d69557e6232a 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -38,7 +38,7 @@ class HloModuleTest : public HloTestBase { std::unique_ptr CreateConstantComputation() { auto builder = HloComputation::Builder("Constant"); builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); return builder.Build(); } @@ -81,6 +81,30 @@ TEST_F(HloModuleTest, TwoComputationsPostOrder) { EXPECT_EQ(computation2->name(), "Constant.1"); } +TEST_F(HloModuleTest, CloneTest) { + // Create and copy a module with a diamond call graph of computations. + auto module = CreateNewModule(); + auto computation1 = + module->AddEmbeddedComputation(CreateConstantComputation()); + auto computation2 = + module->AddEmbeddedComputation(CreateCallComputation({computation1})); + auto computation3 = + module->AddEmbeddedComputation(CreateCallComputation({computation1})); + module->AddEntryComputation( + CreateCallComputation({computation2, computation3})); + + auto post_order = module->MakeComputationPostOrder(); + auto cloned_module = module->Clone("copy"); + auto post_order_copied = cloned_module->MakeComputationPostOrder(); + + EXPECT_EQ(post_order.size(), post_order_copied.size()); + for (auto origin = post_order.begin(), copied = post_order_copied.begin(); + origin != post_order.end() && copied != post_order_copied.end(); + ++origin, ++copied) { + EXPECT_EQ((*origin)->name() + "copy", (*copied)->name()); + } +} + TEST_F(HloModuleTest, DiamondComputationsPostOrder) { // Create a module with a diamond call graph of computations. auto module = CreateNewModule(); diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc index ceb0cdaa3169bb57e4ebb61ac1b2ea41f1ef7995..53a93f3dac2737491012f466a97750a8b3961ed3 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode.cc @@ -19,11 +19,20 @@ limitations under the License. namespace xla { string HloOpcodeString(HloOpcode opcode) { + // Note: Do not use ':' in opcode strings. It is used as a special character + // in these places: + // - In extended opcode strings (HloInstruction::ExtendedOpcodeString()), to + // separate the opcode from the fusion kind + // - In fully qualified names (HloInstruction::FullyQualifiedName()), to + // separate the qualifiers (name of the computation and potentially the + // fusion instruction) from the name switch (opcode) { case HloOpcode::kAbs: return "abs"; case HloOpcode::kAdd: return "add"; + case HloOpcode::kBatchNormTraining: + return "batch-norm-training"; case HloOpcode::kBitcast: return "bitcast"; case HloOpcode::kBroadcast: @@ -40,6 +49,8 @@ string HloOpcodeString(HloOpcode opcode) { return "convert"; case HloOpcode::kConvolution: return "convolution"; + case HloOpcode::kCos: + return "cosine"; case HloOpcode::kCrossReplicaSum: return "cross-replica-sum"; case HloOpcode::kCustomCall: @@ -112,6 +123,8 @@ string HloOpcodeString(HloOpcode opcode) { return "recv"; case HloOpcode::kReduce: return "reduce"; + case HloOpcode::kReducePrecision: + return "reduce-precision"; case HloOpcode::kReduceWindow: return "reduce-window"; case HloOpcode::kRemainder: diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index e2cdbfdfa7a4b5509dccf9a83ffbd799f9ab1374..d1263219c01fe1025bda50c35e94de65c5f86d37 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -30,6 +30,7 @@ namespace xla { enum class HloOpcode { kAbs, kAdd, + kBatchNormTraining, kBitcast, kBroadcast, kCall, @@ -40,6 +41,7 @@ enum class HloOpcode { kConvert, kConvolution, kCopy, + kCos, kCrossReplicaSum, kCustomCall, kDivide, @@ -74,6 +76,7 @@ enum class HloOpcode { kPower, kRecv, kReduce, + kReducePrecision, kReduceWindow, kRemainder, kReshape, diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 72911ae9f91c175d729c3136959cf47029e8a695..7230682d0b1d05ddde8e6d3b7a65319c271669c5 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -15,13 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include #include #include -#include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -113,6 +110,20 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a, // a_ancestor and b_ancestor must be either both null or both non-null. CHECK_NE(b_ancestor, nullptr); CHECK_EQ(a_ancestor->parent(), b_ancestor->parent()); + + // If the common ancestor is a while instruction there is an additional + // ordering criteria which may apply. The condition computation is considered + // to execute before the body computation so if 'a' is in the condition and + // 'b' is in the body, then 'a' executes before 'b'. + if (a_ancestor == b_ancestor && a_ancestor->opcode() == HloOpcode::kWhile) { + const HloComputation* body = a_ancestor->while_body(); + const HloComputation* condition = a_ancestor->while_condition(); + if (call_graph_->InstructionIsNestedIn(a, condition) && + call_graph_->InstructionIsNestedIn(b, body)) { + return true; + } + } + return ExecutesBeforeInSameComputation(a_ancestor, b_ancestor); } @@ -141,7 +152,7 @@ bool PredecessorHloOrdering::ExecutesBeforeInSameComputation( CHECK_EQ(a->parent(), b->parent()); // 'a' executes before 'b' if 'a' is in the strict predecessor set of 'b'. - return strict_predecessors_.at(b->parent())->IsReachable(b, a); + return a != b && predecessors_.at(a->parent())->IsReachable(a, b); } string PredecessorHloOrdering::ToStringHelper(const string& name) const { @@ -153,10 +164,10 @@ string PredecessorHloOrdering::ToStringHelper(const string& name) const { const auto all = computation->MakeInstructionPostOrder(); for (auto instruction : all) { pieces.push_back(tensorflow::strings::Printf( - " %s strict predecessors:", instruction->name().c_str())); + " %s predecessors:", instruction->name().c_str())); for (auto predecessor : all) { - if (strict_predecessors_.at(computation.get()) - ->IsReachable(instruction, predecessor)) { + if (predecessors_.at(computation.get()) + ->IsReachable(predecessor, instruction)) { pieces.push_back( tensorflow::strings::Printf(" %s", predecessor->name().c_str())); } @@ -172,8 +183,8 @@ DependencyHloOrdering::DependencyHloOrdering(const HloModule* module) // ordering based on dependencies. ExecutesBefore will return true iff there // exists a path in the HLO computation graph from 'a' to 'b'. for (auto& computation : module->computations()) { - strict_predecessors_.emplace(computation.get(), - computation->ComputeTransitiveOperands()); + predecessors_.emplace(computation.get(), + computation->ComputeReachability()); } } @@ -238,358 +249,6 @@ string SequentialHloOrdering::ToString() const { return tensorflow::str_util::Join(pieces, "\n"); } -StatusOr MinimumMemoryForSequence( - const SequentialHloOrdering::HloModuleSequence& module_sequence, - const LogicalBuffer::SizeFunction& size_function) { - if (module_sequence.empty()) { - return 0; - } - - const HloModule* module = module_sequence.begin()->first->parent(); - TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, - TuplePointsToAnalysis::Run(module)); - - // The absolute minimum memory required for a given sequence of instructions - // is determined by the sequence of Alloc and Free calls on a simulated heap, - // ignoring fragmentation. We run the heap simulation on the whole module, - // rather than summing each computation, since it gives us a better lower - // bound, by minimizing the liveness of sub-computations. - TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique(), *module, - module_sequence, *points_to_analysis, size_function)); - return result.heap_size; -} - -namespace { - -// Class implementing a list scheduler of HLO instructions which produces a -// sequence which minimizes memory usage. -class ListScheduler { - public: - // Construct and return a memory-minimizing sequence of HLO instructions - // containing the given HLO computation. - static StatusOr> Run( - const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { - ListScheduler scheduler(computation, points_to_analysis, size_function); - return scheduler.CreateSchedule(); - } - - private: - // The scheduling priority of an instruction is first the number of bytes - // freed by scheduling the instruction, and second (tie-breaker) by the number - // of users. This is represented as a std::pair containing these two values - // (first element is the bytes freed). std::pair provides the necessary - // comparison operators. - using Priority = std::pair; - - ListScheduler(const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) - : computation_(computation), - points_to_analysis_(points_to_analysis), - size_function_(size_function) { - // Create a map containing the LogicalBuffer uses for each HLO - // instruction. An HLO instruction "uses" a LogicalBuffer if the - // LogicalBuffer is in an operand of the instruction as indicated by - // points-to analysis. - for (auto& instruction : computation.instructions()) { - buffer_uses_.insert( - {instruction.get(), std::unordered_set()}); - for (auto* operand : instruction->operands()) { - for (const LogicalBuffer* buffer : - points_to_analysis.GetBuffersDefinedByInstruction(operand)) { - buffer_uses_[instruction.get()].insert(buffer); - } - } - } - - // Create map containing the number of unscheduled uses (hlo instructions) - // of each logical buffer. - for (auto& instruction : computation.instructions()) { - for (auto* buffer : points_to_analysis.GetBuffersDefinedByInstruction( - instruction.get())) { - unscheduled_use_count_[buffer] = 0; - } - } - for (auto& instruction : computation.instructions()) { - for (const LogicalBuffer* buffer : buffer_uses_.at(instruction.get())) { - ++unscheduled_use_count_[buffer]; - } - } - - // Buffers live out of the computation have an implicit use at the end of - // the computation. - for (const LogicalBuffer* live_out_buffer : - points_to_analysis.GetPointsToSet(computation.root_instruction()) - .CreateFlattenedSet()) { - ++unscheduled_use_count_[live_out_buffer]; - } - } - - // Returns whether the memory used by the given buffer should be ignored by - // the scheduling heuristic. - bool IgnoreBuffer(const LogicalBuffer& buffer) { - return buffer.instruction()->opcode() == HloOpcode::kParameter || - buffer.instruction()->opcode() == HloOpcode::kConstant; - } - - // Return the number of bytes freed if the HLO instruction is scheduled. - int64 BytesFreedIfScheduled(const HloInstruction* instruction) { - int64 freed_bytes = 0; - // Sum the total size of the values last used by this instruction. - for (auto* buffer : buffer_uses_.at(instruction)) { - if (IgnoreBuffer(*buffer)) { - continue; - } - CHECK_GE(unscheduled_use_count_.at(buffer), 1); - if (unscheduled_use_count_.at(buffer) == 1) { - // This is the last use of the logical buffer. - freed_bytes += size_function_(*buffer); - } - } - // Then subtract the size of the value(s) defined by this instruction. - for (auto* buffer : - points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) { - if (!IgnoreBuffer(*buffer)) { - freed_bytes -= size_function_(*buffer); - } - } - return freed_bytes; - } - - // Construct the scheduling priority of the given instruction. - Priority GetPriority(const HloInstruction* instruction) { - return {BytesFreedIfScheduled(instruction), instruction->user_count()}; - } - - std::vector CreateSchedule() { - std::vector schedule; - - // Populate the ready list with instructions which have no operands or - // control predecessors. - std::unordered_map unscheduled_pred_count; - std::list ready_list; - for (auto& instruction : computation_.instructions()) { - // TODO(b/34466113): Replace this and above with successors() or - // predecessors() when these methods are added to HloInstruction. - for (const HloInstruction* user : instruction->users()) { - unscheduled_pred_count[user]++; - } - for (const HloInstruction* succ : instruction->control_successors()) { - unscheduled_pred_count[succ]++; - } - } - for (auto& instruction : computation_.instructions()) { - // Instruction with no operands or control predecessors will - // not be in the map. - if (unscheduled_pred_count.count(instruction.get()) == 0) { - ready_list.push_back(instruction.get()); - } - } - - while (!ready_list.empty()) { - // Select the highest priority HLO instruction from the ready list. - auto best_it = ready_list.begin(); - Priority best_priority = GetPriority(*best_it); - for (auto ready_it = std::next(ready_list.begin()); - ready_it != ready_list.end(); ++ready_it) { - Priority priority = GetPriority(*ready_it); - if (priority > best_priority) { - best_it = ready_it; - best_priority = priority; - } - } - - // Remove the selected instruction from the ready list and add it to the - // schedule. - const HloInstruction* best = *best_it; - ready_list.erase(best_it); - schedule.push_back(best); - scheduled_instructions_.insert(best); - - // Update the unscheduled uses of the logical buffers. - for (const LogicalBuffer* buffer : buffer_uses_.at(best)) { - CHECK_GT(unscheduled_use_count_.at(buffer), 0); - --unscheduled_use_count_[buffer]; - } - - // Add new instructions to ready list. - auto update_pred_count = [&unscheduled_pred_count, - &ready_list](HloInstruction* inst) { - int64 pred_count = --unscheduled_pred_count.at(inst); - CHECK_GE(pred_count, 0); - if (pred_count == 0) { - ready_list.push_back(inst); - } - }; - // TODO(b/34466113): Replace this and above with successors() or - // predecessors() when these methods are added to HloInstruction. - for (HloInstruction* user : best->users()) { - update_pred_count(user); - } - for (HloInstruction* succ : best->control_successors()) { - update_pred_count(succ); - } - } - CHECK_EQ(schedule.size(), computation_.instructions().size()); - CHECK_EQ(scheduled_instructions_.size(), - computation_.instructions().size()); - - return schedule; - } - - const HloComputation& computation_; - const TuplePointsToAnalysis& points_to_analysis_; - const LogicalBuffer::SizeFunction& size_function_; - - // A map containing the LogicalBuffers that each instruction uses. - std::unordered_map> - buffer_uses_; - - // A map containing the count of unscheduled HLOs which using a particular - // LogicalBuffer. - std::unordered_map unscheduled_use_count_; - - // Set of instructions which have been scheduled. - std::unordered_set scheduled_instructions_; -}; - -int64 SumLogicalBufferSizes(const std::vector& buffers, - const LogicalBuffer::SizeFunction& size_function) { - int64 size = 0; - for (const LogicalBuffer* buffer : buffers) { - size += size_function(*buffer); - } - return size; -} - -StatusOr> RunDFSMemoryScheduler( - const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { - // This ordering is based on DFS post-order, with a heuristic to decide which - // operand to visit first. The heuristic is based on 'extra_users', which is - // simply users-1 for each instruction. By subtracting 1, we're saying that - // instructions with no users or a single user don't count; instructions with - // lots of fan-out will be visited earlier. - tensorflow::gtl::FlatMap extra_users; - tensorflow::gtl::FlatMap total_sizes; - for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) { - extra_users[hlo] = hlo->users().empty() ? 0 : hlo->users().size() - 1; - total_sizes[hlo] = SumLogicalBufferSizes( - points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function); - tensorflow::gtl::FlatSet unique_operands( - hlo->operands().begin(), hlo->operands().end()); - for (const HloInstruction* operand : unique_operands) { - extra_users[hlo] += extra_users[operand]; - total_sizes[hlo] += total_sizes[operand]; - } - } - CHECK_EQ(extra_users.size(), computation.instructions().size()); - CHECK_EQ(total_sizes.size(), computation.instructions().size()); - - // Construct a total order based on DFS post-order, visiting operands in - // decreasing cumulative extra user order, and next by cumulative size, with a - // tiebreaker by name for determinism. - std::vector sequence; - FunctionVisitor visitor([&sequence](HloInstruction* hlo) { - sequence.push_back(hlo); - return Status::OK(); - }); - TF_RETURN_IF_ERROR(computation.AcceptWithOperandOrder( - &visitor, [&extra_users, &total_sizes](const HloInstruction* a, - const HloInstruction* b) { - if (extra_users[a] != extra_users[b]) { - return extra_users[a] > extra_users[b]; - } - if (total_sizes[a] != total_sizes[b]) { - return total_sizes[a] > total_sizes[b]; - } - return a->name() < b->name(); - })); - CHECK_EQ(sequence.size(), computation.instructions().size()); - return sequence; -} - -StatusOr MinimumMemoryForComputation( - const HloComputation& computation, - const std::vector& sequence, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { - TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique(), computation, - sequence, points_to_analysis, size_function)); - return result.heap_size; -} - -StatusOr> CreateMemoryMinimizingSequence( - const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { - // We try both a list-scheduler based ordering and a DFS based ordering, and - // choose whichever returns a lower min-memory, not accounting for - // fragmentation. - // - // Note that this is just a heuristic. One obvious inaccuracy is that the - // memory required for sub-computations might be different when considered - // within the caller's context. But it's good enough for now. - TF_ASSIGN_OR_RETURN( - std::vector list_sequence, - ListScheduler::Run(computation, points_to_analysis, size_function)); - TF_ASSIGN_OR_RETURN( - const int64 list_memory, - MinimumMemoryForComputation(computation, list_sequence, - points_to_analysis, size_function)); - VLOG(2) << "Min-memory list sequence: " << list_memory << " bytes"; - - TF_ASSIGN_OR_RETURN( - std::vector dfs_sequence, - RunDFSMemoryScheduler(computation, points_to_analysis, size_function)); - TF_ASSIGN_OR_RETURN( - const int64 dfs_memory, - MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis, - size_function)); - VLOG(2) << "Min-memory dfs sequence: " << dfs_memory << " bytes"; - - if (list_memory <= dfs_memory) { - VLOG(2) << "Chose min-memory list sequence: " << list_memory << " bytes"; - return list_sequence; - } else { - VLOG(2) << "Chose min-memory dfs sequence: " << dfs_memory << " bytes"; - return dfs_sequence; - } -} - -} // namespace - -StatusOr -CreateMemoryMinimizingSequence( - const HloModule& module, const LogicalBuffer::SizeFunction& size_function) { - SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, - TuplePointsToAnalysis::Run(&module)); - for (const auto& computation : module.computations()) { - TF_ASSIGN_OR_RETURN(sequence[computation.get()], - CreateMemoryMinimizingSequence( - *computation, *points_to_analysis, size_function)); - } - return sequence; -} - -StatusOr> CreateMemoryMinimizingSequence( - const HloComputation& computation, - const LogicalBuffer::SizeFunction& size_function) { - TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, - TuplePointsToAnalysis::Run(computation.parent())); - return CreateMemoryMinimizingSequence(computation, *points_to_analysis, - size_function); -} - std::ostream& operator<<( std::ostream& out, const SequentialHloOrdering::HloModuleSequence& module_sequence) { diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index b59e1ea5eb0ad4882d4c2b96ee6ab6d1bc973993..130431f28070d52c3a76befa0d5272a3cc295711 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -24,12 +24,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" -#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -72,8 +68,8 @@ class HloOrdering { std::unique_ptr call_graph_; }; -// Base class for partial orderings implemented by a map of strict predecessors -// for each instruction. Subclasses should fill in strict_predecessors_. +// Base class for partial orderings implemented by a map of predecessors for +// each instruction. Subclasses should fill in predecessors_. class PredecessorHloOrdering : public HloOrdering { public: ~PredecessorHloOrdering() override = default; @@ -93,13 +89,12 @@ class PredecessorHloOrdering : public HloOrdering { const HloInstruction* b) const override; // For each computation in the module, this is the set of the instruction's - // strict predecessors. An instruction is not an element of its own strict - // predecessor set. + // predecessors. An instruction is an element of its own predecessor set. // // Subclasses should fill this in to define the desired ordering. tensorflow::gtl::FlatMap> - strict_predecessors_; + std::unique_ptr> + predecessors_; }; // An HLO ordering based on data dependencies in the HLO graph. In this partial @@ -191,24 +186,6 @@ std::ostream& operator<<( std::ostream& out, const SequentialHloOrdering::HloModuleSequence& module_sequence); -// Returns the minimum memory required to compute the given module sequence, -// assuming no fragmentation. -StatusOr MinimumMemoryForSequence( - const SequentialHloOrdering::HloModuleSequence& module_sequence, - const LogicalBuffer::SizeFunction& size_function); - -// Returns an HloModuleSequence which seeks to minimize the memory required for -// the computation. size_function is the function returning the number of bytes -// required for a LogicalBuffer. -StatusOr -CreateMemoryMinimizingSequence( - const HloModule& module, const LogicalBuffer::SizeFunction& size_function); - -// Overload of above that computes the sequence for a single computation. -StatusOr> CreateMemoryMinimizingSequence( - const HloComputation& computation, - const LogicalBuffer::SizeFunction& size_function); - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ORDERING_H_ diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index 21d852a51d67b2aadc0edea144f60a037a004614..a1e38803c43933a1afba190ac01eb712bbe6fce1 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" @@ -101,7 +102,7 @@ TEST_F(HloOrderingTest, InstructionsInDifferentComputations) { auto builder_c = HloComputation::Builder("C"); HloInstruction* c = builder_c.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); HloComputation* computation_c = module->AddEmbeddedComputation(builder_c.Build()); @@ -155,67 +156,69 @@ TEST_F(HloOrderingTest, InstructionsInDifferentComputations) { EXPECT_FALSE(ordering.ExecutesBefore(y, c)); } -class MinimumMemoryForSequenceTest : public HloTestBase {}; - -TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { +TEST_F(HloOrderingTest, InstructionsInWhileComputations) { + // Tests the ordering of instructions in the body and condition of a while + // instruction. HLO code: + // + // body(F32[]) %param): + // %negate = Negate(%param) + // + // condition(F32[] %param): + // %convert = Convert(%param) + // + // entry: + // %constant = Constant(1.0) + // return While(%constant, body, condition) + // auto module = CreateNewModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); - const Shape tuple_shape = - ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); - - auto cond_builder = HloComputation::Builder("WhileCond"); - // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) - HloInstruction* cond_param = cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); - HloInstruction* cond_iter = cond_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0)); - HloInstruction* cond_data = cond_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); - // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte) - HloInstruction* cond_lt = cond_builder.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), - HloOpcode::kLt, cond_iter, cond_data)); - HloComputation* cond_computation = - module->AddEmbeddedComputation(cond_builder.Build()); - auto body_builder = HloComputation::Builder("WhileBody"); - // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) - HloInstruction* body_param = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, tuple_shape, "body_param")); - HloComputation* body_computation = - module->AddEmbeddedComputation(body_builder.Build()); + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "body_param")); + auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary( + scalar_shape, HloOpcode::kNegate, body_param)); + HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); + + auto cond_builder = HloComputation::Builder("condition"); + auto cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "cond_param")); + auto convert = cond_builder.AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::MakeShape(xla::PRED, {}), cond_param)); + HloComputation* condition = + module->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); - // Entry params: 8 bytes (4 bytes per param), TOTAL=8 - HloInstruction* iter = builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape, "param_iter")); - HloInstruction* data = builder.AddInstruction( - HloInstruction::CreateParameter(1, scalar_shape, "param_data")); - // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24 - HloInstruction* tuple = - builder.AddInstruction(HloInstruction::CreateTuple({iter, data})); - // While: 8 bytes (4 bytes per element), TOTAL=32 - // Both cond and body use a max of 24 bytes, TOTAL=56 - HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( - tuple_shape, cond_computation, body_computation, tuple)); - HloComputation* entry_computation = - module->AddEntryComputation(builder.Build()); - - auto size_fn = [](const LogicalBuffer& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); - }; - - SequentialHloOrdering::HloModuleSequence module_sequence; - module_sequence[cond_computation] = {cond_param, cond_iter, cond_data, - cond_lt}; - module_sequence[body_computation] = {body_param}; - module_sequence[entry_computation] = {iter, data, tuple, while_op}; - EXPECT_EQ(56, - MinimumMemoryForSequence(module_sequence, size_fn).ValueOrDie()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(scalar_shape, condition, body, constant)); + module->AddEntryComputation(builder.Build()); + + DependencyHloOrdering ordering(module.get()); + EXPECT_TRUE(ordering.ExecutesBefore(constant, xla_while)); + EXPECT_TRUE(ordering.ExecutesBefore(constant, cond_param)); + EXPECT_TRUE(ordering.ExecutesBefore(constant, convert)); + EXPECT_TRUE(ordering.ExecutesBefore(constant, body_param)); + EXPECT_TRUE(ordering.ExecutesBefore(constant, negate)); + + // The while should be unordered relative to the body and condition + // instructions. + EXPECT_FALSE(ordering.ExecutesBefore(xla_while, body_param)); + EXPECT_FALSE(ordering.ExecutesBefore(xla_while, cond_param)); + EXPECT_FALSE(ordering.ExecutesBefore(body_param, xla_while)); + EXPECT_FALSE(ordering.ExecutesBefore(cond_param, xla_while)); + + // Condition instructions should be ordered before body instructions. + EXPECT_TRUE(ordering.ExecutesBefore(cond_param, body_param)); + EXPECT_TRUE(ordering.ExecutesBefore(convert, body_param)); + EXPECT_TRUE(ordering.ExecutesBefore(cond_param, negate)); + EXPECT_TRUE(ordering.ExecutesBefore(convert, negate)); + + EXPECT_FALSE(ordering.ExecutesBefore(body_param, cond_param)); } } // namespace - } // namespace xla int main(int argc, char** argv) { diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 119e2d79022dca094147348d83c59b9a04cb339f..4b824f8240074e7ae70b9d9fa82dfa0706d5b355 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" @@ -30,9 +31,10 @@ using ::tensorflow::strings::StrAppend; namespace xla { namespace { -void DumpModule(const Compiler::HloDumper& dumper_, const HloModule& module, +void DumpModule(const HloModule& module, + const string& message) { - dumper_(module, message); + hlo_graph_dumper::MaybeDumpHloModule(module, message); VLOG(2) << "HLO " << message << ":"; XLA_VLOG_LINES(2, module.ToString()); } @@ -75,7 +77,7 @@ StatusOr HloPassPipeline::Run(HloModule* module) { // Emit label containing: "after foo-pass, before bar-pass". message.clear(); StrAppend(&message, prefix, ", before ", pass->name()); - DumpModule(dumper_, *module, message); + DumpModule(*module, message); TF_RETURN_IF_ERROR(run_invariant_checkers()); TF_ASSIGN_OR_RETURN(bool changed_this_pass, pass->Run(module)); @@ -85,7 +87,7 @@ StatusOr HloPassPipeline::Run(HloModule* module) { StrAppend(&prefix, name(), ": after ", pass->name()); } TF_RETURN_IF_ERROR(run_invariant_checkers()); - DumpModule(dumper_, *module, prefix + ", pipeline end"); + DumpModule(*module, prefix + ", pipeline end"); return changed; } diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h index 682c4b952df6aae8cb933c222772dbd823070ecc..a42d7e59fed2d838dfe3cb7f99e6b946edfdb0b4 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h @@ -22,7 +22,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/statusor.h" @@ -34,9 +33,7 @@ namespace xla { // Pipeline of HLO passes. class HloPassPipeline : public HloPassInterface { public: - explicit HloPassPipeline(const string& name, - const Compiler::HloDumper& dumper) - : name_(name), dumper_(dumper) {} + explicit HloPassPipeline(const string& name) : name_(name) {} tensorflow::StringPiece name() const override { return name_; } // Add a pass to the pipeline. It should be called with the arguments for the @@ -69,7 +66,6 @@ class HloPassPipeline : public HloPassInterface { private: const string name_; - Compiler::HloDumper dumper_; std::vector> passes_; std::vector> invariant_checkers_; bool run_called_ = false; diff --git a/tensorflow/compiler/xla/service/hlo_query.cc b/tensorflow/compiler/xla/service/hlo_query.cc index a153d73dbd838663c0d7e0d72ad54668f243f2c2..d45038f1f4a2e4aa19234eec93fdc9a068a902e1 100644 --- a/tensorflow/compiler/xla/service/hlo_query.cc +++ b/tensorflow/compiler/xla/service/hlo_query.cc @@ -25,7 +25,7 @@ namespace hlo_query { bool IsConstantR0F32(HloInstruction* instruction, float* out) { if (instruction->opcode() == HloOpcode::kConstant && ShapeUtil::IsScalarF32(instruction->shape())) { - *out = LiteralUtil::Get(instruction->literal(), {}); + *out = instruction->literal().Get({}); return true; } diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc new file mode 100644 index 0000000000000000000000000000000000000000..fb7ecbdc2a09e6e797d283675ccf2c26f9c1a34c --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_reachability.cc @@ -0,0 +1,61 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_reachability.h" + +namespace xla { + +HloReachabilityMap::HloReachabilityMap( + const std::list& instructions) + : size_(instructions.size()) { + bit_vectors_.reserve(size_); + for (const HloInstruction* hlo : instructions) { + indices_[hlo] = bit_vectors_.size(); + bit_vectors_.emplace_back(size_); + } + CHECK_EQ(size_, indices_.size()); // instructions should be unique +} + +bool HloReachabilityMap::SetReachabilityToUnion( + tensorflow::gtl::ArraySlice inputs, + const HloInstruction* instruction) { + BitVector& bit_vector = GetBitVector(instruction); + tmp_bit_vector_ = bit_vector; + + bit_vector.SetToZero(); + bit_vector.Set(GetIndex(instruction)); + for (const HloInstruction* input : inputs) { + bit_vector.OrWith(GetBitVector(input)); + } + + return bit_vector != tmp_bit_vector_; +} + +void HloReachabilityMap::SetReachable(const HloInstruction* a, + const HloInstruction* b) { + GetBitVector(b).Set(GetIndex(a)); +} + +bool HloReachabilityMap::IsReachable(const HloInstruction* a, + const HloInstruction* b) const { + return GetBitVector(b).Get(GetIndex(a)); +} + +bool HloReachabilityMap::IsConnected(const HloInstruction* a, + const HloInstruction* b) const { + return IsReachable(a, b) || IsReachable(b, a); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h new file mode 100644 index 0000000000000000000000000000000000000000..d7bdac9c86579f19afbba133772c2c50894853d1 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_reachability.h @@ -0,0 +1,138 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_ + +#include +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +class HloInstruction; + +// A class for computing and representing reachability between HloInstructions. +class HloReachabilityMap { + public: + // Sets up an empty reachable matrix for the full set of instructions + // specified in 'instructions'. + explicit HloReachabilityMap(const std::list& instructions); + + // Set the reachability set of 'instruction' to the union of the reachability + // sets of 'inputs'. Upon return, IsReachable(x, instruction) where + // 'x' is not 'instruction' will return true iff IsReachable(x, input) is true + // for some 'input' in 'inputs'. Also sets 'instruction' to be reachable from + // itself. Returns whether the reachability set of 'instruction' changed. + bool SetReachabilityToUnion( + tensorflow::gtl::ArraySlice inputs, + const HloInstruction* instruction); + + // Sets entry so that IsReachable(a, b) will return true + void SetReachable(const HloInstruction* a, const HloInstruction* b); + + // Returns true if "b" is reachable from "a" + bool IsReachable(const HloInstruction* a, const HloInstruction* b) const; + + // Returns true if "b" is reachable from "a" or "a" is reachable from "b" + bool IsConnected(const HloInstruction* a, const HloInstruction* b) const; + + private: + // A bit-vector implementation specialized for this use case which provides a + // fast bitwise OR operation not available in tensorflow::gtl::BitMap. + class BitVector { + public: + BitVector() = default; + BitVector(size_t size) + : size_(size), vector_((size + kBits - 1) / kBits, 0) {} + + // Return the bit at the given index. + bool Get(size_t index) const { + DCHECK(index >= 0 && index < size_); + return vector_[index / kBits] & (1ull << (index % kBits)); + } + + // Set the bit at the given index. + void Set(size_t index) { + DCHECK(index >= 0 && index < size_); + vector_[index / kBits] |= 1ull << (index % kBits); + } + + // Set this bitvector to the Logical OR of this bitvector and 'other'. + void OrWith(const BitVector& other) { + for (size_t i = 0; i < vector_.size(); ++i) { + vector_[i] |= other.vector_[i]; + } + } + + // Set the bitvector to all zeros. + void SetToZero() { std::fill(vector_.begin(), vector_.end(), 0); } + + bool operator==(const BitVector& other) const { + return vector_ == other.vector_; + } + bool operator!=(const BitVector& other) const { + return vector_ != other.vector_; + } + + private: + using Word = uint64; + static const size_t kBits = 64; + + // Number of bits in the bitvector. + size_t size_; + + std::vector vector_; + }; + + // Return the bitvector storing the reachability-to of the given instruction. + const BitVector& GetBitVector(const HloInstruction* instruction) const { + return bit_vectors_[GetIndex(instruction)]; + } + BitVector& GetBitVector(const HloInstruction* instruction) { + return bit_vectors_[GetIndex(instruction)]; + } + + // Return the index of the given instruction. The value is used to index into + // the vector of BitVectors and the BitVectors themselves. + int GetIndex(const HloInstruction* instruction) const { + return FindOrDie(indices_, instruction); + } + + // The number of instructions in the reachability map. + const size_t size_; + + // Dense assignment from HloInstruction* to number. These numbers index + // into the bit_vectors_ vector and into the bits within a BitVector. + tensorflow::gtl::FlatMap indices_; + + // Bitvectors holding the reachability to each instruction. The bit vector for + // instruction X includes ones for each instruction which X is reachable from. + std::vector bit_vectors_; + + // A temporary used by SetReachabilityToUnion to avoid an allocation with each + // call to the method. + BitVector tmp_bit_vector_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_ diff --git a/tensorflow/compiler/xla/service/hlo_reachability_test.cc b/tensorflow/compiler/xla/service/hlo_reachability_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..657a9ee83d29e72b95660325f9139f44159d6508 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_reachability_test.cc @@ -0,0 +1,86 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_reachability.h" + +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { + +namespace { + +class HloReachabilityTest : public HloTestBase {}; + +TEST_F(HloReachabilityTest, Reachability) { + // Construct and test a reachability graph of the following form: + /* + a + / \ + b c + \ / \ + d e + */ + auto builder = HloComputation::Builder(TestName()); + auto a = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + auto b = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + auto c = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + auto d = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + auto e = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + builder.Build(); + + HloReachabilityMap reachability({a, b, c, d, e}); + reachability.SetReachable(a, a); + EXPECT_TRUE(reachability.SetReachabilityToUnion({a}, b)); + EXPECT_TRUE(reachability.SetReachabilityToUnion({a}, c)); + EXPECT_TRUE(reachability.SetReachabilityToUnion({b, c}, d)); + EXPECT_TRUE(reachability.SetReachabilityToUnion({c}, e)); + + EXPECT_TRUE(reachability.IsReachable(a, a)); + EXPECT_TRUE(reachability.IsReachable(a, b)); + EXPECT_TRUE(reachability.IsReachable(a, c)); + EXPECT_TRUE(reachability.IsReachable(a, d)); + EXPECT_TRUE(reachability.IsReachable(a, e)); + + EXPECT_FALSE(reachability.IsReachable(b, a)); + EXPECT_TRUE(reachability.IsReachable(b, b)); + EXPECT_FALSE(reachability.IsReachable(b, c)); + EXPECT_TRUE(reachability.IsReachable(b, d)); + EXPECT_FALSE(reachability.IsReachable(b, e)); + + EXPECT_FALSE(reachability.IsReachable(e, a)); + EXPECT_FALSE(reachability.IsReachable(e, b)); + EXPECT_FALSE(reachability.IsReachable(e, c)); + EXPECT_FALSE(reachability.IsReachable(e, d)); + EXPECT_TRUE(reachability.IsReachable(e, e)); + + // Recomputing the same reachability for a previously computed instruction + // should return false (no change). + EXPECT_FALSE(reachability.SetReachabilityToUnion({a}, b)); + EXPECT_FALSE(reachability.SetReachabilityToUnion({b, c}, d)); +} + +} // namespace + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 2c1b0fff4e602a172cfa54d4eaa626198a426873..d19e8034acd664e1dc57fe0ce1df9b4b6ce4d9db 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -58,9 +59,8 @@ bool IsRematerializable(const HloInstruction* instruction) { return false; } - // Don't rematerialize instructions with side effects, those with a cost that - // might not be captured by HloCostAnalysis, or instructions which cannot be - // cloned safely. + // Don't rematerialize instructions with side effects or instructions which + // cannot be cloned safely. switch (instruction->opcode()) { case HloOpcode::kCall: case HloOpcode::kConstant: @@ -802,23 +802,14 @@ bool MemoryUsageTracker::Check() const { // Computes and returns the cost of rematerializing the given instruction. // Cost per rematerialized instruction is defined as: // -// (flop_count + transcendental_count + element_count) / memory_reduced +// memory_limit_bytes / memory_reduced // -// flop_count: from HloCostAnalysis -// transcendental_count: from HloCostAnalysis -// element_count: number of elements accessed in operands and output of -// instruction -// memory_reduced: The memory usage reduced by rematerializing the -// instruction. -// -// This is a rough estimate of the extra execution time per byte saved by -// rematerializing this instruction for its remaining uses. In general, we -// want the most memory saving for the least latency penalty which is captured -// by this heuristic. +// The idea is to choose the operation that will save the most memory for +// rematerialization and do not worry about how much the compute costs since +// running out of memory is more harmful than taking longer to get the answer. int64 RematerializationCost(const HloInstruction* instruction, const MemoryUsageTracker& memory_tracker, - const HloCostAnalysis& cost_analysis, - int64 memory_reduced) { + int64 memory_reduced, int64 memory_limit_bytes) { // If none of the users of 'instruction' have been placed in the sequence (as // tracked by memory_tracker), then rematerialization of 'instruction' is a // zero-cost move of 'instruction' in the sequence. @@ -830,22 +821,8 @@ int64 RematerializationCost(const HloInstruction* instruction, } CHECK_GT(memory_reduced, 0); - const int64 bytes_accessed = cost_analysis.bytes_accessed(*instruction); - const int64 elements_accessed = - ShapeUtil::IsTuple(instruction->shape()) - ? bytes_accessed - : bytes_accessed / ShapeUtil::ByteSizeOfPrimitiveType( - instruction->shape().element_type()); - - // Multiply by 256 to improve precision of cost. Without this factor, - // many instructions such as many elementwise instructions would have - // zero cost because the bytes reduced can be several times greater than - // the element count. - return 256 * - (cost_analysis.flop_count(*instruction) + - cost_analysis.transcendental_count(*instruction) + - elements_accessed) / - memory_reduced; + // Return the inverse of the benefit of rematerialization. + return memory_limit_bytes / memory_reduced; } // Selects and returns the best candidate instruction for rematerialization. @@ -856,8 +833,8 @@ int64 RematerializationCost(const HloInstruction* instruction, HloInstruction* PickRematerializationCandidate( const MemoryUsageTracker& memory_tracker, const InstructionList& instruction_list, - const HloCostAnalysis& cost_analysis, - const tensorflow::gtl::FlatSet& blacklist) { + const tensorflow::gtl::FlatSet& blacklist, + int64 memory_limit_bytes) { HloInstruction* best = nullptr; int64 best_cost = 0; @@ -891,12 +868,12 @@ HloInstruction* PickRematerializationCandidate( if (memory_reduced <= 0) { VLOG(5) << "candidate " << candidate->name() - << " memory reduced = " << memory_reduced << " <= 0"; + << " memory reduced = " << memory_reduced << " <= 0"; continue; } const int cost = RematerializationCost(candidate, memory_tracker, - cost_analysis, memory_reduced); + memory_reduced, memory_limit_bytes); VLOG(5) << "candidate " << candidate->name() << ", memory reduced " << memory_reduced << ", cost per byte " << cost; @@ -1011,7 +988,7 @@ StatusOr HloRematerialization::RematerializeComputation( << ", limit is " << HumanReadableNumBytes(memory_limit_bytes); HloInstruction* best = PickRematerializationCandidate( - memory_tracker, instruction_list, cost_analysis_, blacklist); + memory_tracker, instruction_list, blacklist, memory_limit_bytes); if (best == nullptr) { VLOG(3) << "Unable to find rematerialization candidate at program " @@ -1211,11 +1188,6 @@ StatusOr HloRematerialization::Run( VLOG(1) << "Peak memory usage of module (before): " << HumanReadableNumBytes(before_peak_memory); - // Run cost analysis. Operation cost is used in the heuristic for selecting - // instructions for rematerialization. - TF_RETURN_IF_ERROR( - module->entry_computation()->root_instruction()->Accept(&cost_analysis_)); - // Subcomputations called by the entry computation will also be // rematerialized. TF_ASSIGN_OR_RETURN(bool changed, RematerializeComputation( diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 1693f93183bc59c343e3c765cb4051566d4377ef..42c279d440b78d90b9f19b92155c52787156e4b7 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -18,7 +18,6 @@ #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" @@ -61,7 +60,7 @@ class HloRematerialization { protected: HloRematerialization(const ShapeSizeFunction& size_function) - : size_function_(size_function), cost_analysis_(size_function_) {} + : size_function_(size_function) {} ~HloRematerialization() {} // Runs rematerialization on the given module. Returns whether the module was @@ -100,9 +99,6 @@ class HloRematerialization { // Call graph of the hlo_module. std::unique_ptr call_graph_; - // Analysis used for computing the rematerialization cost of instructions. - HloCostAnalysis cost_analysis_; - // The peak memory usage of each computation. The map contains only those // computations called from sequential context // (CallContext::kSequential). These values are updated as rematerialization diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index f306bcc309c6c5e57a311496ee0370741de8a6ab..3a935dcf968abbd3580d7e8df8c147afe7cdf8f8 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -126,7 +126,7 @@ class HloRematerializationTest : public HloTestBase { builder.AddInstruction( HloInstruction::CreateParameter(0, vec1_shape_, "param")); builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + HloInstruction::CreateConstant(Literal::CreateR0(true))); return builder.Build(); } @@ -215,7 +215,7 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, vec1_shape_, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + HloInstruction::CreateConstant(Literal::CreateR0(true))); HloComputation* while_cond = module->AddEmbeddedComputation(cond_builder.Build()); @@ -254,7 +254,7 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, vec1_shape_, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + HloInstruction::CreateConstant(Literal::CreateR0(true))); HloComputation* while_cond = module->AddEmbeddedComputation(cond_builder.Build()); @@ -289,7 +289,7 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, vec1_shape_, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + HloInstruction::CreateConstant(Literal::CreateR0(true))); HloComputation* while_cond = module->AddEmbeddedComputation(cond_builder.Build()); @@ -357,7 +357,7 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { /*dimension=*/0)); builder.AddInstruction(HloInstruction::CreateSlice( vec1024_shape_, concat, /*start_indices=*/{0}, - /*limit_indices=*/{1024}, /*slices=*/{1})); + /*limit_indices=*/{1024}, /*strides=*/{1})); subcomputation = module->AddEmbeddedComputation(builder.Build()); } @@ -473,7 +473,7 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { /*dimension=*/0)); builder.AddInstruction(HloInstruction::CreateSlice( vec1024_shape_, concat, /*start_indices=*/{0}, - /*limit_indices=*/{1024}, /*slices=*/{1})); + /*limit_indices=*/{1024}, /*strides=*/{1})); subcomputation = module->AddEmbeddedComputation(builder.Build()); } diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc new file mode 100644 index 0000000000000000000000000000000000000000..f8e05448da988026da09ed19c2d8be6f262ea55c --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -0,0 +1,388 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_scheduling.h" + +#include +#include + +#include "tensorflow/compiler/xla/service/heap_simulator.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +StatusOr MinimumMemoryForSequence( + const SequentialHloOrdering::HloModuleSequence& module_sequence, + const LogicalBuffer::SizeFunction& size_function) { + if (module_sequence.empty()) { + return 0; + } + + const HloModule* module = module_sequence.begin()->first->parent(); + TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, + TuplePointsToAnalysis::Run(module)); + + // The absolute minimum memory required for a given sequence of instructions + // is determined by the sequence of Alloc and Free calls on a simulated heap, + // ignoring fragmentation. We run the heap simulation on the whole module, + // rather than summing each computation, since it gives us a better lower + // bound, by minimizing the liveness of sub-computations. + TF_ASSIGN_OR_RETURN( + HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique(), *module, + module_sequence, *points_to_analysis, size_function)); + return result.heap_size; +} + +namespace { + +// Class implementing a list scheduler of HLO instructions which produces a +// sequence which minimizes memory usage. +class ListScheduler { + public: + // Construct and return a memory-minimizing sequence of HLO instructions + // containing the given HLO computation. + static StatusOr> Run( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function) { + ListScheduler scheduler(computation, points_to_analysis, size_function); + return scheduler.CreateSchedule(); + } + + private: + // The scheduling priority of an instruction is first the number of bytes + // freed by scheduling the instruction, and second (tie-breaker) by the number + // of users. This is represented as a std::pair containing these two values + // (first element is the bytes freed). std::pair provides the necessary + // comparison operators. + using Priority = std::pair; + + ListScheduler(const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function) + : computation_(computation), + points_to_analysis_(points_to_analysis), + size_function_(size_function) { + // Create a map containing the LogicalBuffer uses for each HLO + // instruction. An HLO instruction "uses" a LogicalBuffer if the + // LogicalBuffer is in an operand of the instruction as indicated by + // points-to analysis. + for (auto& instruction : computation.instructions()) { + buffer_uses_.insert( + {instruction.get(), std::unordered_set()}); + for (auto* operand : instruction->operands()) { + for (const LogicalBuffer* buffer : + points_to_analysis.GetBuffersDefinedByInstruction(operand)) { + buffer_uses_[instruction.get()].insert(buffer); + } + } + } + + // Create map containing the number of unscheduled uses (hlo instructions) + // of each logical buffer. + for (auto& instruction : computation.instructions()) { + for (auto* buffer : points_to_analysis.GetBuffersDefinedByInstruction( + instruction.get())) { + unscheduled_use_count_[buffer] = 0; + } + } + for (auto& instruction : computation.instructions()) { + for (const LogicalBuffer* buffer : buffer_uses_.at(instruction.get())) { + ++unscheduled_use_count_[buffer]; + } + } + + // Buffers live out of the computation have an implicit use at the end of + // the computation. + for (const LogicalBuffer* live_out_buffer : + points_to_analysis.GetPointsToSet(computation.root_instruction()) + .CreateFlattenedSet()) { + ++unscheduled_use_count_[live_out_buffer]; + } + } + + // Returns whether the memory used by the given buffer should be ignored by + // the scheduling heuristic. + bool IgnoreBuffer(const LogicalBuffer& buffer) { + return buffer.instruction()->opcode() == HloOpcode::kParameter || + buffer.instruction()->opcode() == HloOpcode::kConstant; + } + + // Return the number of bytes freed if the HLO instruction is scheduled. + int64 BytesFreedIfScheduled(const HloInstruction* instruction) { + int64 freed_bytes = 0; + // Sum the total size of the values last used by this instruction. + for (auto* buffer : buffer_uses_.at(instruction)) { + if (IgnoreBuffer(*buffer)) { + continue; + } + CHECK_GE(unscheduled_use_count_.at(buffer), 1); + if (unscheduled_use_count_.at(buffer) == 1) { + // This is the last use of the logical buffer. + freed_bytes += size_function_(*buffer); + } + } + // Then subtract the size of the value(s) defined by this instruction. + for (auto* buffer : + points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) { + if (!IgnoreBuffer(*buffer)) { + freed_bytes -= size_function_(*buffer); + } + } + return freed_bytes; + } + + // Construct the scheduling priority of the given instruction. + Priority GetPriority(const HloInstruction* instruction) { + return {BytesFreedIfScheduled(instruction), instruction->user_count()}; + } + + std::vector CreateSchedule() { + std::vector schedule; + + // Populate the ready list with instructions which have no operands or + // control predecessors. + std::unordered_map unscheduled_pred_count; + std::list ready_list; + for (auto& instruction : computation_.instructions()) { + // TODO(b/34466113): Replace this and above with successors() or + // predecessors() when these methods are added to HloInstruction. + for (const HloInstruction* user : instruction->users()) { + unscheduled_pred_count[user]++; + } + for (const HloInstruction* succ : instruction->control_successors()) { + unscheduled_pred_count[succ]++; + } + } + for (auto& instruction : computation_.instructions()) { + // Instruction with no operands or control predecessors will + // not be in the map. + if (unscheduled_pred_count.count(instruction.get()) == 0) { + ready_list.push_back(instruction.get()); + } + } + + while (!ready_list.empty()) { + // Select the highest priority HLO instruction from the ready list. + auto best_it = ready_list.begin(); + Priority best_priority = GetPriority(*best_it); + for (auto ready_it = std::next(ready_list.begin()); + ready_it != ready_list.end(); ++ready_it) { + Priority priority = GetPriority(*ready_it); + if (priority > best_priority) { + best_it = ready_it; + best_priority = priority; + } + } + + // Remove the selected instruction from the ready list and add it to the + // schedule. + const HloInstruction* best = *best_it; + ready_list.erase(best_it); + schedule.push_back(best); + scheduled_instructions_.insert(best); + + // Update the unscheduled uses of the logical buffers. + for (const LogicalBuffer* buffer : buffer_uses_.at(best)) { + CHECK_GT(unscheduled_use_count_.at(buffer), 0); + --unscheduled_use_count_[buffer]; + } + + // Add new instructions to ready list. + auto update_pred_count = [&unscheduled_pred_count, + &ready_list](HloInstruction* inst) { + int64 pred_count = --unscheduled_pred_count.at(inst); + CHECK_GE(pred_count, 0); + if (pred_count == 0) { + ready_list.push_back(inst); + } + }; + // TODO(b/34466113): Replace this and above with successors() or + // predecessors() when these methods are added to HloInstruction. + for (HloInstruction* user : best->users()) { + update_pred_count(user); + } + for (HloInstruction* succ : best->control_successors()) { + update_pred_count(succ); + } + } + CHECK_EQ(schedule.size(), computation_.instructions().size()); + CHECK_EQ(scheduled_instructions_.size(), + computation_.instructions().size()); + + return schedule; + } + + const HloComputation& computation_; + const TuplePointsToAnalysis& points_to_analysis_; + const LogicalBuffer::SizeFunction& size_function_; + + // A map containing the LogicalBuffers that each instruction uses. + std::unordered_map> + buffer_uses_; + + // A map containing the count of unscheduled HLOs which using a particular + // LogicalBuffer. + std::unordered_map unscheduled_use_count_; + + // Set of instructions which have been scheduled. + std::unordered_set scheduled_instructions_; +}; + +int64 SumLogicalBufferSizes(const std::vector& buffers, + const LogicalBuffer::SizeFunction& size_function) { + int64 size = 0; + for (const LogicalBuffer* buffer : buffers) { + size += size_function(*buffer); + } + return size; +} + +StatusOr> RunDFSMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function) { + // This ordering is based on DFS post-order, with a heuristic to decide which + // operand to visit first. The heuristic is based on 'extra_users', which is + // simply users-1 for each instruction. By subtracting 1, we're saying that + // instructions with no users or a single user don't count; instructions with + // lots of fan-out will be visited earlier. + tensorflow::gtl::FlatMap extra_users; + tensorflow::gtl::FlatMap total_sizes; + for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) { + extra_users[hlo] = hlo->users().empty() ? 0 : hlo->users().size() - 1; + total_sizes[hlo] = SumLogicalBufferSizes( + points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function); + tensorflow::gtl::FlatSet unique_operands( + hlo->operands().begin(), hlo->operands().end()); + for (const HloInstruction* operand : unique_operands) { + extra_users[hlo] += extra_users[operand]; + total_sizes[hlo] += total_sizes[operand]; + } + } + CHECK_EQ(extra_users.size(), computation.instructions().size()); + CHECK_EQ(total_sizes.size(), computation.instructions().size()); + + // Construct a total order based on DFS post-order, visiting operands in + // decreasing cumulative extra user order, and next by cumulative size, with a + // tiebreaker by name for determinism. + std::vector sequence; + FunctionVisitor visitor([&sequence](HloInstruction* hlo) { + sequence.push_back(hlo); + return Status::OK(); + }); + TF_RETURN_IF_ERROR(computation.AcceptWithOperandOrder( + &visitor, [&extra_users, &total_sizes](const HloInstruction* a, + const HloInstruction* b) { + if (extra_users[a] != extra_users[b]) { + return extra_users[a] > extra_users[b]; + } + if (total_sizes[a] != total_sizes[b]) { + return total_sizes[a] > total_sizes[b]; + } + return a->name() < b->name(); + })); + CHECK_EQ(sequence.size(), computation.instructions().size()); + return sequence; +} + +StatusOr MinimumMemoryForComputation( + const HloComputation& computation, + const std::vector& sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function) { + TF_ASSIGN_OR_RETURN( + HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique(), computation, + sequence, points_to_analysis, size_function)); + return result.heap_size; +} + +StatusOr> CreateMemoryMinimizingSequence( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function) { + // We try both a list-scheduler based ordering and a DFS based ordering, and + // choose whichever returns a lower min-memory, not accounting for + // fragmentation. + // + // Note that this is just a heuristic. One obvious inaccuracy is that the + // memory required for sub-computations might be different when considered + // within the caller's context. But it's good enough for now. + TF_ASSIGN_OR_RETURN( + std::vector list_sequence, + ListScheduler::Run(computation, points_to_analysis, size_function)); + TF_ASSIGN_OR_RETURN( + const int64 list_memory, + MinimumMemoryForComputation(computation, list_sequence, + points_to_analysis, size_function)); + VLOG(2) << "Min-memory list sequence: " << list_memory << " bytes"; + + TF_ASSIGN_OR_RETURN( + std::vector dfs_sequence, + RunDFSMemoryScheduler(computation, points_to_analysis, size_function)); + TF_ASSIGN_OR_RETURN( + const int64 dfs_memory, + MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis, + size_function)); + VLOG(2) << "Min-memory dfs sequence: " << dfs_memory << " bytes"; + + if (list_memory <= dfs_memory) { + VLOG(2) << "Chose min-memory list sequence: " << list_memory << " bytes"; + return list_sequence; + } else { + VLOG(2) << "Chose min-memory dfs sequence: " << dfs_memory << " bytes"; + return dfs_sequence; + } +} + +} // namespace + +StatusOr +CreateMemoryMinimizingSequence( + const HloModule& module, const LogicalBuffer::SizeFunction& size_function) { + SequentialHloOrdering::HloModuleSequence sequence; + TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, + TuplePointsToAnalysis::Run(&module)); + for (const auto& computation : module.computations()) { + TF_ASSIGN_OR_RETURN(sequence[computation.get()], + CreateMemoryMinimizingSequence( + *computation, *points_to_analysis, size_function)); + } + return sequence; +} + +StatusOr> CreateMemoryMinimizingSequence( + const HloComputation& computation, + const LogicalBuffer::SizeFunction& size_function) { + TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, + TuplePointsToAnalysis::Run(computation.parent())); + return CreateMemoryMinimizingSequence(computation, *points_to_analysis, + size_function); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h new file mode 100644 index 0000000000000000000000000000000000000000..ec92a56b962152b15981f868369683144aa7c76a --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_scheduling.h @@ -0,0 +1,50 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +// Returns the minimum memory required to compute the given module sequence, +// assuming no fragmentation. +StatusOr MinimumMemoryForSequence( + const SequentialHloOrdering::HloModuleSequence& module_sequence, + const LogicalBuffer::SizeFunction& size_function); + +// Returns an HloModuleSequence which seeks to minimize the memory required for +// the computation. size_function is the function returning the number of bytes +// required for a LogicalBuffer. +StatusOr +CreateMemoryMinimizingSequence( + const HloModule& module, const LogicalBuffer::SizeFunction& size_function); + +// Overload of above that computes the sequence for a single computation. +StatusOr> CreateMemoryMinimizingSequence( + const HloComputation& computation, + const LogicalBuffer::SizeFunction& size_function); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d09d22ee40638c5beed3f4eaf3723be0f6b6bf96 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc @@ -0,0 +1,97 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_scheduling.h" + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +class MinimumMemoryForSequenceTest : public HloTestBase {}; + +TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { + auto module = CreateNewModule(); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); + + auto cond_builder = HloComputation::Builder("WhileCond"); + // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); + HloInstruction* cond_iter = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0)); + HloInstruction* cond_data = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); + // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte) + HloInstruction* cond_lt = cond_builder.AddInstruction( + HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), + HloOpcode::kLt, cond_iter, cond_data)); + HloComputation* cond_computation = + module->AddEmbeddedComputation(cond_builder.Build()); + + auto body_builder = HloComputation::Builder("WhileBody"); + // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "body_param")); + HloComputation* body_computation = + module->AddEmbeddedComputation(body_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + // Entry params: 8 bytes (4 bytes per param), TOTAL=8 + HloInstruction* iter = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param_iter")); + HloInstruction* data = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "param_data")); + // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24 + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({iter, data})); + // While: 8 bytes (4 bytes per element), TOTAL=32 + // Both cond and body use a max of 24 bytes, TOTAL=56 + HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( + tuple_shape, cond_computation, body_computation, tuple)); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + + auto size_fn = [](const LogicalBuffer& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + }; + + SequentialHloOrdering::HloModuleSequence module_sequence; + module_sequence[cond_computation] = {cond_param, cond_iter, cond_data, + cond_lt}; + module_sequence[body_computation] = {body_param}; + module_sequence[entry_computation] = {iter, data, tuple, while_op}; + EXPECT_EQ(56, + MinimumMemoryForSequence(module_sequence, size_fn).ValueOrDie()); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc index 867ebc7f61aab1483622d1560d951c053e95f135..e3d287d4c91708577b712261842b6ae231fb188b 100644 --- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc +++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc @@ -75,7 +75,7 @@ TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) { module->AddEmbeddedComputation(CreateR0S32IdentityComputation()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); + HloInstruction::CreateConstant(Literal::CreateR0(5))); auto x = builder.AddInstruction( HloInstruction::CreateCall(r0s32_, {constant}, callee1)); auto y = builder.AddInstruction( @@ -89,12 +89,14 @@ TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) { EXPECT_NE(x->to_apply(), y->to_apply()); if (VLOG_IS_ON(1)) { hlo_graph_dumper::DumpGraph(*module->entry_computation(), - "before unification", false, false, nullptr); + "before unification", + module->config().debug_options()); } EXPECT_TRUE(HloSubcomputationUnification().Run(module.get()).ValueOrDie()); if (VLOG_IS_ON(1)) { hlo_graph_dumper::DumpGraph(*module->entry_computation(), - "after unification", false, false, nullptr); + "after unification", + module->config().debug_options()); } EXPECT_EQ(2, module->computations().size()); EXPECT_EQ(x->to_apply(), y->to_apply()); @@ -110,9 +112,9 @@ TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) { module->AddEmbeddedComputation(CreateR0S32AdditionComputation()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); + HloInstruction::CreateConstant(Literal::CreateR0(5))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(3))); + HloInstruction::CreateConstant(Literal::CreateR0(3))); auto x = builder.AddInstruction( HloInstruction::CreateCall(r0s32_, {constant1, constant2}, callee1)); auto y = builder.AddInstruction( @@ -126,12 +128,14 @@ TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) { EXPECT_NE(x->to_apply(), y->to_apply()); if (VLOG_IS_ON(1)) { hlo_graph_dumper::DumpGraph(*module->entry_computation(), - "before unification", false, false, nullptr); + "before unification", + module->config().debug_options()); } EXPECT_TRUE(HloSubcomputationUnification().Run(module.get()).ValueOrDie()); if (VLOG_IS_ON(1)) { hlo_graph_dumper::DumpGraph(*module->entry_computation(), - "after unification", false, false, nullptr); + "after unification", + module->config().debug_options()); } EXPECT_EQ(2, module->computations().size()); EXPECT_EQ(x->to_apply(), y->to_apply()); @@ -164,12 +168,14 @@ TEST_F(HloSubcomputationUnificationTest, DifferentParameterShapes) { EXPECT_NE(x->to_apply(), y->to_apply()); if (VLOG_IS_ON(1)) { hlo_graph_dumper::DumpGraph(*module->entry_computation(), - "before unification", false, false, nullptr); + "before unification", + module->config().debug_options()); } EXPECT_FALSE(HloSubcomputationUnification().Run(module.get()).ValueOrDie()); if (VLOG_IS_ON(1)) { hlo_graph_dumper::DumpGraph(*module->entry_computation(), - "after unification", false, false, nullptr); + "after unification", + module->config().debug_options()); } EXPECT_EQ(3, module->computations().size()); EXPECT_NE(x->to_apply(), y->to_apply()); diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc index 6707b02c5c57262b0154ae6b23fdd61a198a8d70..76177462aa4959261483045296d2388acabe46a5 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc @@ -171,8 +171,7 @@ void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction, break; case HloOpcode::kConstant: if (ShapeUtil::IsScalar(instruction->shape())) { - attrs["value"].set_s( - LiteralUtil::GetAsString(instruction->literal(), {})); + attrs["value"].set_s(instruction->literal().GetAsString({})); } break; case HloOpcode::kCustomCall: diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc index c2718ea8003c9d2a8e3d65773b439aae915a30d0..8e9d93e367e51cb69f0a38ae7aa8d9539e78ad8a 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc @@ -91,7 +91,7 @@ TEST_F(HloTfGraphBuilderTest, CheckConcatenateDimsAndShapes) { TEST_F(HloTfGraphBuilderTest, CheckScalarValue) { auto builder = HloComputation::Builder("Const"); HloInstruction *instruction = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(123))); + HloInstruction::CreateConstant(Literal::CreateR0(123))); OpMetadata metadata; metadata.set_op_name("x"); metadata.set_op_type("y"); diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc new file mode 100644 index 0000000000000000000000000000000000000000..a5046a712e785d23d557e7567349fd94ca41e90f --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -0,0 +1,313 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_value.h" + +#include +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; + +const Shape& HloLocation::shape() const { + return ShapeUtil::GetSubshape(instruction->shape(), index); +} + +string HloLocation::ToString() const { + string index_str = + ShapeUtil::IsTuple(instruction->shape()) ? (" " + index.ToString()) : ""; + return StrCat(instruction->name(), index_str); +} + +std::ostream& operator<<(std::ostream& out, const HloLocation& location) { + out << location.ToString(); + return out; +} + +string HloUse::ToString() const { + string index_str = + ShapeUtil::IsTuple(instruction->operand(operand_number)->shape()) + ? (" " + operand_index.ToString()) + : ""; + return StrCat(instruction->name(), ", operand ", operand_number, index_str); +} + +std::ostream& operator<<(std::ostream& out, const HloUse& use) { + out << use.ToString(); + return out; +} + +HloValue::HloValue(HloValue::Id id, HloInstruction* instruction, + const ShapeIndex& index, bool is_phi) + : id_(id), is_phi_(is_phi) { + // The defining location is always the first element in the locations_ vector. + AddLocation(instruction, index); +} + +bool HloValue::operator==(const HloValue& other) const { + bool equal = defining_instruction() == other.defining_instruction() && + defining_index() == other.defining_index(); + // If the values are equal they most both be phi (or non phi). + CHECK(!(equal && is_phi() != other.is_phi())); + return equal; +} + +bool HloValue::operator!=(const HloValue& other) const { + return !(*this == other); +} + +string HloValue::ToShortString() const { + string index_str = ShapeUtil::IsTuple(defining_instruction()->shape()) + ? defining_index().ToString() + : ""; + return StrCat(is_phi_ ? "PHI " : "", defining_instruction()->name(), + index_str); +} + +string HloValue::ToString(int indent) const { + string indentation(indent, ' '); + string out = StrCat(indentation, ToShortString(), ", locations:\n"); + for (const HloLocation& location : locations()) { + StrAppend(&out, indentation, " ", location.ToString(), "\n"); + } + StrAppend(&out, indentation, " uses:\n"); + for (const HloUse& use : uses()) { + StrAppend(&out, indentation, " ", use.ToString(), "\n"); + } + return out; +} + +namespace { + +// Returns true if the instruction 'user' may use the value at the given +// ShapeIndex in the given operand. Generally, instruction which pass through +// values transparently without reading the value are not considered to use the +// value. +bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index, + const HloInstruction* user) { + switch (user->opcode()) { + case HloOpcode::kGetTupleElement: + case HloOpcode::kCopy: + // These instructions only access the top-level values of their + // operand. Non-top-level (nested) values are passed through + // transparently. + CHECK_EQ(operand_number, 0); + return index.empty(); + case HloOpcode::kSelect: + // Select does not use any nested elements of its selected-from operands + // (operand 1 and 2) + CHECK_GE(operand_number, 0); + CHECK_LE(operand_number, 2); + return operand_number == 0 || index.empty(); + + case HloOpcode::kCall: + case HloOpcode::kTuple: + // These instructions always pass through their operands transparently. + return false; + + case HloOpcode::kWhile: + // Though the while instructions passes through its operands, we return + // true because in SSA form there may be a Phi at the parameter of the + // while which is considered a use of its incoming value because the Phi + // input values are not passed through into the body computation. Because + // this function is used in both SSA and non-SSA forms of the analysis + // conservatively return true. + return true; + + default: + return true; + } +} + +} // namespace + +void HloValue::AddLocation(HloInstruction* instruction, + const ShapeIndex& index) { + HloLocation new_location{instruction, index}; + + // The new location must not already exist in locations_. + for (const HloLocation& location : locations_) { + DCHECK_NE(location, new_location); + } + // The shape of the new location must match existing locations. + if (!locations_.empty()) { + CHECK( + ShapeUtil::Compatible(locations_.front().shape(), new_location.shape())) + << "front: " << locations_.front() << " new: " << new_location; + } + + locations_.push_back(std::move(new_location)); + + // Update uses. + for (HloInstruction* user : instruction->users()) { + for (int64 operand_number : user->OperandIndices(instruction)) { + if (MayUseOperandValue(operand_number, index, user)) { + HloUse new_use{user, operand_number, index}; + + // The new use must not already exist in uses_. + for (const HloUse& use : uses_) { + DCHECK_NE(use, new_use); + } + + uses_.push_back(std::move(new_use)); + } + } + } + + // Update liveout status of this HloValue. + const HloModule& module = *instruction->parent()->parent(); + if (instruction == module.entry_computation()->root_instruction()) { + live_out_of_module_ = true; + } + + if (instruction == instruction->parent()->root_instruction()) { + live_out_of_computation_ = true; + } +} + +void HloValue::RemoveLocation(HloInstruction* instruction, + const ShapeIndex& index) { + // The defining location cannot be removed. + CHECK(!(instruction == defining_instruction() && index == defining_index())); + + int64 size_before = locations_.size(); + locations_.erase( + std::remove_if(locations_.begin(), locations_.end(), + [instruction, &index](const HloLocation& location) { + return location.instruction == instruction && + location.index == index; + }), + locations_.end()); + // Only a single location should have been removed. + CHECK_EQ(locations_.size(), size_before - 1); + + // Update uses which referred to this location. + uses_.erase(std::remove_if(uses_.begin(), uses_.end(), + [instruction, &index](const HloUse& use) { + return use.instruction->operand( + use.operand_number) == instruction && + use.operand_index == index; + }), + uses_.end()); + + // Returns whether this value is contained in the given instruction's output. + auto is_contained_in = [this](const HloInstruction* instruction) { + for (const HloLocation& location : locations()) { + if (location.instruction == instruction) { + return true; + } + } + return false; + }; + + const HloModule& module = *instruction->parent()->parent(); + if (instruction == module.entry_computation()->root_instruction()) { + // Value has been removed from a location in the entry root instruction. + live_out_of_module_ = + is_contained_in(module.entry_computation()->root_instruction()); + } + if (instruction == defining_instruction()->parent()->root_instruction()) { + // Value has been removed from the root of the computation the value has + // been defined in. + live_out_of_computation_ = + is_contained_in(defining_instruction()->parent()->root_instruction()); + } +} + +std::ostream& operator<<(std::ostream& out, const HloValue& value) { + out << value.ToShortString(); + return out; +} + +void HloValueSet::SortAndUniquifyValues() { + std::sort(value_ids_.begin(), value_ids_.end()); + value_ids_.erase(std::unique(value_ids_.begin(), value_ids_.end()), + value_ids_.end()); +} + +string HloValueSet::ToString() const { + return StrCat("HloValueSet: ", tensorflow::str_util::Join(value_ids_, ", ")); +} + +/*static */ +HloValueSet HloValueSet::Union( + tensorflow::gtl::ArraySlice inputs) { + HloValueSet union_set; + for (const HloValueSet* input : inputs) { + for (HloValue::Id value_id : input->value_ids()) { + union_set.value_ids_.push_back(value_id); + } + } + union_set.SortAndUniquifyValues(); + return union_set; +} + +std::ostream& operator<<(std::ostream& out, const HloValueSet& value_set) { + out << value_set.ToString(); + return out; +} + +InstructionValueSet InstructionValueSet::Union( + tensorflow::gtl::ArraySlice inputs) { + CHECK_GT(inputs.size(), 0); + for (int i = 1; i < inputs.size(); ++i) { + CHECK(ShapeUtil::Compatible(inputs[0]->shape(), inputs[i]->shape())); + } + InstructionValueSet union_set(inputs[0]->shape()); + union_set.ForEachMutableElement( + [&inputs](const ShapeIndex& index, HloValueSet* value_set) { + std::vector input_sets; + for (const InstructionValueSet* input : inputs) { + input_sets.push_back(&input->element(index)); + } + *value_set = HloValueSet::Union(input_sets); + }); + return union_set; +} + +std::ostream& operator<<(std::ostream& out, + const InstructionValueSet& instruction_value_set) { + out << instruction_value_set.ToString(); + return out; +} + +string InstructionValueSet::ToString() const { + string out = + StrCat("InstructionValueSet(", ShapeUtil::HumanString(shape()), ")\n"); + ForEachElement([this, &out](const ShapeIndex& index, + const HloValueSet& value_set) { + StrAppend(&out, " ", index.ToString(), " : ", value_set.ToString(), "\n"); + }); + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_value.h b/tensorflow/compiler/xla/service/hlo_value.h new file mode 100644 index 0000000000000000000000000000000000000000..b0caf24a2184c14dbc6d1143ccf74d1390c267b4 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_value.h @@ -0,0 +1,247 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { + +// Abstraction which identifies a specific point in the XLA graph. An +// HloLocation specifies a ShapeIndex within the output of a specific +// instruction. +struct HloLocation { + HloInstruction* instruction; + ShapeIndex index; + + // Returns the shape at this location. + const Shape& shape() const; + + string ToString() const; + + bool operator==(const HloLocation& other) const { + return instruction == other.instruction && index == other.index; + } + bool operator!=(const HloLocation& other) const { return !(*this == other); } +}; + +std::ostream& operator<<(std::ostream& out, const HloLocation& location); + +// Defines a single use of an HLO value. +struct HloUse { + // Instruction at which the value is used. + HloInstruction* instruction; + + // The operand number in which the value is appears. + int64 operand_number; + + // The shape index within the operand in which the value appears. + ShapeIndex operand_index; + + string ToString() const; + + bool operator==(const HloUse& other) const { + return instruction == other.instruction && + operand_number == other.operand_number && + operand_index == other.operand_index; + } + + bool operator!=(const HloUse& other) const { return !(*this == other); } +}; + +std::ostream& operator<<(std::ostream& out, const HloUse& use); + +// Class describing a value used by the dataflow analysis. XLA arrays are +// trivially a single HloValue. Tuples are made up of more than one HloValue: an +// HloValue for the pointer vector, and an HloValue for each child element. +// +// Every HloValue is defined by a particular instruction and most instructions +// define only a single HloValue. Instructions which define a single HloValue +// include array-shaped instructions such as Add but also includes Tuple-shaped +// instructions such as Tuple. The Tuple instruction defines a single HloValue +// which is a vector of pointers to the values containing the Tuple +// instruction's operands. Though the result of the Tuple instruction includes +// multiple values only the top-level HloValue (the vector of pointers) is +// defined by the Tuple instruction. The values containing the tuple elements +// are defined by earlier instructions, usually the operands of the Tuple +// instruction. +// +// Instructions which construct both the tuple *and* the tuple elements define +// more than one HloValue. This includes (at least) tuple-shaped Constant, +// Parameter, Infeed and While instructions. These tuple-shaped instructions do +// not assemble a tuple from existing HloValues like the Tuple instruction does, +// but rather define all the HloValues in the tuple. +class HloValue { + public: + using Id = int64; + + // Construct an HloValue defined by 'instruction' at shape index 'index'. If + // is_phi is true, then this value is a phi value, for example, at the + // parameter of a while body computation. Phi values are only used in the SSA + // dataflow analysis (HloDataflowAnalysis::ssa_form_ is true). + HloValue(Id id, HloInstruction* instruction, const ShapeIndex& index, + bool is_phi = false); + + // Return a unique identifier for this HloValue. This value is used for stable + // sorting and iteration + Id id() const { return id_; } + + // Returns whether this value is a phi value. + bool is_phi() const { return is_phi_; } + + // Return the location where this value is defined. + const HloLocation& defining_location() const { return locations_[0]; } + + // Return the instruction which defines this HloValue. + HloInstruction* defining_instruction() const { + return defining_location().instruction; + } + + // Return the shape index at which this HloValue is defined in the output of + // its defining instruction. + const ShapeIndex& defining_index() const { return defining_location().index; } + + // Return the shape of this HloValue. + const Shape& shape() const { return defining_location().shape(); } + + // Add or remove a location at which the HloValue appears. The definition + // location can not be removed. The uses of the HloValue are updated. + void AddLocation(HloInstruction* instruction, const ShapeIndex& index); + void RemoveLocation(HloInstruction* instruction, const ShapeIndex& index); + + // Return all locations of the HloValue in the module. + const std::vector& locations() const { return locations_; } + + // Return all uses of the HloValue. + const std::vector& uses() const { return uses_; } + + // Get whether this HloValue is live out of the module. + bool live_out_of_module() const { return live_out_of_module_; } + + // Get whether this HloValue is live out of the computation it is defined in. + bool live_out_of_computation() const { return live_out_of_computation_; } + + bool operator==(const HloValue& other) const; + bool operator!=(const HloValue& other) const; + + // Return a single-line string representation of the value. + string ToShortString() const; + + string ToString(int indent = 0) const; + + private: + // Unique identifier for this HloValue. Used for stable sorting and iteration. + const Id id_; + + // Whether this instruction is a phi value. + const bool is_phi_; + + // The set of locations of this HloValue. The first element is always the + // location of the definition. + std::vector locations_; + + // The set of uses of this HloValue. + std::vector uses_; + + // Whether this value is live out of the HLO module. + bool live_out_of_module_ = false; + + // Whether this value is live out of its computation. + bool live_out_of_computation_ = false; +}; + +std::ostream& operator<<(std::ostream& out, const HloValue& hlo_value); + +// A class representing the possible set of HloValues at a particular point +// (shape index in the output of an instruction) in the XLA graph. This set +// contains the set of reaching HloValue definitions. For a simple array-shaped +// instruction like Add, the HloValueSet of the top-level of the instruction's +// output trivially contains only the HloValue defined by the instruction. For +// instructions which have non-trivial dataflow such as Tuple or Select, the +// HloValueSets of the instruction's output contains one or more HloValues +// defined by the instruction's operands or defined further up in the XLA graph. +class HloValueSet { + public: + HloValueSet() = default; + + explicit HloValueSet(tensorflow::gtl::ArraySlice value_ids) + : value_ids_(value_ids.begin(), value_ids.end()) { + SortAndUniquifyValues(); + } + + // Return the union of the given HloValueSets. + static HloValueSet Union( + tensorflow::gtl::ArraySlice inputs); + + // Return the vector of the IDs of all HloValues in the set. Values in the + // vector are unique and sorted. + const std::vector& value_ids() const { return value_ids_; } + + // Return the unique HLO value in the set. CHECKs if the set does not contain + // exactly one value. + HloValue::Id GetUniqueValueId() const { + CHECK_EQ(value_ids().size(), 1); + return value_ids()[0]; + } + + bool operator==(const HloValueSet& other) const { + return value_ids() == other.value_ids(); + } + bool operator!=(const HloValueSet& other) const { return !(*this == other); } + + string ToString() const; + + private: + // Sorts value_ and removes duplicates. This should be called after adding any + // elements to values_. + void SortAndUniquifyValues(); + + // HloValues sorted by HloValue::Id. + std::vector value_ids_; +}; + +std::ostream& operator<<(std::ostream& out, const HloValueSet& hlo_value); + +// A class collecting the HloValues which might be contained in the output of +// an HLO instruction. For array-shaped instructions, an InstructionValueSet +// trivially holds a single HloValueSet. Tuple-shaped InstructionValueSets +// hold multiple HloValueSets. +class InstructionValueSet : public ShapeTree { + public: + InstructionValueSet(const Shape& shape) : ShapeTree(shape) {} + + // Return the union of the given InstructionValueSets. + static InstructionValueSet Union( + tensorflow::gtl::ArraySlice inputs); + + string ToString() const; +}; + +std::ostream& operator<<(std::ostream& out, + const InstructionValueSet& instruction_value_set); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_ diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index de6081e57e7f27a07b314692c6935ecf3e3c54a9..01fba49bc567900418f9e4622351373abe7b2e18 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -14,10 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { StatusOr HloVerifier::Run(HloModule* module) { + tensorflow::gtl::FlatMap instructions; + for (auto& computation : module->computations()) { for (const auto& instruction : computation->instructions()) { TF_RET_CHECK(instruction->parent() == computation.get()); @@ -30,6 +33,16 @@ StatusOr HloVerifier::Run(HloModule* module) { << " computation: " << computation.get(); } } + + auto previous = instructions.find(instruction->name()); + TF_RET_CHECK(previous == instructions.end()) + << "HLO has name that is not unique within module:\n" + << instruction->ToString() + << " in computation: " << computation->name() + << "\nPrevious HLO with same name:\n" + << previous->second->ToString() + << " in computation: " << previous->second->parent()->name(); + instructions[instruction->name()] = instruction.get(); } } diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc index 2887a8a0a097c9dcb3d490f0845547f104aa1bdf..84bfbb30c30d84a6a233a60fb420b43c3fe3454c 100644 --- a/tensorflow/compiler/xla/service/inliner_test.cc +++ b/tensorflow/compiler/xla/service/inliner_test.cc @@ -51,10 +51,10 @@ TEST_F(InlinerTest, MapMax) { auto max_f32 = max_builder.Build(); auto builder = HloComputation::Builder("MapMaxFunction"); - auto lhs = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1, 2, 3, 4}))); - auto rhs = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({4, 3, 2, 1}))); + auto lhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3, 4}))); + auto rhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({4, 3, 2, 1}))); builder.AddInstruction( HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get())); @@ -70,7 +70,7 @@ TEST_F(InlinerTest, MapMax) { // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - auto expected = LiteralUtil::CreateR1({4, 3, 3, 4}); + auto expected = Literal::CreateR1({4, 3, 3, 4}); LiteralTestUtil::ExpectEqual(*result, *expected); } @@ -83,12 +83,12 @@ TEST_F(InlinerTest, MapConstant) { HloInstruction::CreateParameter(0, r0f32, "x")); (void)param1; const2_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0f))); auto const2_f32 = const2_builder.Build(); auto builder = HloComputation::Builder("MapConstFunction"); auto lhs = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 2, 3, 4}, {5, 6, 7, 8}}))); + Literal::CreateR2({{1, 2, 3, 4}, {5, 6, 7, 8}}))); builder.AddInstruction( HloInstruction::CreateMap(lhs->shape(), {lhs}, const2_f32.get())); @@ -104,7 +104,7 @@ TEST_F(InlinerTest, MapConstant) { // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - auto expected = LiteralUtil::CreateR2({{2, 2, 2, 2}, {2, 2, 2, 2}}); + auto expected = Literal::CreateR2({{2, 2, 2, 2}, {2, 2, 2, 2}}); LiteralTestUtil::ExpectEqual(*result, *expected); } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 721640cdbd8133f621f65a2505cdf3b84590e740..2a66943f9b2f0a68b4d56ca0961f5c87ac468073 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -28,7 +28,6 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" namespace xla { - /*static*/ bool InstructionFusion::IsExpensive( const HloInstruction& instruction) { switch (instruction.opcode()) { @@ -43,6 +42,7 @@ namespace xla { case HloOpcode::kConstant: case HloOpcode::kConvert: case HloOpcode::kCopy: + case HloOpcode::kCos: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: case HloOpcode::kEq: @@ -64,6 +64,7 @@ namespace xla { case HloOpcode::kNegate: case HloOpcode::kOutfeed: case HloOpcode::kPad: + case HloOpcode::kReducePrecision: case HloOpcode::kReshape: case HloOpcode::kReverse: case HloOpcode::kSelect: @@ -75,6 +76,7 @@ namespace xla { return false; // Expensive instructions. + case HloOpcode::kBatchNormTraining: case HloOpcode::kCall: case HloOpcode::kConvolution: case HloOpcode::kCrossReplicaSum: @@ -113,8 +115,98 @@ bool FusionWouldDuplicate(const HloInstruction& producer, const HloInstruction& consumer) { return !(producer.users().size() == 1 && consumer.IsUserOf(&producer)); } + +// An "effectively unary" operation is one that has one "large" +// input with the others being negligible in terms of memory usage. +// We use "has a smaller true rank than the output" as a heuristic +// for "negligible" memory usage. +bool EffectivelyUnary(HloInstruction* hlo) { + int64 output_rank = 0; + ShapeUtil::ForEachSubshape( + hlo->shape(), + [&output_rank](const Shape& subshape, const ShapeIndex& shape_index) { + if (ShapeUtil::IsArray(subshape)) { + output_rank = std::max(output_rank, ShapeUtil::TrueRank(subshape)); + } + }); + return std::count_if(hlo->operands().begin(), hlo->operands().end(), + [output_rank](HloInstruction* operand) { + if (operand->opcode() == HloOpcode::kBroadcast) { + return false; + } + if (operand->opcode() == HloOpcode::kConstant && + ShapeUtil::IsEffectiveScalar(operand->shape())) { + return false; + } + return ShapeUtil::TrueRank(operand->shape()) >= + output_rank; + }) <= 1; +} } // namespace +bool InstructionFusion::CanFuseOnAllPaths( + const HloReachabilityMap& reachability_map, HloInstruction* producer, + HloInstruction* consumer, DoNotFuseSet* do_not_fuse) { + auto could_fuse_on_all_paths = [&] { + // First check to see if we have already marked this producer as infeasible + // to fuse into consumer. + if (do_not_fuse->count(producer) > 0) { + return false; + } + // Make sure it is possible for producer and consumer to exist in a fusion + // node. + if (!producer->IsFusable() || !consumer->IsFusable()) { + return false; + } + // We do an upward walk of the graph from consumer towards all paths which + // lead to producer to find any unfusable paths. + for (int64 i = 0, e = consumer->operand_count(); i < e; ++i) { + auto* consumer_operand = consumer->mutable_operand(i); + if (consumer_operand == producer) { + // This is the base case: our upward crawl ends but we need to make sure + // that fusion from consumer can happen. + if (!ShouldFuse(consumer, i)) { + return false; + } + } else if (reachability_map.IsReachable(producer, consumer_operand)) { + // The reachability map told us that consumer_operand is a node on the + // path to producer. We need to further investigate from + // consumer_operand. + + // First check if we have already ruled out fusing producer into + // consumer_operand. + if (do_not_fuse->count(consumer_operand) > 0) { + return false; + } + // Make sure it is possible for consumer_operand to exist in a fusion + // node. + if (!consumer_operand->IsFusable()) { + return false; + } + // The producer is reachable from consumer_operand which means we need + // to be able to fuse consumer_operand into consumer in order for + // producer to be fusable into consumer on all paths. + if (!ShouldFuse(consumer, i)) { + return false; + } + // Perform the recursive step: make sure producer can be fused into + // consumer_operand on all paths. + if (!CanFuseOnAllPaths(reachability_map, producer, consumer_operand, + do_not_fuse)) { + return false; + } + } + } + return true; + }; + if (could_fuse_on_all_paths()) { + return true; + } + // We couldn't fuse on all paths, record this result. + do_not_fuse->insert(producer); + return false; +} + StatusOr InstructionFusion::Run(HloModule* module) { bool changed = false; for (auto& computation : module->computations()) { @@ -131,56 +223,42 @@ StatusOr InstructionFusion::Run(HloModule* module) { std::vector post_order(post_order_list.begin(), post_order_list.end()); - std::set all_consumers_fusable; - // Find which ops can be fused into all of their operands. We would rather - // not fuse an op into only some of its users, as that offers no benefit in - // terms of memory bandwidth, but forces us to keep more live values around. - for (auto* hlo : post_order) { - auto user_fusable_into_hlo = [this, &hlo](HloInstruction* consumer) { - if (!consumer->IsFusable()) { - return false; - } - for (int operand_number = 0; - operand_number < consumer->operands().size(); ++operand_number) { - if (consumer->operand(operand_number) == hlo) { - if (!ShouldFuse(consumer, operand_number)) { - return false; - } - } - } - return true; - }; - - // An "effectively unary" operation is one that has one "large" - // input with the others being negligible in terms of memory usage. - // We use "has a smaller true rank than the output" as a heuristic - // for "negligible" memory usage. - auto effectively_unary = [](HloInstruction* hlo) { - if (hlo->operands().size() == 1) { - return true; - } - auto output_rank = ShapeUtil::TrueRank(hlo->shape()); - return std::count_if( - hlo->operands().begin(), hlo->operands().end(), - [output_rank](HloInstruction* operand) { - return ((operand->opcode() != HloOpcode::kBroadcast) && - ShapeUtil::TrueRank(operand->shape()) >= - output_rank); - }) <= 1; - }; - - if (effectively_unary(hlo) || - std::all_of(hlo->users().begin(), hlo->users().end(), - user_fusable_into_hlo)) { - all_consumers_fusable.insert(hlo); - } - } - tensorflow::gtl::FlatMap post_order_index; for (size_t i = 0; i < post_order.size(); ++i) { InsertOrDie(&post_order_index, post_order[i], i); } + DoNotFuseSet do_not_fuse; + auto reachability = computation->ComputeReachability(); + + auto cheap_to_duplicate = [](HloInstruction* producer) { + if (producer->opcode() == HloOpcode::kBroadcast) { + return true; + } + if (producer->opcode() == HloOpcode::kConstant && + ShapeUtil::IsEffectiveScalar(producer->shape())) { + return true; + } + if (EffectivelyUnary(producer)) { + return true; + } + return false; + }; + + for (HloInstruction* consumer : post_order) { + for (HloInstruction* producer : consumer->operands()) { + if (cheap_to_duplicate(producer)) { + continue; + } + if (CanFuseOnAllPaths(*reachability, producer, consumer, + &do_not_fuse)) { + CHECK_EQ(do_not_fuse.count(producer), 0); + } else { + CHECK_GT(do_not_fuse.count(producer), 0); + } + } + } + // Instruction fusion effectively fuses edges in the computation graph // (producer instruction -> consumer instruction) so we iterate over all // edges. When we fuse an edge, we create a copy of the producer inside the @@ -263,34 +341,36 @@ StatusOr InstructionFusion::Run(HloModule* module) { for (int64 i : sorted_operand_numbers) { HloInstruction* operand = instruction->mutable_operand(i); - if (FusionWouldDuplicate(*operand, *instruction) && - (all_consumers_fusable.count(operand) == 0)) { + if (!operand->IsFusable()) { continue; } - - if (operand->IsFusable() && ShouldFuse(instruction, i)) { - HloInstruction* fusion_instruction = Fuse(operand, instruction); - - // Fusing an instruction into a fusion instruction can change the - // operand set of the fusion instruction. For simplicity just push the - // instruction to the top of the post_order and reconsider it for - // further fusion in the next iteration of the outer loop. - post_order.push_back(fusion_instruction); - InsertOrDie(&post_order_index, fusion_instruction, - post_order.size() - 1); - changed = true; - - if (operand->user_count() == 0) { - // Operand is now dead. Remove from post order by setting it's - // location to nullptr. - post_order[FindOrDie(post_order_index, operand)] = nullptr; - post_order_index.erase(operand); - - // Remove from computation. - TF_RETURN_IF_ERROR(computation_->RemoveInstruction(operand)); - } - break; + if (!ShouldFuse(instruction, i)) { + continue; + } + if (do_not_fuse.count(operand) > 0) { + continue; + } + HloInstruction* fusion_instruction = Fuse(operand, instruction); + + // Fusing an instruction into a fusion instruction can change the + // operand set of the fusion instruction. For simplicity just push the + // instruction to the top of the post_order and reconsider it for + // further fusion in the next iteration of the outer loop. + post_order.push_back(fusion_instruction); + InsertOrDie(&post_order_index, fusion_instruction, + post_order.size() - 1); + changed = true; + + if (operand->user_count() == 0) { + // Operand is now dead. Remove from post order by setting it's + // location to nullptr. + post_order[FindOrDie(post_order_index, operand)] = nullptr; + post_order_index.erase(operand); + + // Remove from computation. + TF_RETURN_IF_ERROR(computation_->RemoveInstruction(operand)); } + break; } } } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index a9f3723f2dfcc1b3b697d34eb9510f5857a443f0..f6f37bb79b9fe1480db61b10b9810347960f9a72 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -72,6 +72,15 @@ class InstructionFusion : public HloPassInterface { private: HloInstruction* Fuse(HloInstruction* producer, HloInstruction* consumer); + // The set of producers whose consumers we cannot fuse into. + using DoNotFuseSet = std::unordered_set; + + // Whether or not we can fuse consumer into original_producer on all paths + // from the producer to the consumer where nodes are HLOs and edges are uses. + bool CanFuseOnAllPaths(const HloReachabilityMap& reachability_map, + HloInstruction* producer, HloInstruction* consumer, + DoNotFuseSet* do_not_fuse); + // Used to determine if an HLO is expensive. Expensive operations will not be // duplicated. std::function is_expensive_; diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index a2e6c2ae00bd65b1d3aeca49f26448d8a07670a8..b3e0007dcc2d43028b49cc48477a0a69153b13c8 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -28,7 +28,7 @@ TEST_F(InstructionFusionTest, CostlyProducerAndOperandElementReusingConsumerNotFused) { HloComputation::Builder builder(TestName()); HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); + HloInstruction::CreateConstant(Literal::CreateR0(5))); HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0)); HloInstruction* broadcast2 = @@ -49,7 +49,7 @@ TEST_F(InstructionFusionTest, NonCostlyProducerAndOperandElementReusingConsumerFused) { HloComputation::Builder builder(TestName()); HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); + HloInstruction::CreateConstant(Literal::CreateR0(5))); HloInstruction* negate1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {}), HloOpcode::kNegate, const0)); HloInstruction* broadcast2 = @@ -70,7 +70,7 @@ TEST_F(InstructionFusionTest, CostlyProducerAndNonOperandElementReusingConsumerFused_Reshape) { HloComputation::Builder builder(TestName()); HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); + HloInstruction::CreateConstant(Literal::CreateR0(5))); HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0)); HloInstruction* reshape2 = builder.AddInstruction( @@ -90,7 +90,7 @@ TEST_F(InstructionFusionTest, CostlyProducerAndNonOperandElementReusingConsumerFused_Transpose) { HloComputation::Builder builder(TestName()); HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); + HloInstruction::CreateConstant(Literal::CreateR0(5))); HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0)); HloInstruction* transpose2 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index e9e199226a6db7a0547bda4b069e917f2a41295b..aafface0b9f3013d01c1f6d13ef1f4c7927b5a16 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -382,7 +382,11 @@ Status LayoutAssignment::AddMandatoryConstraints( // instruction. // TODO(b/31425034): Change infeeds to be more like parameters, with // shapes in the ComputationLayout. - shape_with_layout = &instruction->shape(); + // TODO(b/62477016): When the infeed does not set padding anymore, the + // call to ShapeWithoutPadding can be removed. + Shape infeed_shape = ShapeUtil::ShapeWithoutPadding(instruction->shape()); + TF_RETURN_IF_ERROR( + constraints->SetInstructionLayout(infeed_shape, instruction.get())); } else if (instruction->opcode() == HloOpcode::kOutfeed) { // Constrain the input to the Outfeed instruction to be the expected // layout of the Outfeed. @@ -729,23 +733,18 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( if (instruction->opcode() == HloOpcode::kReshape) { // Prefer the operand layout that makes the reshape an bitcast. If any // dimension bound is 1 in the operand shape, there may be several such - // layouts. So if 'output_layout' is a MajorToMinor layout, try if the + // layouts. So if 'output_layout' is the default layout, try if the // reshape is a bitcast when using the same layout. This may avoid copy // operations. const Shape& output_shape = instruction->shape(); Shape output_shape_with_layout = ShapeUtil::MakeShapeWithLayout( output_shape.element_type(), AsInt64Slice(output_shape.dimensions()), AsInt64Slice(output_layout.minor_to_major())); - const Shape& operand_shape = operand->shape(); - if (LayoutUtil::IsMonotonicWithDim0Major(output_layout)) { - Shape operand_shape_with_layout = - ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( - operand_shape.element_type(), - AsInt64Slice(operand_shape.dimensions())); - if (ShapeUtil::ReshapeIsBitcast(operand_shape_with_layout, - output_shape_with_layout)) { - return MakeUnique(operand_shape_with_layout.layout()); - } + Shape operand_shape = operand->shape(); + *operand_shape.mutable_layout() = + LayoutUtil::GetDefaultLayoutForShape(operand_shape); + if (ShapeUtil::ReshapeIsBitcast(operand_shape, output_shape_with_layout)) { + return MakeUnique(operand_shape.layout()); } auto aligned_operand_shape = ShapeUtil::AlignLayouts(output_shape_with_layout, operand_shape); @@ -759,10 +758,14 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( if (instruction->opcode() == HloOpcode::kTranspose) { // Pick the operand layout that makes the transpose a bitcast. - std::vector perm = - ComposePermutations(instruction->dimensions(), - AsInt64Slice(output_layout.minor_to_major())); - Layout operand_layout = LayoutUtil::MakeLayout(perm); + int64 rank = ShapeUtil::Rank(instruction->shape()); + std::vector new_minor_to_major(rank); + for (int64 i = 0; i < rank; ++i) { + int64 output_dim = output_layout.minor_to_major(i); + int64 operand_dim = instruction->dimensions(output_dim); + new_minor_to_major[i] = operand_dim; + } + Layout operand_layout = LayoutUtil::MakeLayout(new_minor_to_major); TF_CHECK_OK( LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape())); return MakeUnique(operand_layout); @@ -789,23 +792,18 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( if (user->opcode() == HloOpcode::kReshape) { // Prefer the user layout that makes the reshape an bitcast. If any // dimension bound is 1 in the user shape, there may be several such - // layouts. So if 'operand_layout' is a MajorToMinor layout, try if the + // layouts. So if 'operand_layout' is the default layout, try if the // reshape is a bitcast when using the same layout. This may avoid copy // operations. Shape operand_shape_with_layout = ShapeUtil::MakeShapeWithLayout( operand->shape().element_type(), AsInt64Slice(operand->shape().dimensions()), AsInt64Slice(operand_layout.minor_to_major())); - const Shape& output_shape = user->shape(); - if (LayoutUtil::IsMonotonicWithDim0Major(operand_layout)) { - Shape output_shape_with_layout = - ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( - output_shape.element_type(), - AsInt64Slice(output_shape.dimensions())); - if (ShapeUtil::ReshapeIsBitcast(output_shape_with_layout, - operand_shape_with_layout)) { - return MakeUnique(output_shape_with_layout.layout()); - } + Shape output_shape = user->shape(); + *output_shape.mutable_layout() = + LayoutUtil::GetDefaultLayoutForShape(output_shape); + if (ShapeUtil::ReshapeIsBitcast(output_shape, operand_shape_with_layout)) { + return MakeUnique(output_shape.layout()); } auto aligned_user_shape = ShapeUtil::AlignLayouts(operand_shape_with_layout, output_shape); @@ -818,14 +816,16 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( } if (user->opcode() == HloOpcode::kTranspose) { - // Pick the user layout that makes the reshape a bitcast. - // To become a bitcast, the layouts need to satisfy - // collapsing_order * output_layout = input_layout - // so output_layout = inverse(collapsing_order) * input_layout - std::vector perm = - Permute(InversePermutation(user->dimensions()), - AsInt64Slice(operand_layout.minor_to_major())); - Layout user_layout = LayoutUtil::MakeLayout(perm); + // Pick the user layout that makes the transpose a bitcast. + int64 rank = ShapeUtil::Rank(user->shape()); + std::vector new_minor_to_major(rank); + auto inverse_dimensions = InversePermutation(user->dimensions()); + for (int64 i = 0; i < rank; ++i) { + int64 operand_dim = operand_layout.minor_to_major(i); + int64 user_dim = inverse_dimensions[operand_dim]; + new_minor_to_major[i] = user_dim; + } + Layout user_layout = LayoutUtil::MakeLayout(new_minor_to_major); TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(user_layout, user->shape())); return MakeUnique(user_layout); } @@ -926,7 +926,7 @@ Status LayoutAssignment::PropagateUseConstraintToDefs( ShapeUtil::IsArray(buffer->shape())) { TF_RETURN_IF_ERROR(constraints->SetBufferLayout( ShapeUtil::GetSubshape(shape_layout.shape(), index).layout(), - *buffer)); + *buffer, /*mandatory=*/true)); } } } @@ -1346,8 +1346,7 @@ StatusOr LayoutAssignment::Run(HloModule* module) { if (VLOG_IS_ON(10)) { hlo_graph_dumper::DumpGraph(*module->entry_computation(), "before layout assignment", - /*show_addresses=*/false, - /*show_layouts=*/true); + module->config().debug_options()); } // Assign layouts to computations in an order such that a callee computation @@ -1373,8 +1372,7 @@ StatusOr LayoutAssignment::Run(HloModule* module) { if (VLOG_IS_ON(10)) { hlo_graph_dumper::DumpGraph(*module->entry_computation(), "after layout assignment", - /*show_addresses=*/false, - /*show_layouts=*/true); + module->config().debug_options()); } // All layouts are reset then reassigned by this pass. diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 6d818cdea0c30701adf83f6265a6d7b554fb91cc..f69c043f32b4e688a543d277164eb91b364b51dc 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -230,7 +230,7 @@ TEST_F(LayoutAssignmentTest, TupleSelect) { HloInstruction::CreateTuple({constant0, constant1})); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + HloInstruction::CreateConstant(Literal::CreateR0(true))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1)); @@ -264,7 +264,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { // tuple and assigning the layouts of the copied arrays as needed. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); auto inner_tuple = builder.AddInstruction(HloInstruction::CreateTuple({constant})); auto nested_tuple = builder.AddInstruction( @@ -552,6 +552,41 @@ TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) { ElementsAre(1, 0)); } +// Test layout assignment of a transpose into a bitcast based on its operand. +TEST_F(LayoutAssignmentTest, TransposeToBitcastFromOperand) { + auto builder = HloComputation::Builder(TestName()); + Shape input_shape_with_layout = + ShapeUtil::MakeShapeWithLayout(F32, {3, 5, 6, 7}, {2, 0, 3, 1}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape_with_layout, "param")); + auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), param, {2, 3, 0, 1})); + auto module = CreateNewModule(); + HloComputation* computation = + module->AddEntryComputation(builder.Build(transpose)); + ComputationLayout computation_layout(computation->ComputeProgramShape()); + AssignLayouts(module.get(), &computation_layout); + EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(), + transpose->shape(), {2, 3, 0, 1})); +} +// Test layout assignment of a transpose into a bitcast based on its user. +TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) { + auto builder = HloComputation::Builder(TestName()); + Shape input_shape = ShapeUtil::MakeShape(F32, {3, 5, 6, 7}); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(input_shape, constant, {})); + auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), broadcast, {2, 3, 0, 1})); + auto module = CreateNewModule(); + HloComputation* computation = + module->AddEntryComputation(builder.Build(transpose)); + ComputationLayout computation_layout(computation->ComputeProgramShape()); + AssignLayouts(module.get(), &computation_layout); + EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(), + transpose->shape(), {2, 3, 0, 1})); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/liveness_util.cc b/tensorflow/compiler/xla/service/liveness_util.cc index 682bf19807b4b5d4e8a66c6c5e2e01c80a026594..9c80fb3adbc99b2e5cd3efc20deaf602c5ebc526 100644 --- a/tensorflow/compiler/xla/service/liveness_util.cc +++ b/tensorflow/compiler/xla/service/liveness_util.cc @@ -28,17 +28,6 @@ limitations under the License. namespace xla { -bool DoesNotUseOperandBuffer(const HloInstruction* operand, - const ShapeIndex& index, - const HloInstruction* user) { - CHECK(user->IsUserOf(operand)) - << "user: " << user->ToString() << " operand: " << operand->ToString(); - - // GetTupleElement instructions only access the top-level buffer of their - // operand. - return (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()); -} - bool DoesNotUseOperandBuffer(const HloInstruction* operand, const ShapeIndex& index, const HloInstruction* user, @@ -149,18 +138,22 @@ bool HasUniqueFusedUseOfOperandAt( // User and operand can share buffers iff both instructions emit the same shape // and layout, and 'user' meets one of the following qualifications: -// *) Is element-wise. Or... -// *) Is a loop fusion instruction where the only use of 'operand' at 'index' -// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root -// at operand 0. Or... -// *) Is a kDot -> kAdd (or fused kTransposeDot -> kAdd) output fusion -// instruction where the only use of 'operand' at 'index' in the set -// 'user.fused_instructions' is a kAdd fused root at operand 0 or 1. Or... -// *) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index 0. +// +// (1) Is element-wise. Or... +// (2) Is a loop fusion instruction where the only use of 'operand' at 'index' +// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root +// at operand 0. Or... +// (3) Is a kDot -> kAdd (or fused kTransposeDot -> kAdd) output fusion +// instruction where the only use of 'operand' at 'index' in the set +// 'user.fused_instructions' is a kAdd fused root at operand 0 or 1. Or... +// (4) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index +// 0. +// +// (2) and (3) can only be determined if points-to analysis is available. bool CanShareOperandBufferWithUser( HloInstruction* operand, const ShapeIndex& operand_index, HloInstruction* user, const ShapeIndex& user_index, - const TuplePointsToAnalysis& points_to_analysis) { + const TuplePointsToAnalysis* points_to_analysis) { CHECK(user->IsUserOf(operand)) << "user: " << user->ToString() << " operand: " << operand->ToString(); Shape operand_subshape = @@ -170,7 +163,7 @@ bool CanShareOperandBufferWithUser( if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { return false; } - if (user->opcode() == HloOpcode::kFusion) { + if (points_to_analysis != nullptr && user->opcode() == HloOpcode::kFusion) { if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && user->fused_expression_root()->opcode() == HloOpcode::kDynamicUpdateSlice) { @@ -180,7 +173,7 @@ bool CanShareOperandBufferWithUser( // 'operand_index', and this singleton use is the fused root at operand // index 0. return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0, - points_to_analysis); + *points_to_analysis); } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && user->fused_expression_root()->opcode() == HloOpcode::kAdd) { // Output fusion with kAdd fused root. @@ -208,7 +201,7 @@ bool CanShareOperandBufferWithUser( // index 'other_add_operand_index'). return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, other_add_operand_index, - points_to_analysis); + *points_to_analysis); } } if (user->opcode() == HloOpcode::kDynamicUpdateSlice || diff --git a/tensorflow/compiler/xla/service/liveness_util.h b/tensorflow/compiler/xla/service/liveness_util.h index 0b01223db73d49ad3ee127dd9076e37f5fac8ec5..c7799e5ab5d0c0d0477c09fa7e6a36c67312a72b 100644 --- a/tensorflow/compiler/xla/service/liveness_util.h +++ b/tensorflow/compiler/xla/service/liveness_util.h @@ -34,21 +34,16 @@ bool DoesNotUseOperandBuffer(const HloInstruction* operand, const HloInstruction* user, const TuplePointsToAnalysis& points_to_analysis); -// Overload which does not require points-to analysis. The result is more -// conservative (returns false more often). -bool DoesNotUseOperandBuffer(const HloInstruction* operand, - const ShapeIndex& index, - const HloInstruction* user); - // Returns true if 'user' (at 'user_index') can share a buffer with its operand -// 'operand' (at 'operand_index'). -// Returns false otherwise. +// 'operand' (at 'operand_index'). Returns false otherwise. Optionally takes a +// points-to analysis argument. Without the analysis, the result is more +// conservative (returns false more often). // // REQUIRES: 'operand' is an operand of 'user'. bool CanShareOperandBufferWithUser( HloInstruction* operand, const ShapeIndex& operand_index, HloInstruction* user, const ShapeIndex& user_index, - const TuplePointsToAnalysis& points_to_analysis); + const TuplePointsToAnalysis* points_to_analysis = nullptr); } // namespace xla diff --git a/tensorflow/compiler/xla/service/liveness_util_test.cc b/tensorflow/compiler/xla/service/liveness_util_test.cc index bad4be149a68bdc07a1f7e4ac0668728d10d152e..6a4fde87614750d21cf9572e7f447bba924379c4 100644 --- a/tensorflow/compiler/xla/service/liveness_util_test.cc +++ b/tensorflow/compiler/xla/service/liveness_util_test.cc @@ -85,9 +85,9 @@ TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(Literal::CreateR1({2}))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); + Literal::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); @@ -122,10 +122,10 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) { BuildModuleAndRunAnalysis(builder.Build()); - EXPECT_TRUE( - CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_)); - EXPECT_TRUE( - CanShareOperandBufferWithUser(exp, {}, log, {}, *points_to_analysis_)); + EXPECT_TRUE(CanShareOperandBufferWithUser(param, {}, exp, {}, + points_to_analysis_.get())); + EXPECT_TRUE(CanShareOperandBufferWithUser(exp, {}, log, {}, + points_to_analysis_.get())); } TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { @@ -143,9 +143,9 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { BuildModuleAndRunAnalysis(builder.Build()); EXPECT_FALSE(CanShareOperandBufferWithUser(param0, {}, result, {}, - *points_to_analysis_)); + points_to_analysis_.get())); EXPECT_FALSE(CanShareOperandBufferWithUser(param1, {}, result, {}, - *points_to_analysis_)); + points_to_analysis_.get())); } TEST_F(CanShareOperandBufferWithUserTest, CopyShares) { @@ -161,10 +161,10 @@ TEST_F(CanShareOperandBufferWithUserTest, CopyShares) { BuildModuleAndRunAnalysis(builder.Build()); - EXPECT_TRUE( - CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_)); - EXPECT_TRUE( - CanShareOperandBufferWithUser(exp, {}, copy, {}, *points_to_analysis_)); + EXPECT_TRUE(CanShareOperandBufferWithUser(param, {}, exp, {}, + points_to_analysis_.get())); + EXPECT_TRUE(CanShareOperandBufferWithUser(exp, {}, copy, {}, + points_to_analysis_.get())); } TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { @@ -180,9 +180,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(Literal::CreateR1({2}))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); + Literal::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); @@ -197,9 +197,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { // The fusion instruction can share with tuple element 1. EXPECT_FALSE(CanShareOperandBufferWithUser(tuple, {0}, fusion, {}, - *points_to_analysis_)); + points_to_analysis_.get())); EXPECT_TRUE(CanShareOperandBufferWithUser(tuple, {1}, fusion, {}, - *points_to_analysis_)); + points_to_analysis_.get())); } TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { @@ -221,12 +221,12 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { // The DynamicUpdateSlice instruction can share with the data operand, but not // with update or starts. - EXPECT_TRUE( - CanShareOperandBufferWithUser(data, {}, dus, {}, *points_to_analysis_)); - EXPECT_FALSE( - CanShareOperandBufferWithUser(update, {}, dus, {}, *points_to_analysis_)); - EXPECT_FALSE( - CanShareOperandBufferWithUser(starts, {}, dus, {}, *points_to_analysis_)); + EXPECT_TRUE(CanShareOperandBufferWithUser(data, {}, dus, {}, + points_to_analysis_.get())); + EXPECT_FALSE(CanShareOperandBufferWithUser(update, {}, dus, {}, + points_to_analysis_.get())); + EXPECT_FALSE(CanShareOperandBufferWithUser(starts, {}, dus, {}, + points_to_analysis_.get())); } TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { @@ -234,15 +234,15 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); auto a = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); + Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); auto b = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); auto dot = builder.AddInstruction( HloInstruction::CreateBinary(data_shape, HloOpcode::kDot, a, b)); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto add_operand = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape, one, {1})); @@ -256,7 +256,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { // Output fused dot add should be able to share buffer with 'add_operand'. EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, - *points_to_analysis_)); + points_to_analysis_.get())); } TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) { @@ -264,9 +264,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) { Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); auto a = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); + Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); auto b = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); auto b_t = builder.AddInstruction( HloInstruction::CreateTranspose(data_shape, b, {1, 0})); @@ -274,7 +274,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) { HloInstruction::CreateBinary(data_shape, HloOpcode::kDot, a, b_t)); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto add_operand = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape, one, {1})); @@ -292,7 +292,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) { // Output fused transpose-dot-add should be share buffer with 'add_operand'. EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, - *points_to_analysis_)); + points_to_analysis_.get())); } TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { @@ -300,7 +300,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto operand = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape, one, {1})); @@ -308,7 +308,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { HloInstruction::CreateReverse(data_shape, operand, {0, 1})); auto two = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two)); @@ -320,7 +320,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { // Output fused operand->reverse->add cannot alias operand buffer 'operand'. EXPECT_FALSE(CanShareOperandBufferWithUser(operand, {}, fusion, {}, - *points_to_analysis_)); + points_to_analysis_.get())); } TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { @@ -360,8 +360,8 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { RunAnalysis(); // The While instruction can share with the data operand. - EXPECT_TRUE( - CanShareOperandBufferWithUser(data, {}, whil, {}, *points_to_analysis_)); + EXPECT_TRUE(CanShareOperandBufferWithUser(data, {}, whil, {}, + points_to_analysis_.get())); } } // namespace diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 12b2762f0ed7eb9acce8a60d4501ab6ce53c3b57..61945bd128e68b59bd0a1156882c5b29d6be2a27 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -29,7 +29,6 @@ cc_library( ":ir_array", ":llvm_util", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla/legacy_flags:alias_analysis_flags", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:logical_buffer", @@ -47,7 +46,6 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:llvm_util_flags", "//tensorflow/core:lib", "@llvm//:core", "@llvm//:support", @@ -93,6 +91,7 @@ cc_library( deps = [ ":ir_array", ":llvm_loop", + ":ops", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc index 02710ff57f6f75fe6aa1c32670cc7196ae4c402f..1f6932bcc3fb76adb874b963ecf5fb1b16d8a9f4 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include "external/llvm/include/llvm/IR/MDBuilder.h" -#include "tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/types.h" @@ -51,28 +50,37 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo, buffer_slice = *slices.begin(); } - llvm::MDNode*& alias_scope_md = alias_scope_metadata_[buffer_slice]; - if (alias_scope_md == nullptr) { - alias_scope_md = - GetAliasScopeMetadataForBuffer(buffer_slice, GetAliasDomain()); + if (module_.config().debug_options().xla_llvm_enable_alias_scope_metadata()) { + llvm::MDNode*& alias_scope_md = alias_scope_metadata_[buffer_slice]; + if (alias_scope_md == nullptr) { + alias_scope_md = + GetAliasScopeMetadataForBuffer(buffer_slice, GetAliasDomain()); + } + array->AddAliasScopeMetadata(alias_scope_md); } - array->AddAliasScopeMetadata(alias_scope_md); - llvm::MDNode*& noalias_md = noalias_metadata_[buffer_slice]; - if (noalias_md == nullptr) { - noalias_md = GetNoaliasMetadataForBuffer(buffer_slice, GetAliasDomain(), - assignment_, hlo); + if (module_.config().debug_options().xla_llvm_enable_noalias_metadata()) { + llvm::MDNode*& noalias_md = noalias_metadata_[buffer_slice]; + if (noalias_md == nullptr) { + noalias_md = GetNoaliasMetadataForBuffer(buffer_slice, GetAliasDomain(), + assignment_, hlo); + } + array->AddNoaliasMetadata(noalias_md); } - array->AddNoaliasMetadata(noalias_md); - // Parameters of the entry computation are never stored to, loading from a - // parameter pointer should always return the same result within a loop. - if (hlo.opcode() == HloOpcode::kParameter) { - const std::vector& parameter_instructions = - module_.entry_computation()->parameter_instructions(); - if (std::find(parameter_instructions.begin(), parameter_instructions.end(), - &hlo) != parameter_instructions.end()) { - array->AddInvariantLoad(llvm::MDNode::get(*context_, /*MDs=*/{})); + if (module_.config() + .debug_options() + .xla_llvm_enable_invariant_load_metadata()) { + // Parameters of the entry computation are never stored to, loading from a + // parameter pointer should always return the same result within a loop. + if (hlo.opcode() == HloOpcode::kParameter) { + const std::vector& parameter_instructions = + module_.entry_computation()->parameter_instructions(); + if (std::find(parameter_instructions.begin(), + parameter_instructions.end(), + &hlo) != parameter_instructions.end()) { + array->AddInvariantLoad(llvm::MDNode::get(*context_, /*MDs=*/{})); + } } } } @@ -87,12 +95,6 @@ llvm::MDNode* AliasAnalysis::GetAliasDomain() { llvm::MDNode* AliasAnalysis::GetAliasScopeMetadataForBuffer( const BufferAllocation::Slice& buffer_slice, llvm::MDNode* domain) { - legacy_flags::AliasAnalysisFlags* flags = - legacy_flags::GetAliasAnalysisFlags(); - if (!flags->xla_emit_alias_scope) { - return nullptr; - } - // While we could synthesize an alias.scope, doing so is not more profitable // than LLVM's default behavior. if (buffer_slice.allocation() == kParameterAllocation) { @@ -109,12 +111,6 @@ llvm::MDNode* AliasAnalysis::GetAliasScopeMetadataForBuffer( llvm::MDNode* AliasAnalysis::GetNoaliasMetadataForBuffer( const BufferAllocation::Slice& buffer_slice, llvm::MDNode* domain, const BufferAssignment& assignment, const HloInstruction& hlo) { - legacy_flags::AliasAnalysisFlags* flags = - legacy_flags::GetAliasAnalysisFlags(); - if (!flags->xla_emit_alias_scope) { - return nullptr; - } - // We want to construct a list of buffers which: // // 1. Do not alias the given buffer. diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index b259d348708c227a3e580fd352422e457284129d..26e73a6ec390c5823c2a0315480a427ea0a7b373 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -128,6 +128,27 @@ Status FusedIrEmitter::HandleParameter(HloInstruction* parameter) { return Status::OK(); } +Status FusedIrEmitter::HandleTuple( + HloInstruction* tuple, + tensorflow::gtl::ArraySlice operands) { + std::vector operand_elemental_ir_types; + for (HloInstruction* operand : operands) { + operand_elemental_ir_types.push_back(llvm_ir::PrimitiveTypeToIrType( + operand->shape().element_type(), ir_builder_)); + } + generators_[tuple] = + [=](const IrArray::Index& index) -> StatusOr { + llvm::Value* ret = llvm::UndefValue::get(llvm::StructType::get( + ir_builder_->getContext(), operand_elemental_ir_types)); + for (size_t i = 0; i < ShapeUtil::TupleElementCount(tuple->shape()); ++i) { + TF_ASSIGN_OR_RETURN(llvm::Value * val_i, generators_[operands[i]](index)); + ret = ir_builder_->CreateInsertValue(ret, val_i, i); + } + return ret; + }; + return Status::OK(); +} + Status FusedIrEmitter::FinishVisit(HloInstruction* root) { fused_root_ = root; return tensorflow::Status::OK(); diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h index 79007b7099a32973cada7a9986ff95c5e4aabec6..1cd8d1194686236dd11f71c56d668708ad113f03 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h @@ -54,6 +54,11 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { Status HandleParameter(HloInstruction* parameter) override; + // Emits the ir value for each element in the tuple. + Status HandleTuple( + HloInstruction* tuple, + tensorflow::gtl::ArraySlice operands) override; + Status FinishVisit(HloInstruction* root) override; // Returns the generator function for the root of the fused computation. diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index e401305ae7342a9db09499c9b3846f5a0a705fa7..dd13a39232badf659b6d1ba4d89e14c0632bc885 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -85,7 +85,7 @@ IrArray::IrArray(llvm::Value* base_ptr, const Shape& shape) ++depth; } - if (ShapeUtil::Rank(*shape_) == 0) { + if (!ShapeUtil::IsArray(*shape_) || ShapeUtil::IsScalar(*shape_)) { DCHECK(depth == 1 || depth == 0) << depth; } else { DCHECK_EQ(depth, ShapeUtil::Rank(*shape_)) << shape.ShortDebugString(); @@ -153,6 +153,28 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( return Index(source_multidim_index); } +IrArray::Index IrArray::Index::SourceIndexOfSlice( + const Shape& shape, tensorflow::gtl::ArraySlice starts, + tensorflow::gtl::ArraySlice strides, + llvm::IRBuilder<>* builder) const { + Index source_index(multidim_.size()); + for (int i = 0; i < multidim_.size(); ++i) { + int64 stride = strides[i]; + auto type = multidim_[i]->getType(); + + if (stride != 1) { + source_index[i] = builder->CreateAdd( + builder->CreateMul(multidim_[i], + llvm::ConstantInt::get(type, stride)), + llvm::ConstantInt::get(type, starts[i])); + } else { + source_index[i] = builder->CreateAdd( + multidim_[i], llvm::ConstantInt::get(type, starts[i])); + } + } + return source_index; +} + IrArray::Index IrArray::Index::SourceIndexOfTranspose( const Shape& shape, const Shape& operand_shape, tensorflow::gtl::ArraySlice dimension_mapping, diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index 91cb3a679fd67fffb29f8a935cc3c65e9442136b..e72f6518727ca35224b3f8649d00fc798336bf18 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -115,6 +115,16 @@ class IrArray { Index SourceIndexOfReshape(const Shape& shape, const Shape& operand_shape, llvm::IRBuilder<>* builder) const; + // Returns the index into the source operand from which a slice operation + // selects a value to be placed into index "this". The slice is described + // by starting indices `starts` and stride values `strides`. + // + // Precondition: "this" is an index into a slice whose shape is `shape`. + Index SourceIndexOfSlice(const Shape& shape, + tensorflow::gtl::ArraySlice starts, + tensorflow::gtl::ArraySlice strides, + llvm::IRBuilder<>* builder) const; + // Given that "this" is the target index of a transpose from `operand_shape` // to `shape` with the given dimension mapping, returns the source index. Index SourceIndexOfTranspose( diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index ff2f4cd693ca76c0e4d20522f50a302fb3ae2c40..e348511c6269e15e19ca64af1c90458cf4d4ba7d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -22,7 +22,6 @@ limitations under the License. #include "external/llvm/include/llvm/IR/Operator.h" #include "external/llvm/include/llvm/Target/TargetOptions.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" @@ -163,36 +162,36 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, llvm::Constant* value; switch (shape.element_type()) { case PRED: - value = llvm::ConstantInt::get( - ir_element_type, LiteralUtil::Get(literal, *multi_index)); + value = llvm::ConstantInt::get(ir_element_type, + literal.Get(*multi_index)); break; case U8: - value = llvm::ConstantInt::get( - ir_element_type, LiteralUtil::Get(literal, *multi_index)); + value = llvm::ConstantInt::get(ir_element_type, + literal.Get(*multi_index)); break; case S32: - value = llvm::ConstantInt::get( - ir_element_type, LiteralUtil::Get(literal, *multi_index)); + value = llvm::ConstantInt::get(ir_element_type, + literal.Get(*multi_index)); break; case U32: - value = llvm::ConstantInt::get( - ir_element_type, LiteralUtil::Get(literal, *multi_index)); + value = llvm::ConstantInt::get(ir_element_type, + literal.Get(*multi_index)); break; case S64: - value = llvm::ConstantInt::get( - ir_element_type, LiteralUtil::Get(literal, *multi_index)); + value = llvm::ConstantInt::get(ir_element_type, + literal.Get(*multi_index)); break; case U64: - value = llvm::ConstantInt::get( - ir_element_type, LiteralUtil::Get(literal, *multi_index)); + value = llvm::ConstantInt::get(ir_element_type, + literal.Get(*multi_index)); break; case F32: - value = llvm::ConstantFP::get( - ir_element_type, LiteralUtil::Get(literal, *multi_index)); + value = llvm::ConstantFP::get(ir_element_type, + literal.Get(*multi_index)); break; case F64: - value = llvm::ConstantFP::get( - ir_element_type, LiteralUtil::Get(literal, *multi_index)); + value = llvm::ConstantFP::get(ir_element_type, + literal.Get(*multi_index)); break; default: LOG(FATAL) << "unsupported type " << shape.element_type(); @@ -357,31 +356,9 @@ void EmitLogging(const char* tag, llvm::Value* value, void SetTbaaForInstruction(llvm::Instruction* instruction, Shape shape, bool is_pointer_to) { - legacy_flags::LlvmUtilFlags* flags = legacy_flags::GetLlvmUtilFlags(); - if (!flags->xla_emit_tbaa) { - return; - } - - llvm::MDBuilder metadata_builder(instruction->getContext()); - llvm::MDNode* root = metadata_builder.createTBAARoot("XLA TBAA"); - string type_name; - if (is_pointer_to) { - type_name += "pointer-to "; - } - // Scalars do not have layout which makes it permissible to omit an explicit - // layout. To make sure that equivalent scalar shapes have the same TBAA, - // remove the (meaningless) explicit layout if one is present. - if (ShapeUtil::Rank(shape) == 0) { - LayoutUtil::ClearLayout(&shape); - } else { - CHECK(shape.has_layout()); - } - type_name += shape.ShortDebugString(); - llvm::MDNode* tbaa_node = - metadata_builder.createTBAANode(llvm_ir::AsStringRef(type_name), root); - instruction->setMetadata(llvm::LLVMContext::MD_tbaa, - metadata_builder.createTBAAStructTagNode( - tbaa_node, tbaa_node, /*Offset=*/0)); + // TODO(b/62903316): TBAA metadata causes LLVM to miscompile generated code, + // most likely because the generated metadata is incorrect. Disable TBAA + // metadata while we resolve this. } void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index 9a128b2aa6f2d5e5650624f103c573e671335f7b..8839ec582df844f46f060e26917f15aa297cba3d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ops.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -51,8 +52,41 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, shape_(target_array.GetShape()), ir_builder_(ir_builder) {} +LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, + tensorflow::gtl::ArraySlice target_arrays, + llvm::IRBuilder<>* ir_builder) + : body_emitter_([=](const llvm_ir::IrArray::Index array_index) + -> ::tensorflow::Status { + // Convert target_element_generator to a BodyEmitter. + TF_ASSIGN_OR_RETURN(llvm::Value * target_element, + target_element_generator(array_index)); + if (target_arrays.size() == 1) { + target_arrays[0].EmitWriteArrayElement(array_index, target_element, + ir_builder); + return tensorflow::Status::OK(); + } + + for (int64 i = 0; i < target_arrays.size(); ++i) { + target_arrays[i].EmitWriteArrayElement( + array_index, ir_builder_->CreateExtractValue(target_element, i), + ir_builder); + } + return tensorflow::Status::OK(); + }), + ir_builder_(ir_builder) { + if (target_arrays.size() > 1) { + // The sanity check for multiple outputs. + shape_ = target_arrays[0].GetShape(); + for (int64 i = 1; i < target_arrays.size(); ++i) { + const Shape& element_shape = target_arrays[i].GetShape(); + CHECK(ShapeUtil::SameDimensions(shape_, element_shape)); + } + } else { + shape_ = target_arrays[0].GetShape(); + } +} + IrArray::Index LoopEmitter::EmitIndexAndSetExitBasicBlock() { - CHECK(!ShapeUtil::IsTuple(shape_)); if (ShapeUtil::IsScalar(shape_)) { // No loop needed, so set exit_bb_ to nullptr. exit_bb_ = nullptr; diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h index 08171e9e9de294339359f86059f89dcf4939ddea..ab6b702c441e04f2c7988a3dcb9880a86ff95355 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h @@ -47,6 +47,10 @@ class LoopEmitter { // element of the given target array. LoopEmitter(const ElementGenerator& target_element_generator, const IrArray& target_array, llvm::IRBuilder<>* ir_builder); + // Same as previous method except emits multiple targets in an array. + LoopEmitter(const ElementGenerator& target_element_generator, + tensorflow::gtl::ArraySlice target_arrays, + llvm::IRBuilder<>* ir_builder); LoopEmitter(const LoopEmitter&) = delete; LoopEmitter& operator=(const LoopEmitter&) = delete; virtual ~LoopEmitter() = default; diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 131c2ee87b0e78a4f7e315bfbb2b2793c0a91fa1..25588a6fb8adb8c05a773c56e24c7b252e5e26f0 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/legacy_flags/service_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" @@ -46,13 +47,6 @@ namespace se = ::perftools::gputools; namespace xla { -/* static */ StatusOr> LocalService::NewService( - perftools::gputools::Platform* platform) { - ServiceOptions default_options; - default_options.set_platform(platform); - return NewService(default_options); -} - /* static */ StatusOr> LocalService::NewService( const ServiceOptions& options) { perftools::gputools::Platform* platform = options.platform(); @@ -62,7 +56,6 @@ namespace xla { BackendOptions backend_options; backend_options.set_platform(platform) - .set_number_of_replicas(options.number_of_replicas()) .set_intra_op_parallelism_threads(options.intra_op_parallelism_threads()); TF_ASSIGN_OR_RETURN(std::unique_ptr backend, Backend::CreateBackend(backend_options)); @@ -70,15 +63,15 @@ namespace xla { TF_ASSIGN_OR_RETURN(std::unique_ptr compute_constant_backend, CreateComputeConstantBackend()); std::unique_ptr service(new LocalService( - std::move(backend), std::move(compute_constant_backend))); + options, std::move(backend), std::move(compute_constant_backend))); return std::move(service); } -LocalService::LocalService(std::unique_ptr execute_backend, +LocalService::LocalService(const ServiceOptions& options, + std::unique_ptr execute_backend, std::unique_ptr compute_constant_backend) - : Service(std::move(execute_backend), std::move(compute_constant_backend)) { - runs_in_client_process_ = true; -} + : Service(options, std::move(execute_backend), + std::move(compute_constant_backend)) {} namespace { // Returns the space required to allocate a shape. If @@ -152,7 +145,12 @@ StatusOr> LocalService::CompileExecutable( // Construct computation layout from the argument layouts. auto module_config = MakeUnique(*program_shape); module_config->set_has_hybrid_result(has_hybrid_result); - module_config->set_replica_count(execute_backend_->Replicas().size()); + module_config->set_replica_count(options_.number_of_replicas()); + module_config->set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + if (execute_backend_->eigen_intra_op_thread_pool() != nullptr) { + module_config->set_intra_op_parallelism_threads( + execute_backend_->eigen_intra_op_thread_pool()->NumThreads()); + } legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); if (flags->xla_hlo_profile) { module_config->enable_hlo_profiling(true); diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index 767a3ab697febb283af448b25369445152381a5e..13797ec0450bd0eb2030b111464c42e966792266 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -35,11 +35,7 @@ namespace xla { // in the same process as the client. class LocalService : public Service { public: - // Factory for creating a LocalService. The parameter platform is the platform - // that the service should target. If platform is null then the default - // platform is used. - static StatusOr> NewService( - perftools::gputools::Platform* platform); + // Factory for creating a LocalService. static StatusOr> NewService( const ServiceOptions& options); @@ -60,7 +56,8 @@ class LocalService : public Service { const Shape* result_layout, int device_ordinal, bool has_hybrid_result); private: - explicit LocalService(std::unique_ptr backend, + explicit LocalService(const ServiceOptions& options, + std::unique_ptr backend, std::unique_ptr compute_constant_backend); LocalService(const LocalService&) = delete; void operator=(const LocalService&) = delete; diff --git a/tensorflow/compiler/xla/service/logical_buffer.cc b/tensorflow/compiler/xla/service/logical_buffer.cc index d24a592f46ed2dd8fd9c927e8ed9816771a7396c..3e843b202997a09f76993acd4d02f4de9aae9854 100644 --- a/tensorflow/compiler/xla/service/logical_buffer.cc +++ b/tensorflow/compiler/xla/service/logical_buffer.cc @@ -27,7 +27,7 @@ limitations under the License. namespace xla { string LogicalBuffer::ToString() const { - return tensorflow::strings::StrCat(instruction_->FullyQualifiedName(), "[", + return tensorflow::strings::StrCat(instruction_->name(), "[", tensorflow::str_util::Join(index_, ","), "](#", id_, " @", color_.value(), ")"); } diff --git a/tensorflow/compiler/xla/service/logical_buffer.h b/tensorflow/compiler/xla/service/logical_buffer.h index 566cd01ea437433e5e328ad523090e682a799233..a9f6688612002f320541b7c1d20df4dd41ea971a 100644 --- a/tensorflow/compiler/xla/service/logical_buffer.h +++ b/tensorflow/compiler/xla/service/logical_buffer.h @@ -95,11 +95,13 @@ class LogicalBuffer { // Functions which return the size and alignment of a logical buffer in bytes. using SizeFunction = std::function; - using AlignmentFunction = std::function; + using AlignmentFunction = std::function; - LogicalBuffer(HloInstruction* instruction, const ShapeIndex& index, Id id, - Color color) - : instruction_(instruction), index_(index), id_(id), color_(color) {} + LogicalBuffer(HloInstruction* instruction, const ShapeIndex& index, Id id) + : instruction_(instruction), + index_(index), + id_(id), + color_(kInvalidColor) {} Id id() const { return id_; } @@ -112,8 +114,19 @@ class LogicalBuffer { // Return the color of the logical buffer. Differently colored buffers can // not be parts of the same allocation. - Color color() const { return color_; } - void set_color(Color color) { color_ = color; } + Color color() const { + CHECK_NE(color_, kInvalidColor) + << "Should not query the color of a buffer that was never colored"; + return color_; + } + + void set_color(Color color) { + CHECK_NE(color, kInvalidColor) + << "Should not set the color of a buffer to the invalid color"; + color_ = color; + } + + bool has_color() const { return color_ != kInvalidColor; } // Return the shape of the buffer. This reference points into the shape field // of the instruction defining the buffer. Therefore, the returned shape will @@ -143,6 +156,8 @@ class LogicalBuffer { static LogicalBufferProto::Location ToLocationProto( const HloInstruction& instruction, const ShapeIndex& index); + const Color kInvalidColor = Color(-1); + private: HloInstruction* instruction_; ShapeIndex index_; diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc new file mode 100644 index 0000000000000000000000000000000000000000..dafefdc4910a2ce3ed03bb23362c8c44d4e11cfb --- /dev/null +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc @@ -0,0 +1,61 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +StatusOr ReducePrecisionInsertion::Run(HloModule* module) { + bool changed = false; + VLOG(1) << "Running ReducePrecisionInsertion pass on " << module->name(); + + for (auto& computation : module->computations()) { + std::vector instructions_to_suffix; + + for (auto& instruction : computation->instructions()) { + VLOG(3) << "Visited instruction: " << instruction->ToString(); + + // For now, ReducePrecision is only implemented for F32 data, so this + // ignore instructions that produce other data. In particular, this + // currently ignores instructions producing tuples, even if those tuples + // contain F32 data inside them. The assumption is that in most cases + // equivalent behavior can be obtained by adding ReducePrecision + // instructions after the instructions that pull the F32 data out of the + // tuples. + if (instruction->shape().element_type() == PrimitiveType::F32 && + should_reduce_output_precision_(instruction->opcode())) { + instructions_to_suffix.push_back(instruction.get()); + } + } + + for (auto& instruction : instructions_to_suffix) { + HloInstruction* reduced = + computation->AddInstruction(HloInstruction::CreateReducePrecision( + instruction->shape(), instruction, exponent_bits_, + mantissa_bits_)); + TF_RETURN_IF_ERROR( + computation->ReplaceUsesOfInstruction(instruction, reduced)); + VLOG(2) << "Inserted new op after instruction: " + << instruction->ToString(); + changed = true; + } + } + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.h b/tensorflow/compiler/xla/service/reduce_precision_insertion.h new file mode 100644 index 0000000000000000000000000000000000000000..e9c8bba0313e4ba4622560b95484630427c05abf --- /dev/null +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.h @@ -0,0 +1,67 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_PRECISION_INSERTION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_PRECISION_INSERTION_H_ + +#include "tensorflow/compiler/xla/service/buffer_liveness.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace xla { + +// HLO pass which inserts reduce-precision instructions into the HLO graph, for +// purposes of experimenting with the effects of reduced-precision storage of +// intermediate values. +class ReducePrecisionInsertion : public HloPassInterface { + using OpcodeFilterFunction = std::function; + + public: + // The exponent_bits and mantissa_bits arguments specify the parameters of + // the instructions to insert. The instructions will be inserted after each + // instruction with an opcode for which the should_reduce_output_precision + // function returns true and the output type is F32. + explicit ReducePrecisionInsertion( + const int exponent_bits, const int mantissa_bits, + const OpcodeFilterFunction& should_reduce_output_precision) + : exponent_bits_(exponent_bits), + mantissa_bits_(mantissa_bits), + should_reduce_output_precision_(should_reduce_output_precision) {} + ~ReducePrecisionInsertion() override{}; + + tensorflow::StringPiece name() const override { + return "reduce-precision-insertion"; + } + + // Run the pass on the given module. Returns whether the module was changed + // (reduce-precision instructions were inserted). + StatusOr Run(HloModule* module) override; + + private: + // Parameters for the precision reduction to be added. + const int exponent_bits_; + const int mantissa_bits_; + + // Function to determine (from the opcode) whether a given instruction should + // have a reduce-precision instruction inserted in its output stream. + const OpcodeFilterFunction& should_reduce_output_precision_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_PRECISION_INSERTION_H_ diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc b/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..80717ec2e3f43a968b04dae1367cb7f78fa08b25 --- /dev/null +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc @@ -0,0 +1,186 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { + +using ::testing::UnorderedElementsAre; + +class ReducePrecisionInsertionTest : public HloTestBase { + protected: + bool InsertOps(HloModule* module, + const std::function& filter) { + ReducePrecisionInsertion op_insertion(5, 10, filter); + StatusOr result = op_insertion.Run(module); + EXPECT_IS_OK(result.status()); + return result.ValueOrDie(); + } +}; + +TEST_F(ReducePrecisionInsertionTest, RootInstruction) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4}); + + // Create a simple graph with a parameter feeding a unary cosine function. + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + // Confirm expected state before adding ops. + EXPECT_EQ(computation->root_instruction(), b); + + EXPECT_TRUE(InsertOps(module.get(), + [](HloOpcode h) { return h == HloOpcode::kCos; })); + + // Confirm expected graph after adding ops. + EXPECT_THAT(computation->root_instruction(), op::ReducePrecision()); + EXPECT_EQ(computation->root_instruction()->operand(0), b); +} + +TEST_F(ReducePrecisionInsertionTest, NonRootInstruction) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4}); + + // Create a graph with two parameters feeding into unary cosine functions, + // and the output of those feeds into an add function. Feeding the outputs + // from the suffixed cosine functions into a binary add function allows us to + // confirm that the separate operand streams are not crossed when the new + // instructions are inserted. + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* a_cos = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); + + HloInstruction* b = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* b_cos = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCos, b)); + + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_cos, b_cos)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + // Confirm expected graph before adding ops. + EXPECT_EQ(c->operand(0), a_cos); + EXPECT_EQ(c->operand(1), b_cos); + + EXPECT_TRUE(InsertOps(module.get(), + [](HloOpcode h) { return h == HloOpcode::kCos; })); + + // Confirm expected graph after adding ops. + EXPECT_THAT(c->operand(0), op::ReducePrecision()); + EXPECT_EQ(c->operand(0)->operand(0), a_cos); + EXPECT_THAT(c->operand(1), op::ReducePrecision()); + EXPECT_EQ(c->operand(1)->operand(0), b_cos); +} + +TEST_F(ReducePrecisionInsertionTest, OutputIsNotFloat) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(S32, {4}); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); + HloInstruction* y = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + // Confirm expected graph before adding ops. + EXPECT_THAT(x->users(), UnorderedElementsAre(y)); + EXPECT_EQ(computation->root_instruction(), y); + + // Since none of the instructions produce F32 data, this should not change + // the graph. + EXPECT_FALSE(InsertOps(module.get(), [](HloOpcode) { return true; })); + + // Confirm that graph has not changed. + EXPECT_THAT(x->users(), UnorderedElementsAre(y)); + EXPECT_EQ(computation->root_instruction(), y); +} + +TEST_F(ReducePrecisionInsertionTest, ShouldReduceOutputPrecisionIsFalse) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4}); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); + HloInstruction* y = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + // Confirm expected graph before adding ops. + EXPECT_THAT(x->users(), UnorderedElementsAre(y)); + EXPECT_EQ(computation->root_instruction(), y); + + // Since none of the instructions match the should_reduce_output_precision + // function, this should not change the graph. + EXPECT_FALSE(InsertOps(module.get(), [](HloOpcode h) { return false; })); + + // Confirm that graph has not changed. + EXPECT_THAT(x->users(), UnorderedElementsAre(y)); + EXPECT_EQ(computation->root_instruction(), y); +} + +TEST_F(ReducePrecisionInsertionTest, InsertionIsNotRecursive) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4}); + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateReducePrecision(shape, a, 9, 23)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + // Confirm expected state before adding ops. + EXPECT_EQ(computation->root_instruction(), b); + + // This should insert a new ReducePrecision after the existing one, but + // should not then recurse by adding another after the just-inserted one. + EXPECT_TRUE(InsertOps(module.get(), [](HloOpcode h) { + return h == HloOpcode::kReducePrecision; + })); + + // Confirm expected graph after adding ops. + EXPECT_THAT(computation->root_instruction(), op::ReducePrecision()); + EXPECT_EQ(computation->root_instruction()->operand(0), b); +} + +} // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index 9becdb2bed480d610e658303ee7deff4cf7d2743..49c175552028dd3d6fbd08343afd65fd59d647a1 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -84,7 +84,7 @@ TEST_F(ReshapeMoverTest, 1ConstantAnd1ReshapesOnRngNotMoved) { builder.AddInstruction(HloInstruction::CreateReshape(root_shape, rng0)); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateFromShape(root_shape))); + HloInstruction::CreateConstant(Literal::CreateFromShape(root_shape))); builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, const1)); @@ -179,9 +179,8 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMoved) { TEST_F(ReshapeMoverTest, 1ConstantAnd2ReshapesMoved) { HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {2, 3}); - auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR2( - {{true, true, false}, {false, false, true}}))); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{true, true, false}, {false, false, true}}))); auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {1, 3, 1, 2}), "param1")); @@ -263,12 +262,12 @@ TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) { HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {2, 3}); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape0 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const0)); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const1)); @@ -318,7 +317,7 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) { auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {1, 3, 1, 2}), "param0")); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape0 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); builder.AddInstruction(HloInstruction::CreateBinary( @@ -464,7 +463,7 @@ TEST_F(ReshapeMoverTest, ImplicitlyBroadcastReshapeIsNotMovedBug37787999) { auto reshape = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {128, 1}), param0)); Array2D a(128, 1024); - auto literal = LiteralUtil::CreateR2FromArray2D(a); + auto literal = Literal::CreateR2FromArray2D(a); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); auto multiply = builder.AddInstruction(HloInstruction::CreateBinary( diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 85ca7e4e59ce9e69a671b829f3c2c3a4834a99ce..44e0fabe771ec2f5dd5647330213d5099593976c 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/legacy_flags/backend_flags.h" #include "tensorflow/compiler/xla/legacy_flags/service_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/compiler.h" @@ -141,12 +142,13 @@ int ServiceOptions::intra_op_parallelism_threads() const { } BackendOptions backend_options; backend_options.set_platform(platform); - backend_options.set_number_of_replicas(options.number_of_replicas()); TF_ASSIGN_OR_RETURN(execute_backend, Backend::CreateBackend(backend_options)); + TF_ASSIGN_OR_RETURN(std::unique_ptr compute_constant_backend, CreateComputeConstantBackend()); - std::unique_ptr service(new Service( - std::move(execute_backend), std::move(compute_constant_backend))); + std::unique_ptr service( + new Service(options, std::move(execute_backend), + std::move(compute_constant_backend))); return std::move(service); } @@ -158,24 +160,30 @@ Service::CreateComputeConstantBackend() { if (platform->id() == se::host::kHostPlatformId) { BackendOptions backend_options; backend_options.set_platform(platform); - backend_options.set_number_of_replicas(1); return Backend::CreateBackend(backend_options); } } return NotFound("CPU platform not found"); } -/* static */ Compiler::HloDumper Service::MakeHloDumper() { - return [](const HloModule& module, const string& label) { - return Executable::DumpExecutedHlo(module, label, /*profile=*/nullptr); - }; -} - -Service::Service(std::unique_ptr execute_backend, +Service::Service(const ServiceOptions& options, + std::unique_ptr execute_backend, std::unique_ptr compute_constant_backend) - : execute_backend_(std::move(execute_backend)), + : options_(options), + execute_backend_(std::move(execute_backend)), compute_constant_backend_(std::move(compute_constant_backend)) { + // TODO(b/32648682): this flag / options update dance will go away once we + // pass the replica count explicitly to the service. + if (options_.number_of_replicas() < 0) { + legacy_flags::BackendFlags* flags = legacy_flags::GetBackendFlags(); + options_.set_number_of_replicas(flags->xla_replicas); + } + if (execute_backend_) { + if (execute_backend_->device_count() > 0) { + CHECK_GE(execute_backend_->device_count(), options_.number_of_replicas()) + << "Requested more replicas than there are devices."; + } LOG(INFO) << Printf( "XLA service %p executing computations on platform %s. Devices:", this, execute_backend_->platform()->Name().c_str()); @@ -325,7 +333,7 @@ StatusOr> Service::CreateModuleConfig( module_config->enable_hlo_profiling(true); } - module_config->set_replica_count(backend->Replicas().size()); + module_config->set_replica_count(options_.number_of_replicas()); module_config->set_seed(execution_options.seed()); module_config->set_debug_options(execution_options.debug_options()); @@ -378,11 +386,9 @@ StatusOr>> Service::BuildExecutables( modules.push_back(std::move(module)); } - Compiler::HloDumper hlo_dumper = MakeHloDumper(); TF_ASSIGN_OR_RETURN( std::vector> executables, - backend->compiler()->Compile(std::move(modules), hlo_dumper, - std::move(executors))); + backend->compiler()->Compile(std::move(modules), std::move(executors))); if (!other_directory_path.empty()) { for (size_t i = 0; i < versioned_handles.size(); ++i) { @@ -429,15 +435,9 @@ StatusOr> Service::BuildExecutable( /*include_unreachable_instructions=*/ !executable_for_compute_constant)); - Compiler::HloDumper hlo_dumper = MakeHloDumper(); - if (executable_for_compute_constant && - !flags->xla_hlo_graph_for_compute_constant) { - hlo_dumper = [](const HloModule&, const string&) {}; - } - TF_ASSIGN_OR_RETURN( std::unique_ptr executable, - backend->compiler()->Compile(std::move(module), hlo_dumper, executor)); + backend->compiler()->Compile(std::move(module), executor)); if (!other_directory_path.empty()) { executable->set_session_module(std::move(session_module)); @@ -495,47 +495,55 @@ Service::ExecuteParallelAndRegisterResult( tensorflow::gtl::ArraySlice< std::vector> arguments, - Backend* backend, - tensorflow::gtl::ArraySlice executors, + Backend* backend, tensorflow::gtl::ArraySlice device_handles, tensorflow::gtl::ArraySlice result_tags) { - // TODO(b/33943292): Support for replication when using multiple computations. - TF_RET_CHECK(backend->Replicas().size() == 1); - - // Set up streams. + // Streams where the computation are launched, so we can wait on the streams + // to complete. std::vector::SmartPtr> streams; - for (se::StreamExecutor* executor : executors) { - TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, - backend->BorrowStream(executor)); - streams.push_back(std::move(stream)); - } - - // Set up run options. - std::vector run_options; - for (const Pool::SmartPtr& stream : streams) { - ExecutableRunOptions options; - options.set_stream(stream.get()); - options.set_allocator(backend->memory_allocator()); - options.set_inter_op_thread_pool(backend->inter_op_thread_pool()); - options.set_intra_op_thread_pool( - backend->eigen_intra_op_thread_pool_device()); - run_options.emplace_back(options, backend->StreamBorrower()); - } - - // Asynchronously launch all executables. + // Global data handles for the computation results, one for each computation. std::vector result_handles; - for (tensorflow::gtl::ArraySlice::size_type i = 0; - i < executables.size(); i++) { - TF_ASSIGN_OR_RETURN( - perftools::gputools::DeviceMemoryBase result, - executables[i]->ExecuteAsyncOnStream(&run_options[i], arguments[i])); - result_handles.push_back(allocation_tracker_.Register( - backend, executors[i]->device_ordinal(), result, - executables[i]->result_shape(), result_tags[i])); + + TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, + backend->computation_placer()->AssignDevices( + options_.number_of_replicas(), executables.size())); + + for (int64 i = 0; i < executables.size(); i++) { + // Stream executors for the replicas of the current computation. + TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i])); + for (int64 replica = 0; replica < replicas.size(); ++replica) { + TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, + backend->BorrowStream(replicas[replica])); + streams.push_back(std::move(stream)); + + // Set up run options. + ExecutableRunOptions options; + options.set_stream(streams.back().get()); + options.set_allocator(backend->memory_allocator()); + options.set_inter_op_thread_pool(backend->inter_op_thread_pool()); + options.set_intra_op_thread_pool( + backend->eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment); + ServiceExecutableRunOptions run_options(options, + backend->StreamBorrower()); + + // Asynchronously launch the computation. + TF_ASSIGN_OR_RETURN( + perftools::gputools::DeviceMemoryBase result, + executables[i]->ExecuteAsyncOnStream(&run_options, arguments[i])); + + // All replicas share the same device address for the result allocation, + // so only one of the replicas need to register the result handle. + if (replica == 0) { + result_handles.push_back(allocation_tracker_.Register( + backend, replicas[0]->device_ordinal(), result, + executables[i]->result_shape(), result_tags[i])); + } + } } // Wait for all executions to complete. - for (int64 i = 0; i < result_handles.size(); ++i) { + for (int64 i = 0; i < streams.size(); ++i) { if (!streams[i]->BlockHostUntilDone()) { return InternalError("failed to complete execution for stream %lld", i); } @@ -550,17 +558,23 @@ StatusOr Service::ExecuteAndRegisterResult( arguments, Backend* backend, perftools::gputools::StreamExecutor* executor, const string& result_tag, ExecutionProfile* profile) { - TF_RET_CHECK(!backend->Replicas().empty()); - // Set up streams. std::vector::SmartPtr> streams; - for (se::StreamExecutor* executor : backend->Replicas()) { + TF_ASSIGN_OR_RETURN(auto replicas, + Replicas(*backend, SingleComputationDeviceHandle())); + TF_RET_CHECK(!replicas.empty()); + for (se::StreamExecutor* executor : replicas) { TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, backend->BorrowStream(executor)); streams.push_back(std::move(stream)); } + TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, + backend->computation_placer()->AssignDevices( + options_.number_of_replicas(), + /*computation_count=*/1)); + // Set up run options. std::vector run_options; for (const Pool::SmartPtr& stream : streams) { @@ -570,19 +584,20 @@ StatusOr Service::ExecuteAndRegisterResult( options.set_inter_op_thread_pool(backend->inter_op_thread_pool()); options.set_intra_op_thread_pool( backend->eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment); run_options.emplace_back(options, backend->StreamBorrower(), backend->inter_op_thread_pool()); } perftools::gputools::DeviceMemoryBase result; - if (backend->Replicas().size() == 1) { + if (options_.number_of_replicas() == 1) { TF_ASSIGN_OR_RETURN( result, executable->ExecuteOnStreamWrapper( &run_options[0], profile, arguments)); } else { std::vector< tensorflow::gtl::ArraySlice> - repeated_arguments(backend->Replicas().size(), arguments); + repeated_arguments(options_.number_of_replicas(), arguments); TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams( run_options, repeated_arguments)); @@ -610,25 +625,26 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, std::vector versioned_handles; std::vector> module_configs; std::vector computation_names; + std::vector device_handles; - if (arg->requests_size() > execute_backend_->stream_executors().size()) { + if (arg->requests_size() * options_.number_of_replicas() > + execute_backend_->device_count()) { return FailedPrecondition( "there are not enough stream executors to execute %d computations", arg->requests_size()); } for (int64 i = 0; i < arg->requests_size(); ++i) { - // Get the stream executor on which the computation will run. Select the - // specific device if requested, otherwise select the i'th device from the - // list of available stream executors. - se::StreamExecutor* executor; - if (arg->requests(i).has_device_handle()) { - executor = - execute_backend_ - ->stream_executors()[arg->requests(i).device_handle().handle()]; - } else { - executor = execute_backend_->stream_executors()[i]; + // Get the stream executor for the i'th computation. This stream executor + // is one of the executors to run the replicated computation. + if (!arg->requests(i).has_device_handle()) { + return FailedPrecondition( + "device handles must be given to execute parallel computations"); } + TF_ASSIGN_OR_RETURN( + auto replicas, + Replicas(*execute_backend_, arg->requests(i).device_handle())); + se::StreamExecutor* executor = replicas[0]; CHECK(executor != nullptr); // Resolve the UserComputation object associated with the requested @@ -673,6 +689,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, module_configs.push_back(std::move(module_config)); computation_names.push_back(user_computation->name()); executors.push_back(executor); + device_handles.push_back(arg->requests(i).device_handle()); } // Build the user computations into HloModules and compile to generate the @@ -692,7 +709,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, TF_ASSIGN_OR_RETURN( std::vector outputs, ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments, - execute_backend_.get(), executors, + execute_backend_.get(), device_handles, computation_names)); for (const GlobalDataHandle& output : outputs) { ExecuteResponse response; @@ -706,10 +723,12 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, GetDeviceHandlesResponse* result) { - const int64 available_device_count = - execute_backend_->stream_executors().size(); - const int64 replicas = execute_backend_->Replicas().size(); - if (available_device_count < arg->device_count() * replicas) { + const int64 available_device_count = execute_backend_->device_count(); + const int64 replica_count = options_.number_of_replicas(); + if (replica_count <= 0) { + return FailedPrecondition("Replica count must be a positive integer"); + } + if (available_device_count < arg->device_count() * replica_count) { return ResourceExhausted( "Requested device count (%lld) exceeds the number of available devices " "on the target (%lld)", @@ -718,8 +737,8 @@ tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, for (int64 i = 0; i < arg->device_count(); ++i) { DeviceHandle device_handle; - device_handle.set_handle( - execute_backend_->stream_executors()[i * replicas]->device_ordinal()); + device_handle.set_handle(i); + device_handle.set_device_count(arg->device_count()); *result->add_device_handles() = device_handle; } @@ -841,11 +860,14 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, execute_backend_->default_stream_executor(), &profile)); - TF_RET_CHECK(!execute_backend_->Replicas().empty()); + TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, + SingleComputationDeviceHandle())); + TF_RET_CHECK(!replicas.empty()); + // Set up streams. std::vector::SmartPtr> streams; - for (se::StreamExecutor* executor : execute_backend_->Replicas()) { + for (se::StreamExecutor* executor : replicas) { TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, execute_backend_->BorrowStream(executor)); streams.push_back(std::move(stream)); @@ -927,19 +949,20 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, Literal literal = Literal(arg->literal()); const Shape& shape = literal.shape(); - if (ShapeUtil::IsTuple(shape) && execute_backend_->Replicas().size() > 1) { + if (ShapeUtil::IsTuple(shape) && options_.number_of_replicas() > 1) { // TODO(b/32990684): Tuple transfers to host end up allocating further // buffers - implement that correctly. return Unimplemented( "Tuple transfers to the device not supported with replication."); } - se::StreamExecutor* stream_executor; + std::vector replicas; if (arg->has_device_handle()) { - TF_ASSIGN_OR_RETURN(stream_executor, execute_backend_->stream_executor( - arg->device_handle().handle())); + TF_ASSIGN_OR_RETURN(replicas, + Replicas(*execute_backend_, arg->device_handle())); } else { - stream_executor = execute_backend_->default_stream_executor(); + TF_ASSIGN_OR_RETURN( + replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle())); } // Allocate memory on the device, using the stream executor. The size of the @@ -950,14 +973,12 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase allocation, execute_backend_->memory_allocator()->Allocate( - stream_executor->device_ordinal(), allocation_size)); + replicas[0]->device_ordinal(), allocation_size)); *result->mutable_data() = allocation_tracker_.Register( - execute_backend_.get(), stream_executor->device_ordinal(), allocation, - shape, StrCat("TransferToServer literal of size ", allocation_size)); + execute_backend_.get(), replicas[0]->device_ordinal(), allocation, shape, + StrCat("TransferToServer literal of size ", allocation_size)); - TF_ASSIGN_OR_RETURN(auto replicas, execute_backend_->Replicas( - stream_executor->device_ordinal())); for (se::StreamExecutor* executor : replicas) { TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralToDevice( @@ -968,7 +989,7 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, TransferToInfeedResponse* result) { - const int64 replica_count = execute_backend_->Replicas().size(); + const int64 replica_count = options_.number_of_replicas(); if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { return FailedPrecondition( "%s", @@ -980,11 +1001,14 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, se::StreamExecutor* executor; if (arg->has_device_handle()) { - TF_ASSIGN_OR_RETURN(auto replicas, execute_backend_->Replicas( - arg->device_handle().handle())); + TF_ASSIGN_OR_RETURN(auto replicas, + Replicas(*execute_backend_, arg->device_handle())); executor = replicas[arg->replica_id()]; } else { - executor = execute_backend_->Replicas()[arg->replica_id()]; + TF_ASSIGN_OR_RETURN( + auto replicas, + Replicas(*execute_backend_, SingleComputationDeviceHandle())); + executor = replicas[arg->replica_id()]; } return execute_backend_->transfer_manager()->TransferLiteralToInfeed( @@ -994,7 +1018,7 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, tensorflow::Status Service::TransferFromOutfeed( const TransferFromOutfeedRequest* arg, TransferFromOutfeedResponse* result) { - const int64 replica_count = execute_backend_->Replicas().size(); + const int64 replica_count = options_.number_of_replicas(); if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { return FailedPrecondition( "The replica_id=%lld on TransferFromOutfeedRequest not in range [0, " @@ -1004,11 +1028,14 @@ tensorflow::Status Service::TransferFromOutfeed( se::StreamExecutor* executor; if (arg->has_device_handle()) { - TF_ASSIGN_OR_RETURN(auto replicas, execute_backend_->Replicas( - arg->device_handle().handle())); + TF_ASSIGN_OR_RETURN(auto replicas, + Replicas(*execute_backend_, arg->device_handle())); executor = replicas[arg->replica_id()]; } else { - executor = execute_backend_->Replicas()[arg->replica_id()]; + TF_ASSIGN_OR_RETURN( + auto replicas, + Replicas(*execute_backend_, SingleComputationDeviceHandle())); + executor = replicas[arg->replica_id()]; } Literal literal; @@ -1146,11 +1173,14 @@ tensorflow::Status Service::GetComputationStats( VersionedComputationHandle versioned_handle = user_computation->GetVersionedHandle(); + HloModuleConfig config; + config.set_debug_options(arg->debug_options()); TF_ASSIGN_OR_RETURN( std::unique_ptr module, - computation_tracker_.BuildHloModule(versioned_handle, HloModuleConfig())); + computation_tracker_.BuildHloModule(versioned_handle, config)); - MakeHloDumper()(*module, "computation statistics subject"); + hlo_graph_dumper::MaybeDumpHloModule(*module, + "computation statistics subject"); // Run HLO analysis to get the computation statistics. HloCostAnalysis analysis( @@ -1166,17 +1196,6 @@ tensorflow::Status Service::GetComputationStats( return tensorflow::Status::OK(); } -tensorflow::Status Service::CheckRunsInClientProcess( - const string& method_name) const { - if (runs_in_client_process_) { - return tensorflow::Status::OK(); - } else { - return FailedPrecondition( - "%s only supported if service runs in the same process as the client", - method_name.c_str()); - } -} - template tensorflow::Status Service::AddInstruction( const RequestT* arg, ResponseT* result, @@ -1195,6 +1214,10 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { StatusOr handle_status; switch (arg->op_case()) { + case OpRequest::kBatchNormTrainingRequest: + handle_status = computation->AddBatchNormTrainingInstruction( + arg->batch_norm_training_request()); + break; case OpRequest::kBinaryOpRequest: handle_status = computation->AddBinaryInstruction(arg->binary_op_request()); @@ -1277,6 +1300,11 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { computation->AddReduceInstruction(arg->reduce_request(), *to_apply); break; } + case OpRequest::kReducePrecisionRequest: { + handle_status = computation->AddReducePrecisionInstruction( + arg->reduce_precision_request()); + break; + } case OpRequest::kReduceWindowRequest: { TF_ASSIGN_OR_RETURN(UserComputation * to_apply, computation_tracker_.Resolve( @@ -1383,4 +1411,28 @@ tensorflow::Status Service::LoadComputationSnapshot( return tensorflow::Status::OK(); } +DeviceHandle Service::SingleComputationDeviceHandle() const { + DeviceHandle device_handle; + device_handle.set_handle(0); + device_handle.set_device_count(1); + return device_handle; +} + +StatusOr> Service::Replicas( + const Backend& backend, const DeviceHandle& device_handle) const { + std::vector replicas; + for (int replica = 0; replica < options_.number_of_replicas(); ++replica) { + // From the computation placer, find out the device ids of the replicas for + // the given device handle. + TF_ASSIGN_OR_RETURN( + int device_ordinal, + backend.computation_placer()->DeviceId(replica, device_handle.handle(), + options_.number_of_replicas(), + device_handle.device_count())); + TF_ASSIGN_OR_RETURN(auto executor, backend.stream_executor(device_ordinal)); + replicas.push_back(executor); + } + return replicas; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index abd1281bdd0ab76297bc64493ec77bbc35fb552b..48cb0bec4757f55d4d5fcd35f589511fd7393ade 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/channel_tracker.h" #include "tensorflow/compiler/xla/service/compilation_cache.h" -#include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/executable.h" @@ -126,7 +125,7 @@ class Service : public ServiceInterface { // least N * R devices must be available. The devices are assigned based on // the device ordinals such that the first R available devices are assigned to // the first set of replicas, and the next R devices to the second set of - // replicas, etc. Each returned device handles represent the device with the + // replicas, etc. Each returned device handle represents the device with the // replica id 0. tensorflow::Status GetDeviceHandles( const GetDeviceHandlesRequest* arg, @@ -248,7 +247,7 @@ class Service : public ServiceInterface { // The constructor is private. Use the NewService factory to create new // service objects. - Service(std::unique_ptr backend, + Service(const ServiceOptions& options, std::unique_ptr backend, std::unique_ptr compute_constant_backend); static StatusOr> CreateComputeConstantBackend(); @@ -319,14 +318,9 @@ class Service : public ServiceInterface { std::vector> arguments, Backend* backend, - tensorflow::gtl::ArraySlice - executors, + tensorflow::gtl::ArraySlice device_handles, tensorflow::gtl::ArraySlice result_tags); - // Returns an HLO dumper for use in the compiler (it refers to flags - // associated with the service). - static Compiler::HloDumper MakeHloDumper(); - // Convenience function for adding a function to a user computation. template tensorflow::Status AddInstruction( @@ -334,18 +328,24 @@ class Service : public ServiceInterface { const std::function(UserComputation*)>& adder); - // If the service is running in the client process - // (runs_in_client_process_ is true) then return - // tensorflow::Status::OK. Otherwise return an appropriate error - // status with the given method name. Used for "InProcess" methods. - tensorflow::Status CheckRunsInClientProcess(const string& method_name) const; - // Convenience function which checks whether the given shape_with_layout // (presumably passed by the client to set the result layout) is valid for the // given computation result shape. tensorflow::Status ValidateResultShapeWithLayout( const Shape& shape_with_layout, const Shape& result_shape) const; + // Returns the stream executors assigned to the replicas represented by the + // given device handle. Each device_handle is a virtual replicated device that + // represents a set of physical devices for the replicas. + StatusOr> Replicas( + const Backend& backend, const DeviceHandle& device_handle) const; + + // Returns the device handle that represents the replicated device for a + // single computation that is not model-parallelized. + DeviceHandle SingleComputationDeviceHandle() const; + + ServiceOptions options_; + // Tracks computations built via the API. ComputationTracker computation_tracker_; @@ -369,9 +369,6 @@ class Service : public ServiceInterface { // Backend to use when executing ComputeConstant. std::unique_ptr compute_constant_backend_; - // Whether the service runs in the same process as the client. - bool runs_in_client_process_ = false; - TF_DISALLOW_COPY_AND_ASSIGN(Service); }; diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index d6436cf988db7632ecf89f1a1e274a0fbab00ce2..f02df232d8a6ac112b44b51fa7546e9cf945b1b0 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -184,6 +184,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, switch (operation) { case UNOP_FLOOR: case UNOP_CEIL: + case UNOP_COS: case UNOP_EXP: case UNOP_LOG: case UNOP_TANH: @@ -297,6 +298,30 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return ShapeUtil::ChangeElementType(operand_shape, new_element_type); } +/* static */ StatusOr ShapeInference::InferReducePrecisionShape( + const Shape& operand_shape, const int exponent_bits, + const int mantissa_bits) { + if (!ShapeUtil::ElementIsFloating(operand_shape)) { + return InvalidArgument( + "expected element type in shape to be floating point for " + "ReducePrecision operation; got %s", + PrimitiveType_Name(operand_shape.element_type()).c_str()); + } + if (exponent_bits < 1) { + // One exponent bit is necessary to distinguish 0 from infinity. Having + // no exponent bits doesn't produce a sensible number, so we require at + // least one. + return InvalidArgument("expected exponent_bits >= 1; got %d", + exponent_bits); + } + if (mantissa_bits < 0) { + // A number with no mantissa bits is still meaningful, however. + return InvalidArgument("expected non-negative mantissa_bits; got %d", + mantissa_bits); + } + return operand_shape; +} + /* static */ StatusOr ShapeInference::InferPadShape( const Shape& operand_shape, const Shape& padding_value_shape, const PaddingConfig& padding_config) { @@ -525,9 +550,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( ExpectNotTupleOrOpaque(rhs, "rhs of elementwise binary operation")); if (!ShapeUtil::SameElementType(lhs, rhs)) { - return InvalidArgument("binary op with different element types: %s and %s", - ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str()); + return InvalidArgument( + "binary op %s with different element types: %s and %s", + BinaryOperation_Name(operation).c_str(), + ShapeUtil::HumanString(lhs).c_str(), + ShapeUtil::HumanString(rhs).c_str()); } if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs) && @@ -754,6 +781,109 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( AsInt64Slice(arg_shape->dimensions())); } +/* static */ StatusOr ShapeInference::InferBatchNormTrainingShape( + const Shape& operand_shape, const Shape& offset_shape, + const Shape& scale_shape, int64 feature_index) { + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm training")); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( + offset_shape, "offset input of batch norm training")); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( + scale_shape, "scale input of batch norm training")); + + TF_RET_CHECK(ShapeUtil::ValidateShape(operand_shape) == + tensorflow::Status::OK()); + TF_RET_CHECK(ShapeUtil::ValidateShape(offset_shape) == + tensorflow::Status::OK()); + TF_RET_CHECK(ShapeUtil::ValidateShape(scale_shape) == + tensorflow::Status::OK()); + + if (feature_index >= ShapeUtil::Rank(operand_shape)) { + return InvalidArgument( + "Expected feature_index of batch-norm-training to be " + "smaller than the rank of operand_shape; " + "got feature_index %lld, and rank %lld", + feature_index, ShapeUtil::Rank(offset_shape)); + } + + if (feature_index < 0) { + return InvalidArgument( + "Expected feature_index of batch-norm-training to " + "be a non-negative number, got %lld", + feature_index); + } + + if (ShapeUtil::Rank(operand_shape) < 1) { + return InvalidArgument( + "Expected the rank of operand to " + "batch-norm-training to be at least 1; got %lld", + ShapeUtil::Rank(offset_shape)); + } + + if (ShapeUtil::Rank(offset_shape) != 1) { + return InvalidArgument( + "Offset input of batch-norm-training must have" + " rank 1, but has rank %lld.", + ShapeUtil::Rank(offset_shape)); + } + + if (ShapeUtil::Rank(scale_shape) != 1) { + return InvalidArgument( + "Scale input of batch-norm-training must have" + " rank 1, but has rank %lld.", + ShapeUtil::Rank(scale_shape)); + } + + if (!ShapeUtil::ElementIsFloating(operand_shape)) { + return InvalidArgument( + "The operand to batch-norm-training must have a floating point " + "element type, but the shape is %s", + PrimitiveType_Name(operand_shape.element_type()).c_str()); + } + + if (!ShapeUtil::SameElementType(offset_shape, operand_shape)) { + return InvalidArgument( + "The inputs should have the same element type for batch-norm-training, " + "but the shape of offset factor is %s " + "and the shape of operand is %s", + PrimitiveType_Name(offset_shape.element_type()).c_str(), + PrimitiveType_Name(operand_shape.element_type()).c_str()); + } + + if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) { + return InvalidArgument( + "The inputs should have the same element type for batch-norm-training, " + "but the shape of scale factor is %s " + "and the shape of operand is %s", + PrimitiveType_Name(scale_shape.element_type()).c_str(), + PrimitiveType_Name(operand_shape.element_type()).c_str()); + } + + const int64 feature_count = operand_shape.dimensions(feature_index); + Shape output_shape_for_mean_and_var = + ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count}); + + if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) { + return InvalidArgument( + "The size of offset factor should be the same as feature count," + "but the size of offset factor is %lld " + "and the feature count is %lld", + ShapeUtil::GetDimension(offset_shape, 0), feature_count); + } + + if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) { + return InvalidArgument( + "The size of scale factor should be the same as feature count," + "but the size of scale factor is %lld " + "and the feature count is %lld", + ShapeUtil::GetDimension(scale_shape, 0), feature_count); + } + + return ShapeUtil::MakeTupleShape({operand_shape, + output_shape_for_mean_and_var, + output_shape_for_mean_and_var}); +} + /* static */ StatusOr ShapeInference::InferConvolveShape( const Shape& lhs, const Shape& rhs, const Window& window, const ConvolutionDimensionNumbers& dnums) { @@ -1019,6 +1149,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( starts.size(), limits.size()); } + if (starts.size() != strides.size()) { + return InvalidArgument("slice start and strides sizes differ: %zu vs %zu", + starts.size(), strides.size()); + } + if (starts.size() != ShapeUtil::Rank(arg)) { return InvalidArgument( "slice index count does not match argument rank: %zu vs %lld", @@ -1034,9 +1169,6 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument("negative start index to slice: %lld", start_index); } - if (stride == 0) { - return InvalidArgument("Zero stride"); - } if (limit_index > arg.dimensions(dimension)) { return InvalidArgument( "limit index (%lld) must be less than or equal to dimension " @@ -1047,17 +1179,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( start_index); VLOG(2) << tensorflow::strings::Printf("limits[%lld] = %lld", dimension, limit_index); - if (stride > 0) { - if (start_index > limit_index) { - return InvalidArgument( - "limit index (%lld) must be greater or equal to " - "start index (%lld) in slice with positive stride", - limit_index, start_index); - } - sizes.push_back((limit_index - start_index + stride - 1) / stride); - } else { - return InvalidArgument("Negative strides not supported"); + if (start_index > limit_index) { + return InvalidArgument( + "limit index (%lld) must be greater or equal to " + "start index (%lld) in slice with positive stride", + limit_index, start_index); } + if (stride <= 0) { + return InvalidArgument("stride (%lld) must be positive", stride); + } + sizes.push_back((limit_index - start_index + stride - 1) / stride); } return ShapeUtil::MakeShape(arg.element_type(), sizes); @@ -1394,10 +1525,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const ProgramShape& to_apply) { // The applied function's arity equals the number of arguments. if (arg_shapes.size() != to_apply.parameters_size()) { + string computation_signature = ShapeUtil::HumanString(to_apply); + string argument_shapes = tensorflow::str_util::Join( + arg_shapes, ", ", [](string* out, const Shape* shape) { + tensorflow::strings::StrAppend(out, ShapeUtil::HumanString(*shape)); + }); return InvalidArgument( "Call applied function arity must match number of arguments; got: " - "arity: %d, arguments: %zu", - to_apply.parameters_size(), arg_shapes.size()); + "arity: %d, arguments: %zu; computation signature: %s; argument " + "shapes: [%s]", + to_apply.parameters_size(), arg_shapes.size(), + computation_signature.c_str(), argument_shapes.c_str()); } // All arguments must be compatible with the program shape. diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 0d270f99794bd7a17a1df555b9b666a50d4b7e17..42e4c7d39d25c72c077f04a11353d72e6afda245 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -64,6 +64,13 @@ class ShapeInference { tensorflow::gtl::ArraySlice arg_shapes, const ProgramShape& to_apply); + // Infers the shape produced by InferBatchNormTraining with the given + // operands. + static StatusOr InferBatchNormTrainingShape(const Shape& operand_shape, + const Shape& offset_shape, + const Shape& scale_shape, + int64 feature_index); + // Infers the shape produced by applying the given convolutional // filter (rhs) to lhs in the way specified by the fields on window. static StatusOr InferConvolveShape( @@ -165,6 +172,12 @@ class ShapeInference { static StatusOr InferConvertShape(const Shape& operand_shape, PrimitiveType new_element_type); + // Helper that validates the input data type for a reduce-precision operation, + // and returns the result shape. + static StatusOr InferReducePrecisionShape(const Shape& operand_shape, + const int exponent_bits, + const int mantissa_bits); + // Helper that infers the shape produced by a pad operation based on the // padding configuration. static StatusOr InferPadShape(const Shape& operand_shape, diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index 15f6b7bfb4a7f507272471c406bd2ade3ab27b20..c79ffa9cd73950b1653f72b1c6286346f76c10fb 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -65,6 +65,17 @@ class TransferManager { perftools::gputools::StreamExecutor* executor, const Literal& literal) = 0; + // Transfer a memory block of the given size from 'source' buffer to the + // Infeed interface of the device using the given executor. + // + // size is the size to transfer from source in bytes. + // + // source is the source data that must be in the target-dependent layout that + // the Infeed HLO used in the computation expects. + virtual Status TransferBufferToInfeed( + perftools::gputools::StreamExecutor* executor, int64 size, + const void* source) = 0; + // Transfers the given literal from the Outfeed interface of the device, // using the given executor. virtual Status TransferLiteralFromOutfeed( diff --git a/tensorflow/compiler/xla/service/transfer_manager_test.cc b/tensorflow/compiler/xla/service/transfer_manager_test.cc index ca38601d919adfdfd637dab44796ffa4969cc8f2..29ecef9510cfe6b8764c2e5fe1216255ca1dc983 100644 --- a/tensorflow/compiler/xla/service/transfer_manager_test.cc +++ b/tensorflow/compiler/xla/service/transfer_manager_test.cc @@ -55,7 +55,7 @@ class CpuTransferManagerTest : public ::testing::Test { TEST_F(CpuTransferManagerTest, TransferR0U32ToDevice) { std::vector storage(sizeof(uint32), '\x00'); se::DeviceMemoryBase memptr(storage.data(), storage.size()); - std::unique_ptr literal = LiteralUtil::CreateR0(42); + std::unique_ptr literal = Literal::CreateR0(42); TF_CHECK_OK(transfer_manager_.TransferLiteralToDevice(stream_exec_, *literal, &memptr)); @@ -66,7 +66,7 @@ TEST_F(CpuTransferManagerTest, TransferR1F32ToDevice) { std::vector storage(4 * sizeof(float), '\x00'); se::DeviceMemoryBase memptr(storage.data(), storage.size()); std::unique_ptr literal = - LiteralUtil::CreateR1({1.25f, 2.5f, -17.0f, -20.125f}); + Literal::CreateR1({1.25f, 2.5f, -17.0f, -20.125f}); TF_CHECK_OK(transfer_manager_.TransferLiteralToDevice(stream_exec_, *literal, &memptr)); @@ -80,7 +80,7 @@ TEST_F(CpuTransferManagerTest, TransferR1U8ToDevice) { std::vector storage(16, '\x00'); se::DeviceMemoryBase memptr(storage.data(), storage.size()); const char* str = "0123456789abcdef"; - std::unique_ptr literal = LiteralUtil::CreateR1U8(str); + std::unique_ptr literal = Literal::CreateR1U8(str); TF_CHECK_OK(transfer_manager_.TransferLiteralToDevice(stream_exec_, *literal, &memptr)); diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index c72d127ea86e4e9daf99dff4335c538c081f0605..9520c42d280968e3f21a110089583c94277ef1a6 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -92,11 +92,11 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) { auto builder = HloComputation::Builder("entry_computation"); // 2x1 HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR2({{1}, {2}}))); + HloInstruction::CreateConstant(Literal::CreateR2({{1}, {2}}))); // 3x2 HloInstruction* const1 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}))); + Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}))); HloInstruction* transpose0 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {1, 2}), const0, {1, 0})); @@ -130,11 +130,11 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { auto builder = HloComputation::Builder("entry"); // (1.0 + 2.0) * (2.0 - 3.0) HloInstruction* const1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); HloInstruction* const2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); HloInstruction* const3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + HloInstruction::CreateConstant(Literal::CreateR0(3.0))); HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( const1->shape(), HloOpcode::kAdd, const1, const2)); HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary( diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index ad6f015c70e7241af815246b732fa02768cf0a10..182e99cf1ca88d9037fa7110fa5fe24241057614 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -33,9 +33,9 @@ limitations under the License. namespace xla { string BufferAlias::ToString() const { - return tensorflow::strings::StrCat( - "BufferAlias(", instruction_->FullyQualifiedName(), "[", - tensorflow::str_util::Join(index_, ","), "])"); + return tensorflow::strings::StrCat("BufferAlias(", instruction_->name(), "[", + tensorflow::str_util::Join(index_, ","), + "])"); } std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias) { @@ -125,18 +125,13 @@ void PointsToSet::add_tuple_source(const ShapeIndex& index, } /* static */ StatusOr> -TuplePointsToAnalysis::Run(const HloModule* module, Colorer colorer) { +TuplePointsToAnalysis::Run(const HloModule* module) { std::unique_ptr analysis( - new TuplePointsToAnalysis(module, std::move(colorer))); + new TuplePointsToAnalysis(module)); TF_RETURN_IF_ERROR(analysis->Analyze()); return std::move(analysis); } -/* static */ StatusOr> -TuplePointsToAnalysis::Run(const HloModule* module) { - return Run(module, DefaultColorer()); -} - Status TuplePointsToAnalysis::Analyze() { points_to_.clear(); for (auto& computation : module_->computations()) { @@ -171,9 +166,6 @@ Status TuplePointsToAnalysis::PopulateDefinedBuffersAndAliases( const ShapeIndex& index, const std::vector& pointed_to_buffers) { for (const LogicalBuffer* buffer : pointed_to_buffers) { - if (buffer_aliases_.count(buffer) == 0) { - buffer_aliases_.insert({buffer, std::vector()}); - } buffer_aliases_[buffer].emplace_back(instruction.get(), index); } }); @@ -184,8 +176,8 @@ Status TuplePointsToAnalysis::PopulateDefinedBuffersAndAliases( const LogicalBuffer& TuplePointsToAnalysis::NewLogicalBuffer( HloInstruction* instruction, const ShapeIndex& index) { CHECK_EQ(logical_buffers_.size(), next_buffer_id_); - logical_buffers_.push_back(MakeUnique( - instruction, index, next_buffer_id_, colorer_(instruction, index))); + logical_buffers_.push_back( + MakeUnique(instruction, index, next_buffer_id_)); ++next_buffer_id_; return *logical_buffers_.back(); } @@ -243,12 +235,11 @@ Status TuplePointsToAnalysis::HandleGetTupleElement( return Status::OK(); } -Status TuplePointsToAnalysis::HandleCopy(HloInstruction* copy, - HloInstruction* operand) { +Status TuplePointsToAnalysis::HandleCopy(HloInstruction* copy) { // A kCopy instruction performs a shallow copy of the operand. The top-level // buffer (index={}) is newly created, but all other buffers (in the case of a // tuple shape) come from the operand - PointsToSet& points_to_set = CreateCopiedPointsToSet(copy, operand); + PointsToSet& points_to_set = CreateCopiedPointsToSet(copy, copy->operand(0)); points_to_set.mutable_element(/*index=*/{})->clear(); points_to_set.AddPointedToBuffer(NewLogicalBuffer(copy, /*index=*/{}), /*index=*/{}); @@ -343,9 +334,11 @@ const PointsToSet& TuplePointsToAnalysis::GetPointsToSet( PointsToSet& TuplePointsToAnalysis::CreateEmptyPointsToSet( const HloInstruction* instruction) { - CHECK_EQ(0, points_to_.count(instruction)); - points_to_[instruction] = MakeUnique(instruction->shape()); - return *FindOrDie(points_to_, instruction); + auto set = MakeUnique(&instruction->shape()); + auto res = points_to_.emplace(instruction, std::move(set)); + CHECK(res.second) << "instruction should not have been present in the map."; + // Return *set using the iterator returned by emplace. + return *res.first->second; } bool TuplePointsToAnalysis::InstructionDefinesBufferAtIndex( diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index 4d7fc7cbc9e5ba2ac87dc6fd10691ce308b827f6..be821d5154f61f2d988890fd55539ad4b8fa2eac 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -48,7 +48,10 @@ namespace xla { // the corresponding buffer. class PointsToSet : public ShapeTree> { public: - explicit PointsToSet(const Shape& shape) + // Construct our ShapeTree with a pointer rather than a reference to a Shape + // because this is very hot code, and copying (and then destroying) all these + // Shapes is slow. + explicit PointsToSet(const Shape* shape) : ShapeTree>(shape), tuple_sources_(shape) {} @@ -142,15 +145,10 @@ std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias); // the potential sources of each buffer in each instruction's output. class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { public: - using Colorer = std::function; + using Colorer = + std::function; - // Runs points-to analysis on 'module' with the provided buffer color - // assigner. - static StatusOr> Run( - const HloModule* module, Colorer colorer); - - // Runs points-to analysis on 'module' with the default color assigner. + // Runs points-to analysis on 'module'. static StatusOr> Run( const HloModule* module); @@ -208,7 +206,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { Status HandleGetTupleElement(HloInstruction* get_tuple_element, HloInstruction* operand) override; Status HandleBitcast(HloInstruction* bitcast) override; - Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override; + Status HandleCopy(HloInstruction* copy) override; Status HandleSelect(HloInstruction* select, HloInstruction* pred, HloInstruction* on_true, HloInstruction* on_false) override; @@ -216,15 +214,16 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { string ToString() const; static Colorer DefaultColorer() { - return [](const HloInstruction* instruction, const ShapeIndex& index) { - return LogicalBuffer::Color(0); + return [](TuplePointsToAnalysis* points_to_analysis) { + for (auto& buffer : points_to_analysis->logical_buffers()) { + buffer->set_color(LogicalBuffer::Color(0)); + } + return Status::OK(); }; } private: - explicit TuplePointsToAnalysis(const HloModule* module, - Colorer colorer = DefaultColorer()) - : module_(module), colorer_(colorer) {} + explicit TuplePointsToAnalysis(const HloModule* module) : module_(module) {} // Perform the analysis. Should be called immediately after constructing the // object and before calling GetPointsToSet. @@ -283,9 +282,6 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { // The ID of the next logical buffer created. LogicalBuffer::Id next_buffer_id_ = 0; - // Used to color the created logical buffers. - Colorer colorer_; - TF_DISALLOW_COPY_AND_ASSIGN(TuplePointsToAnalysis); }; diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index 9909c11929d4b2ecf632ab644981a039446bdfc8..cd79e63cafcfecce71cf3380aba9e409da0e72c8 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -124,9 +124,9 @@ class TuplePointsToAnalysisTest : public HloTestBase { TEST_F(TuplePointsToAnalysisTest, SimpleTuple) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); @@ -177,14 +177,14 @@ TEST_F(TuplePointsToAnalysisTest, NestedTuple) { // tuple. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto inner_tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + HloInstruction::CreateConstant(Literal::CreateR0(3.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({inner_tuple, constant3})); @@ -238,14 +238,14 @@ TEST_F(TuplePointsToAnalysisTest, GetTupleElement) { // tuple. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto inner_tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + HloInstruction::CreateConstant(Literal::CreateR0(3.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({inner_tuple, constant3})); @@ -270,7 +270,7 @@ TEST_F(TuplePointsToAnalysisTest, DuplicatedElement) { // Create a tuple which contains duplicate elements. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant, constant, constant})); @@ -291,9 +291,9 @@ TEST_F(TuplePointsToAnalysisTest, TupleCopy) { // the same. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto copy = builder.AddInstruction( @@ -318,16 +318,16 @@ TEST_F(TuplePointsToAnalysisTest, TupleSelect) { // set containing the union of both sides. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto tuple2 = builder.AddInstruction( HloInstruction::CreateTuple({constant2, constant2})); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); @@ -356,7 +356,7 @@ TEST_F(TuplePointsToAnalysisTest, SelectTupleParameters) { auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, tuple_shape, "param1")); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( tuple_shape, HloOpcode::kSelect, pred, param0, param1)); auto copy = builder.AddInstruction( @@ -396,16 +396,16 @@ TEST_F(TuplePointsToAnalysisTest, UnambiguousTupleSelect) { // Select from two identical tuples. The result should not be ambiguous. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto tuple2 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); @@ -427,9 +427,9 @@ TEST_F(TuplePointsToAnalysisTest, NestedTupleSelect) { // the right values. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto inner_tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto inner_tuple2 = builder.AddInstruction( @@ -441,7 +441,7 @@ TEST_F(TuplePointsToAnalysisTest, NestedTupleSelect) { builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple2})); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); @@ -474,9 +474,9 @@ TEST_F(TuplePointsToAnalysisTest, TupleWithBitcast) { // have the operand of the bitcast in its points-to set. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary( constant2->shape(), HloOpcode::kBitcast, constant2)); auto tuple = @@ -510,10 +510,9 @@ TEST_F(TuplePointsToAnalysisTest, PointsToTupleConstantElements) { // Construct a tuple constant and kCopy it. Verify the points-to set of the // copy correctly correctly points into the nested elements of the constant. auto builder = HloComputation::Builder(TestName()); - auto tuple_constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0}, {2.0}}).get(), - LiteralUtil::CreateR1({2.0, 42}).get()}))); + auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::MakeTuple({Literal::CreateR2({{1.0}, {2.0}}).get(), + Literal::CreateR1({2.0, 42}).get()}))); auto copy = builder.AddInstruction(HloInstruction::CreateUnary( tuple_constant->shape(), HloOpcode::kCopy, tuple_constant)); @@ -533,9 +532,9 @@ TEST_F(TuplePointsToAnalysisTest, BufferAliases) { // times. Verify buffer alias sets. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto inner_tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto tuple = builder.AddInstruction( @@ -574,7 +573,7 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest { auto tuple_element1 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(update_shape, tuple_param0, 1)); auto ones = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.f, 1.f, 1.f, 1.f}))); + Literal::CreateR1({1.f, 1.f, 1.f, 1.f}))); // Create 'update' = Add(GetTupleElement(tuple_param0, 1), ones) auto update = builder.AddInstruction(HloInstruction::CreateBinary( update_shape, HloOpcode::kAdd, tuple_element1, ones)); diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index 4aba8875161c9a2d12668d57ea55ded066d38da0..90a24fb44d164069d0efbb5ccd36577549c7aedf 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -22,7 +22,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -49,6 +48,8 @@ HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) { return HloOpcode::kAbs; case UNOP_CEIL: return HloOpcode::kCeil; + case UNOP_COS: + return HloOpcode::kCos; case UNOP_EXP: return HloOpcode::kExp; case UNOP_FLOOR: @@ -465,6 +466,45 @@ StatusOr UserComputation::AddReduceInstruction( return handle; } +StatusOr +UserComputation::AddBatchNormTrainingInstruction( + const BatchNormTrainingRequest& batch_norm_training_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, + LookUpRequest(batch_norm_training_request.operand())); + + TF_ASSIGN_OR_RETURN(const OperationRequest* scale, + LookUpRequest(batch_norm_training_request.scale())); + + TF_ASSIGN_OR_RETURN(const OperationRequest* offset, + LookUpRequest(batch_norm_training_request.offset())); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + + TF_ASSIGN_OR_RETURN( + Shape inferred_shape, + ShapeInference::InferBatchNormTrainingShape( + operand->output_shape(), scale->output_shape(), + offset->output_shape(), batch_norm_training_request.feature_index())); + + *request.mutable_output_shape() = inferred_shape; + + *request.mutable_output_handle() = handle; + + *request.mutable_request()->mutable_batch_norm_training_request() = + batch_norm_training_request; + + VLOG(1) << "AddBatchNormTrainingInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << batch_norm_training_request.ShortDebugString(); + + return handle; +} + StatusOr UserComputation::AddReduceWindowInstruction( const ReduceWindowRequest& reduce_window_request, const UserComputation& to_apply_computation) { @@ -841,6 +881,34 @@ StatusOr UserComputation::AddConvertInstruction( return handle; } +StatusOr UserComputation::AddReducePrecisionInstruction( + const ReducePrecisionRequest& reduce_precision_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, + LookUpRequest(reduce_precision_request.operand())); + + TF_ASSIGN_OR_RETURN( + Shape new_shape, + ShapeInference::InferReducePrecisionShape( + operand->output_shape(), reduce_precision_request.exponent_bits(), + reduce_precision_request.mantissa_bits())); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = new_shape; + *request.mutable_request()->mutable_reduce_precision_request() = + reduce_precision_request; + + VLOG(1) << "AddReducePrecisionInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << reduce_precision_request.ShortDebugString(); + return handle; +} + StatusOr UserComputation::AddConvolveInstruction( const ConvolveRequest& convolve_request) { tensorflow::mutex_lock lock(mutex_); @@ -1556,6 +1624,19 @@ void ConstantVisitor(const SessionComputation& session_computation, break; } + case OpRequest::kBatchNormTrainingRequest: { + const BatchNormTrainingRequest& batch_norm_training_request = + request.request().batch_norm_training_request(); + ConstantVisitor(session_computation, + batch_norm_training_request.operand(), visited, + is_constant); + ConstantVisitor(session_computation, batch_norm_training_request.scale(), + visited, is_constant); + ConstantVisitor(session_computation, batch_norm_training_request.offset(), + visited, is_constant); + break; + } + case OpRequest::kBinaryOpRequest: { const BinaryOpRequest& binary_op_request = request.request().binary_op_request(); @@ -1824,7 +1905,6 @@ Status UserComputation::CheckParametersAreContiguous( } } - auto program_shape = MakeUnique(); for (int64 i = 0; i < parameter_requests.size(); ++i) { auto it = parameter_requests.find(i); if (it == parameter_requests.end()) { @@ -1850,26 +1930,31 @@ class ComputationLowerer { const SessionComputation& session_computation, VersionedComputationHandle::Version version, UserComputation::HloComputationResolver hlo_resolver, + const DebugOptions& debug_options, bool include_unreachable_instructions) { ComputationLowerer lowerer(computation_name, session_computation, version, - std::move(hlo_resolver)); - return lowerer.Lower(include_unreachable_instructions); + std::move(hlo_resolver), debug_options, + include_unreachable_instructions); + return lowerer.Lower(); } private: ComputationLowerer(const string& computation_name, const SessionComputation& session_computation, VersionedComputationHandle::Version version, - UserComputation::HloComputationResolver hlo_resolver) + UserComputation::HloComputationResolver hlo_resolver, + const DebugOptions& debug_options, + bool include_unreachable_instructions) : hlo_builder_(computation_name), session_computation_(session_computation), version_(version), - hlo_resolver_(std::move(hlo_resolver)) {} + hlo_resolver_(std::move(hlo_resolver)), + debug_options_(debug_options), + include_unreachable_instructions_(include_unreachable_instructions) {} // Build an HLO computation from the SessionComputation at the given // version. - StatusOr> Lower( - bool include_unreachable_instructions); + StatusOr> Lower(); private: // Traverses the computation 'root' using a DFS, calling 'visit' in postorder. @@ -1899,6 +1984,8 @@ class ComputationLowerer { const SessionComputation& session_computation_; const VersionedComputationHandle::Version version_; const UserComputation::HloComputationResolver hlo_resolver_; + const DebugOptions& debug_options_; + const bool include_unreachable_instructions_; }; // Calls 'apply' on each operand of 'request'. @@ -1964,6 +2051,16 @@ static void ForEachOperand( break; } + case OpRequest::kBatchNormTrainingRequest: { + const BatchNormTrainingRequest& batch_norm_training_request = + request.request().batch_norm_training_request(); + + apply(batch_norm_training_request.operand()); + apply(batch_norm_training_request.scale()); + apply(batch_norm_training_request.offset()); + break; + } + case OpRequest::kCrossReplicaSumRequest: { const CrossReplicaSumRequest& cross_replica_sum_request = request.request().cross_replica_sum_request(); @@ -2117,6 +2214,13 @@ static void ForEachOperand( break; } + case OpRequest::kReducePrecisionRequest: { + const ReducePrecisionRequest& reduce_precision_request = + request.request().reduce_precision_request(); + apply(reduce_precision_request.operand()); + break; + } + case OpRequest::kTraceRequest: { const TraceRequest& trace_request = request.request().trace_request(); apply(trace_request.operand()); @@ -2175,8 +2279,7 @@ void ComputationLowerer::TraversePostorder( } } -StatusOr> ComputationLowerer::Lower( - bool include_unreachable_instructions) { +StatusOr> ComputationLowerer::Lower() { // Map from ComputationDataHandle to HLO instruction. Serves as a record of // which operations have been visited as well as a cache for looking up // ComputationDataHandles as HloInstructions. @@ -2192,7 +2295,7 @@ StatusOr> ComputationLowerer::Lower( HloInstruction* hlo_root = instructions.at(root_request->output_handle().handle()); - if (include_unreachable_instructions) { + if (include_unreachable_instructions_) { // Iterate through all computation data handles, and visit any unvisited // operations. for (int64 request_num = 1; request_num <= version_; ++request_num) { @@ -2276,7 +2379,7 @@ void ComputationLowerer::Visit( const ConstantRequest& constant_request = request.request().constant_request(); hlo_instruction = add_instruction(HloInstruction::CreateConstant( - LiteralUtil::CloneToUnique(Literal(constant_request.literal())))); + Literal(constant_request.literal()).CloneToUnique())); break; } @@ -2457,6 +2560,23 @@ void ComputationLowerer::Visit( break; } + case OpRequest::kBatchNormTrainingRequest: { + const BatchNormTrainingRequest& batch_norm_training_request = + request.request().batch_norm_training_request(); + HloInstruction* operand = + lookup_instruction(batch_norm_training_request.operand()); + HloInstruction* scale = + lookup_instruction(batch_norm_training_request.scale()); + HloInstruction* offset = + lookup_instruction(batch_norm_training_request.offset()); + + hlo_instruction = add_instruction(HloInstruction::CreateBatchNormTraining( + request.output_shape(), operand, scale, offset, + batch_norm_training_request.epsilon(), + batch_norm_training_request.feature_index())); + break; + } + case OpRequest::kBroadcastRequest: { const BroadcastRequest& broadcast_request = request.request().broadcast_request(); @@ -2670,8 +2790,7 @@ void ComputationLowerer::Visit( lhs = (lhs == operand_to_broadcast) ? broadcasted_operand : lhs; rhs = (rhs == operand_to_broadcast) ? broadcasted_operand : rhs; } - if (legacy_flags::GetUserComputationFlags() - ->xla_eliminate_hlo_implicit_broadcast) { + if (debug_options_.xla_eliminate_hlo_implicit_broadcast()) { if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) { // lhs side is being implicitly broadcast. Change to explicit. lhs = @@ -2688,6 +2807,18 @@ void ComputationLowerer::Visit( break; } + case OpRequest::kReducePrecisionRequest: { + const ReducePrecisionRequest& reduce_precision_request = + request.request().reduce_precision_request(); + HloInstruction* operand = + lookup_instruction(reduce_precision_request.operand()); + auto exponent_bits = reduce_precision_request.exponent_bits(); + auto mantissa_bits = reduce_precision_request.mantissa_bits(); + hlo_instruction = add_instruction(HloInstruction::CreateReducePrecision( + request.output_shape(), operand, exponent_bits, mantissa_bits)); + break; + } + case OpRequest::kTraceRequest: { const TraceRequest& trace_request = request.request().trace_request(); HloInstruction* operand = lookup_instruction(trace_request.operand()); @@ -2718,7 +2849,7 @@ void ComputationLowerer::Visit( StatusOr> UserComputation::BuildHloComputation( VersionedComputationHandle::Version version, - HloComputationResolver hlo_resolver, + HloComputationResolver hlo_resolver, const DebugOptions& debug_options, bool include_unreachable_instructions) const { tensorflow::mutex_lock lock(mutex_); @@ -2730,7 +2861,7 @@ StatusOr> UserComputation::BuildHloComputation( std::unique_ptr hlo_computation, ComputationLowerer::Lower( tensorflow::strings::StrCat(name(), ".v", version), - session_computation_, version, std::move(hlo_resolver), + session_computation_, version, std::move(hlo_resolver), debug_options, include_unreachable_instructions)); XLA_VLOG_LINES(2, hlo_computation->ToString()); diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h index fb5425ae61ab1edcd00aac493c9e2ac3c430cb72..3cc3bd0918de40fb7042c0d304380b97da9abef9 100644 --- a/tensorflow/compiler/xla/service/user_computation.h +++ b/tensorflow/compiler/xla/service/user_computation.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" @@ -84,6 +85,10 @@ class UserComputation { StatusOr AddUnaryInstruction( const UnaryOpRequest& unary_request); + // Enqueues a batch norm training instruction onto this user computation. + StatusOr AddBatchNormTrainingInstruction( + const BatchNormTrainingRequest& batch_norm_training_request); + // Enqueues a binary instruction onto this user computation. // Returns an error status if the operand indices are out of bounds. StatusOr AddBinaryInstruction( @@ -112,6 +117,10 @@ class UserComputation { const MapRequest& map_request, const UserComputation& to_apply_computation); + // Enqueues a reduce-precision instruction onto this user computation. + StatusOr AddReducePrecisionInstruction( + const ReducePrecisionRequest& reduce_precision_request); + // Enqueues a convolution instruction onto this user computation. StatusOr AddConvolveInstruction( const ConvolveRequest& convolve_request); @@ -256,7 +265,7 @@ class UserComputation { std::function; StatusOr> BuildHloComputation( VersionedComputationHandle::Version version, - HloComputationResolver hlo_resolver, + HloComputationResolver hlo_resolver, const DebugOptions& debug_options, bool include_unreachable_instructions = true) const; // Return a vector containing the embedded computations used by this diff --git a/tensorflow/compiler/xla/service/user_computation_test.cc b/tensorflow/compiler/xla/service/user_computation_test.cc index ea691201263e4935afbc29bcb8624a73c6715f83..0d50810dc4088e47c793592576cd72597419e87d 100644 --- a/tensorflow/compiler/xla/service/user_computation_test.cc +++ b/tensorflow/compiler/xla/service/user_computation_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/user_computation.h" -#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -50,7 +50,7 @@ TEST_F(UserComputationTest, SimpleComputation) { ConstantRequest constant_request; *constant_request.mutable_literal() = - LiteralUtil::CreateR1({123.0f, 42.0f})->ToProto(); + Literal::CreateR1({123.0f, 42.0f})->ToProto(); TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle constant_handle, computation.AddConstantInstruction(constant_request)); @@ -92,7 +92,8 @@ TEST_F(UserComputationTest, SimpleComputation) { // Build the HLO computation. TF_ASSIGN_OR_ASSERT_OK( std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver)); + computation.BuildHloComputation(latest_version.version, hlo_resolver, + DebugOptions())); // There should be one HloInstruction per UserComputation operation. EXPECT_EQ(3, hlo_computation->instruction_count()); // The root of the instruction should be the parameter instruction (not the @@ -117,9 +118,10 @@ TEST_F(UserComputationTest, SimpleComputation) { // There should be two instructions, one for the constant and one for the // parameter. The outfeed instruction should not be included. - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr hlo_computation, - computation.BuildHloComputation( - version_at_param.version, hlo_resolver)); + TF_ASSIGN_OR_ASSERT_OK( + std::unique_ptr hlo_computation, + computation.BuildHloComputation(version_at_param.version, hlo_resolver, + DebugOptions())); EXPECT_EQ(2, hlo_computation->instruction_count()); EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter()); } @@ -130,10 +132,11 @@ TEST_F(UserComputationTest, SimpleComputation) { computation.GetVersionedHandle(); // Build the HLO computation. - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr hlo_computation, - computation.BuildHloComputation( - latest_version.version, hlo_resolver, - /*include_unreachable_instructions=*/false)); + TF_ASSIGN_OR_ASSERT_OK( + std::unique_ptr hlo_computation, + computation.BuildHloComputation( + latest_version.version, hlo_resolver, DebugOptions(), + /*include_unreachable_instructions=*/false)); // There is only one reachable instruction, the parameter. EXPECT_EQ(1, hlo_computation->instruction_count()); // The root of the instruction should be the parameter instruction (not the @@ -145,8 +148,8 @@ TEST_F(UserComputationTest, SimpleComputation) { } TEST_F(UserComputationTest, EliminateScalarBroadcast) { - if (!legacy_flags::GetUserComputationFlags() - ->xla_eliminate_hlo_implicit_broadcast) { + if (!legacy_flags::GetDebugOptionsFromFlags() + .xla_eliminate_hlo_implicit_broadcast()) { return; } @@ -161,12 +164,12 @@ TEST_F(UserComputationTest, EliminateScalarBroadcast) { ConstantRequest a_request; *a_request.mutable_literal() = - LiteralUtil::CreateR1({123.0f, 42.0f})->ToProto(); + Literal::CreateR1({123.0f, 42.0f})->ToProto(); TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle a_handle, computation.AddConstantInstruction(a_request)); ConstantRequest b_request; - *b_request.mutable_literal() = LiteralUtil::CreateR0(1.0f)->ToProto(); + *b_request.mutable_literal() = Literal::CreateR0(1.0f)->ToProto(); TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle b_handle, computation.AddConstantInstruction(b_request)); @@ -184,7 +187,8 @@ TEST_F(UserComputationTest, EliminateScalarBroadcast) { // Build the HLO computation. TF_ASSIGN_OR_ASSERT_OK( std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver)); + computation.BuildHloComputation(latest_version.version, hlo_resolver, + DebugOptions())); // The binary operation has implicit scalar broadcast, should be converted // to an explicit broadcast intruction and a binary instruction. EXPECT_EQ(4, hlo_computation->instruction_count()); @@ -196,8 +200,8 @@ TEST_F(UserComputationTest, EliminateScalarBroadcast) { } TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) { - if (!legacy_flags::GetUserComputationFlags() - ->xla_eliminate_hlo_implicit_broadcast) { + if (!legacy_flags::GetDebugOptionsFromFlags() + .xla_eliminate_hlo_implicit_broadcast()) { return; } @@ -240,7 +244,8 @@ TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) { // Build the HLO computation. TF_ASSIGN_OR_ASSERT_OK( std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver)); + computation.BuildHloComputation(latest_version.version, hlo_resolver, + DebugOptions())); // The binary operation has in-dim broadcast and degenerate broadcast, should // first do the in-dim broadcast then convert the degnerate broadcast into a @@ -266,7 +271,7 @@ TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) { int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendUserComputationFlags(&flag_list); + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index cc456df4fce5c78162c41ed36f6c69c0f5ab459b..81cdbf5117f2d16e5a871849a7875b1746baf42a 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -81,19 +82,56 @@ struct ShapeTreeNode { // Like the Shape data structure, this is a tree and tuple elements cannot be // duplicated. That is, every distinct ShapeIndex in the Shape has a unique T // object. +// +// Normally a ShapeTree owns its Shape, but for efficiency reasons, sometimes +// it's helpful not to copy a Shape just to make a ShapeTree. In these cases, +// you can pass a Shape* instead of a Shape& to the ShapeTree constructor. It's +// then up to you to ensure that the pointed-to Shape doesn't die or mutate +// before its ShapeTree goes away. template class ShapeTree { public: // Default constructor creates a tree with a nil shape (i.e. an empty tuple). ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {} + // Create ShapeTree with the given shape, and default-constructed T values for // all nodes. - explicit ShapeTree(const Shape& shape); + // + // The version that takes a pointer may be cheaper because it doesn't require + // any Shape copies, but then it's up to you to ensure that the pointer stays + // alive longer than this ShapeTree. + explicit ShapeTree(Shape shape); + explicit ShapeTree(const Shape* shape); + // Create ShapeTree with the given shape, and init_value for all nodes. - ShapeTree(const Shape& shape, const T& init_value); + ShapeTree(Shape shape, const T& init_value); + ShapeTree(const Shape* shape, const T& init_value); + + ShapeTree(const ShapeTree& other) + : root_(other.root_), shape_storage_(other.shape_storage_) { + // Fix up internal pointer if necessary. + if (shape_storage_) { + CHECK_EQ(other.shape_, &*other.shape_storage_); + shape_ = &*shape_storage_; + } else { + shape_ = other.shape_; + } + } - ShapeTree(const ShapeTree& other) = default; - ShapeTree& operator=(const ShapeTree& other) = default; + ShapeTree& operator=(const ShapeTree& other) { + root_ = other.root_; + shape_storage_ = other.shape_storage_; + + // Fix up internal pointer if necessary. + if (shape_storage_) { + CHECK_EQ(other.shape_, &*other.shape_storage_); + shape_ = &*shape_storage_; + } else { + shape_ = other.shape_; + } + + return *this; + } // Returns the data element associated with the array in the shape at the // given index (see ShapeUtil::GetSubshape for how indexes are defined). @@ -101,7 +139,7 @@ class ShapeTree { T* mutable_element(const ShapeIndex& index); // Return the shape represented with this ShapeTree. - const Shape& shape() const { return shape_; } + const Shape& shape() const { return *shape_; } // Returns true if the node at the given index is a leaf node (an array // shape). @@ -112,27 +150,27 @@ class ShapeTree { // Recursively traverses the shape and calls the given function at each // element. The function has the following arguments: // + // Fn : A callable of type void(const ShapeIndex& index, const T& data) + // (or compatible). // index : the index of the element in the shape. See ShapeUtil::GetSubshape // for definition of index. // data : The data value at this elemnt. - using VisitorFunction = - std::function; - void ForEachElement(const VisitorFunction& func) const; - - using MutableVisitorFunction = - std::function; - void ForEachMutableElement(const MutableVisitorFunction& func); + template + void ForEachElement(const Fn& func) const; - // Variants of ForEach(Mutable)Element which propagate a Status value from the - // visitor. - using StatusVisitorFunction = - std::function; - Status ForEachElementWithStatus(const StatusVisitorFunction& func) const; + // Like ForEachElement, but the callable has type + // + // void (const ShapeIndex& index, T* data). + // + template + void ForEachMutableElement(const Fn& func); - using MutableStatusVisitorFunction = - std::function; - Status ForEachMutableElementWithStatus( - const MutableStatusVisitorFunction& func); + // Like ForEach(Mutable)Element, but the callable returns a Status instead of + // void. The first non-OK return value is returned by the ForEach* function. + template + Status ForEachElementWithStatus(const Fn& func) const; + template + Status ForEachMutableElementWithStatus(const Fn& func); // Copy the subtree of values from 'other' rooted at ShapeIndex // 'source_base_index' into the subtree of value in this ShapeTree rooted at @@ -161,10 +199,12 @@ class ShapeTree { // Helpers for traversing the shape via ForEachElement. The helpers // recursively traverse the subtree rooted at "index" (defined as in // ShapeUtil::GetSubshape). - static Status ForEachHelper(const StatusVisitorFunction& func, - const Node& node, ShapeIndex* index); - static Status ForEachMutableHelper(const MutableStatusVisitorFunction& func, - Node* node, ShapeIndex* index); + template + static Status ForEachHelper(const Fn& func, const Node& node, + ShapeIndex* index); + template + static Status ForEachMutableHelper(const Fn& func, Node* node, + ShapeIndex* index); // Return the tree node at the given index. Node* Lookup(const ShapeIndex& index); @@ -173,8 +213,13 @@ class ShapeTree { // The root node, which contains all other nodes. Node root_; - // The XLA shape mirrored in this ShapeTree. - Shape shape_; + // If we own our Shape, this field contains it, and shape_ is a pointer into + // here. Otherwise if we don't own our shape, this is nullopt. + tensorflow::gtl::optional shape_storage_; + + // The XLA shape mirrored in this ShapeTree. This is either a pointer into + // shape_storage_ or the Shape pointer passed to our constructor. + const Shape* shape_; }; template @@ -200,20 +245,34 @@ void ShapeTree::InitChildren(const Shape& shape, Node* node) { } template -ShapeTree::ShapeTree(const Shape& shape) : root_(), shape_(shape) { +ShapeTree::ShapeTree(Shape shape) + : root_(), shape_storage_(std::move(shape)), shape_(&*shape_storage_) { // The shape_ field is just used to hold the structure of the shape. // It should not be relied upon to store layout information. - LayoutUtil::ClearLayout(&shape_); - InitChildren(shape_, &root_); + LayoutUtil::ClearLayout(&*shape_storage_); + InitChildren(*shape_, &root_); } template -ShapeTree::ShapeTree(const Shape& shape, const T& init_value) - : root_(init_value), shape_(shape) { +ShapeTree::ShapeTree(const Shape* shape) : root_(), shape_(shape) { + InitChildren(*shape_, &root_); +} + +template +ShapeTree::ShapeTree(Shape shape, const T& init_value) + : root_(init_value), + shape_storage_(std::move(shape)), + shape_(&*shape_storage_) { // The shape_ field is just used to hold the structure of the shape. // It should not be relied upon to store layout information. - LayoutUtil::ClearLayout(&shape_); - InitChildren(shape_, init_value, &root_); + LayoutUtil::ClearLayout(&*shape_storage_); + InitChildren(*shape_, init_value, &root_); +} + +template +ShapeTree::ShapeTree(const Shape* shape, const T& init_value) + : root_(init_value), shape_(shape) { + InitChildren(*shape_, init_value, &root_); } template @@ -245,8 +304,9 @@ const internal::ShapeTreeNode* ShapeTree::Lookup( /* static */ template -Status ShapeTree::ForEachHelper(const StatusVisitorFunction& func, - const Node& node, ShapeIndex* index) { +template +Status ShapeTree::ForEachHelper(const Fn& func, const Node& node, + ShapeIndex* index) { TF_RETURN_IF_ERROR(func(*index, node.data)); for (int64 i = 0; i < node.children.size(); ++i) { index->push_back(i); @@ -258,8 +318,9 @@ Status ShapeTree::ForEachHelper(const StatusVisitorFunction& func, /* static */ template -Status ShapeTree::ForEachMutableHelper( - const MutableStatusVisitorFunction& func, Node* node, ShapeIndex* index) { +template +Status ShapeTree::ForEachMutableHelper(const Fn& func, Node* node, + ShapeIndex* index) { TF_RETURN_IF_ERROR(func(*index, &node->data)); for (int64 i = 0; i < node->children.size(); ++i) { index->push_back(i); @@ -271,21 +332,22 @@ Status ShapeTree::ForEachMutableHelper( } template -Status ShapeTree::ForEachElementWithStatus( - const StatusVisitorFunction& func) const { +template +Status ShapeTree::ForEachElementWithStatus(const Fn& func) const { ShapeIndex index; return ForEachHelper(func, root_, &index); } template -Status ShapeTree::ForEachMutableElementWithStatus( - const MutableStatusVisitorFunction& func) { +template +Status ShapeTree::ForEachMutableElementWithStatus(const Fn& func) { ShapeIndex index; return ForEachMutableHelper(func, &root_, &index); } template -void ShapeTree::ForEachElement(const VisitorFunction& func) const { +template +void ShapeTree::ForEachElement(const Fn& func) const { ShapeIndex index; return ForEachHelper( [&func](const ShapeIndex& index, const T& data) { @@ -297,7 +359,8 @@ void ShapeTree::ForEachElement(const VisitorFunction& func) const { } template -void ShapeTree::ForEachMutableElement(const MutableVisitorFunction& func) { +template +void ShapeTree::ForEachMutableElement(const Fn& func) { ShapeIndex index; return ForEachMutableHelper( [&func](const ShapeIndex& index, T* data) { diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc index afc3a2b2a34777780ec66d2325011390879fe693..3a5db1b3a651e2d353741c6bf4f6962da4e54ba1 100644 --- a/tensorflow/compiler/xla/shape_tree_test.cc +++ b/tensorflow/compiler/xla/shape_tree_test.cc @@ -365,5 +365,31 @@ TEST_F(ShapeTreeTest, OperatorEquals) { } } +TEST_F(ShapeTreeTest, ConstructWithPointerToShape) { + // Construct a ShapeTree using a pointer to a shape, rather than a reference + // to a shape. This constructor is an optimization to let us avoid + // constructing and destroying temporary shapes when we have many ShapeTrees. + ShapeTree t(&nested_tuple_shape_, 42); + int num_nodes = 0; + t.ForEachElement([&num_nodes](const ShapeIndex& /*index*/, int data) { + EXPECT_EQ(42, data); + ++num_nodes; + }); + EXPECT_EQ(10, num_nodes); +} + +TEST_F(ShapeTreeTest, CopyWithPointerToShape) { + ShapeTree source(&nested_tuple_shape_, 0); + ShapeTree dest(source); + EXPECT_EQ(&dest.shape(), &nested_tuple_shape_); +} + +TEST_F(ShapeTreeTest, CopyAssignWithPointerToShape) { + ShapeTree source(&nested_tuple_shape_, 0); + ShapeTree dest; + dest = source; + EXPECT_EQ(&dest.shape(), &nested_tuple_shape_); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index ee49a9ae5f5ff442284f2c4bd620425f815fb08d..057905a4311edc246eeea55019821e834605ae78 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -105,6 +105,11 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { return equal; } +/* static */ int64 ShapeUtil::Rank(const Shape& shape) { + CHECK(!ShapeUtil::IsTuple(shape)) << "Tuples do not have a rank"; + return shape.dimensions_size(); +} + /* static */ int64 ShapeUtil::TrueRank(const Shape& shape) { int64 accum = 0; for (int64 dimension : shape.dimensions()) { @@ -165,6 +170,17 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { } return MakeShapeWithMonotonicDim0MajorLayout(shape.element_type(), dims); } + +/* static */ Shape ShapeUtil::ShapeWithoutPadding(const Shape& shape) { + Shape result = shape; + ForEachMutableSubshape(&result, [](Shape* subshape, const ShapeIndex& index) { + auto layout = subshape->mutable_layout(); + layout->clear_padding_value(); + layout->clear_padded_dimensions(); + }); + return result; +} + /* static */ void ShapeUtil::PopulateShape( PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, Shape* shape) { @@ -270,7 +286,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { } /* static */ bool ShapeUtil::IsNil(const Shape& shape) { - return IsEmptyTuple(shape) || HasZeroElements(shape); + return IsTuple(shape) ? IsEmptyTuple(shape) : HasZeroElements(shape); } /* static */ int64 ShapeUtil::TupleElementCount(const Shape& shape) { @@ -323,6 +339,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { } /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) { + CHECK(!IsTuple(shape)); CHECK_EQ(shape.dimensions_size(), Rank(shape)); return std::accumulate( shape.dimensions().begin(), shape.dimensions().end(), 1LL, @@ -534,11 +551,6 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal( const Shape& shape) { if (shape.element_type() == TUPLE) { - // Tuple shape. - if (Rank(shape) != 0) { - return InvalidArgument("tuples must be rank-0; got rank %lld", - Rank(shape)); - } if (shape.dimensions_size() != 0) { return InvalidArgument("tuples must not have dimensions specified"); } diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 853be6b4cb81881f3f03dbb119dee533aa27634f..fa34bfc951d58d252b4381e10a01b39698eb9015 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -93,6 +93,7 @@ class ShapeUtil { public: // Returns the number of elements are contained within the provided shape; // e.g. for rank 0 (scalars) the result is always 1. + // Precondition: !IsTuple(shape) static int64 ElementsIn(const Shape& shape); // Returns true if 'shape' has zero elements. @@ -144,7 +145,8 @@ class ShapeUtil { static bool Equal(const Shape& lhs, const Shape& rhs); // Returns the rank (number of dimensions) of the given shape. - static int64 Rank(const Shape& shape) { return shape.dimensions_size(); } + // Precondition: !IsTuple(shape) + static int64 Rank(const Shape& shape); // Returns the number of dimensions for which the dimension is not (trivially) // 1. e.g., f32[2x1x1] has a true rank of 1D, the other dimensions are just @@ -220,6 +222,9 @@ class ShapeUtil { // elements with a different shape. static Shape NormalizeShapeToMonotonicDim0MajorLayout(const Shape& shape); + // Returns a new shape that has all padding values cleared. + static Shape ShapeWithoutPadding(const Shape& shape); + // As MakeShape, but the object to write to is passed in. static void PopulateShape(PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 13dd1a30b60a64171425f2a7d872da9bb2ca5380..5298af788ea479a92cc2554ed6870032dfd18bd7 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -94,11 +94,11 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", - "//tensorflow/compiler/xla/legacy_flags:hlo_test_base_flags", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_layout", + "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_execution_profile", @@ -116,8 +116,12 @@ cc_binary( name = "local_client_aot_test_helper", srcs = ["local_client_aot_test_helper.cc"], deps = [ + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", ], ) @@ -139,6 +143,7 @@ cc_library( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -151,7 +156,6 @@ cc_library( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", @@ -171,6 +175,7 @@ cc_library( deps = [ "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:executable", @@ -196,12 +201,14 @@ cc_library( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", @@ -213,12 +220,13 @@ xla_test( srcs = ["bad_rng_shape_validation_test.cc"], deps = [ "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/core:lib", @@ -233,12 +241,12 @@ xla_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/core:test", @@ -255,7 +263,6 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/core:test", @@ -268,6 +275,7 @@ xla_test( deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", @@ -275,7 +283,6 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:client_library_test_base", @@ -291,7 +298,6 @@ xla_test( deps = [ "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -307,6 +313,7 @@ xla_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", @@ -315,7 +322,6 @@ xla_test( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -339,7 +345,6 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -356,7 +361,6 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/core:test", @@ -371,7 +375,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -388,7 +392,7 @@ xla_test( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -409,7 +413,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -422,12 +426,13 @@ xla_test( srcs = ["deallocation_test.cc"], deps = [ "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -441,13 +446,14 @@ xla_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -471,9 +477,28 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", - "//tensorflow/compiler/xla/legacy_flags:user_computation_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + ], +) + +xla_test( + name = "reduce_precision_test", + srcs = ["reduce_precision_test.cc"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -490,8 +515,7 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", - "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/legacy_flags:layout_util_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -503,14 +527,10 @@ xla_test( ) # Tests the dot operation in some cases that can be performed via a -# runtime call on some backends - e.g. a runtime call to to Eigen. +# runtime call on some backends - e.g. a runtime call to Eigen. xla_test( name = "dot_operation_runtime_test", srcs = ["dot_operation_test.cc"], - backend_args = { - "cpu": ["--xla_cpu_use_eigen"], - "cpu_parallel": ["--xla_cpu_use_eigen"], - }, deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -518,8 +538,7 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", - "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/legacy_flags:layout_util_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -536,11 +555,9 @@ xla_test( srcs = ["dot_operation_test.cc"], backend_args = { "cpu": [ - "--xla_cpu_use_eigen", "--xla_cpu_multi_thread_eigen=false", ], "cpu_parallel": [ - "--xla_cpu_use_eigen", "--xla_cpu_multi_thread_eigen=false", ], }, @@ -551,8 +568,7 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", - "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/legacy_flags:layout_util_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -568,11 +584,9 @@ xla_test( srcs = ["dot_operation_test.cc"], backend_args = { "cpu": [ - "--xla_cpu_use_eigen", "--xla_default_layout=major2minor", ], "cpu_parallel": [ - "--xla_cpu_use_eigen", "--xla_default_layout=major2minor", ], }, @@ -583,8 +597,7 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", - "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/legacy_flags:layout_util_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -605,7 +618,7 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -624,7 +637,7 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -650,7 +663,7 @@ xla_test( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -677,7 +690,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -694,12 +707,13 @@ xla_test( "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -710,19 +724,25 @@ xla_test( xla_test( name = "batch_normalization_test", srcs = ["batch_normalization_test.cc"], + shard_count = 40, deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -738,7 +758,7 @@ xla_test( "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -754,7 +774,7 @@ xla_test( "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", @@ -768,11 +788,13 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:platform_util", @@ -799,7 +821,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", @@ -816,7 +838,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", @@ -842,7 +864,7 @@ xla_test( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -850,11 +872,14 @@ xla_test( ], ) -xla_test( - name = "reduce_window_test", - timeout = "long", +# Note that the backend-specific macros (e.g. DISABLED_ON_CPU) would not work here, because the +# library does not get recompiled for each backend. +cc_library( + name = "reduce_window_test_library", + testonly = 1, srcs = ["reduce_window_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array4d", @@ -865,7 +890,7 @@ xla_test( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -873,6 +898,15 @@ xla_test( ], ) +xla_test( + name = "reduce_window_test", + timeout = "long", + srcs = [], + deps = [ + ":reduce_window_test_library", + ], +) + xla_test( name = "select_and_scatter_test", timeout = "long", @@ -889,7 +923,7 @@ xla_test( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -906,7 +940,7 @@ xla_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -921,10 +955,11 @@ xla_test( deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -941,11 +976,12 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", + "//tensorflow/core:test", ], ) @@ -958,7 +994,7 @@ xla_test( "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", @@ -976,8 +1012,7 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", - "//tensorflow/compiler/xla/legacy_flags:user_computation_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", ], @@ -995,7 +1030,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -1009,7 +1044,7 @@ xla_test( deps = [ "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", @@ -1022,7 +1057,7 @@ xla_test( deps = [ "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", @@ -1044,7 +1079,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -1058,11 +1093,12 @@ xla_test( deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1081,13 +1117,14 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -1103,7 +1140,7 @@ xla_test( "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -1125,7 +1162,7 @@ xla_test( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -1142,11 +1179,12 @@ xla_test( "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", @@ -1161,7 +1199,6 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1184,7 +1221,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -1199,7 +1236,7 @@ xla_test( deps = [ "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -1215,13 +1252,14 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", @@ -1240,7 +1278,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -1262,7 +1300,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1279,7 +1317,7 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1298,7 +1336,7 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -1315,10 +1353,41 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "multioutput_fusion_test", + srcs = ["multioutput_fusion_test.cc"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -1333,6 +1402,7 @@ cc_test( linkstatic = 1, deps = [ "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", ], @@ -1347,9 +1417,9 @@ cc_test( ":local_client_test_base", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:computation_tracker", "//tensorflow/compiler/xla/service:local_service", - "//tensorflow/core:lib", "//tensorflow/core:test_main", ], ) @@ -1365,7 +1435,7 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1381,7 +1451,7 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1407,7 +1477,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -1415,6 +1485,15 @@ xla_test( ], ) +xla_test( + name = "deep_graph_test", + srcs = ["deep_graph_test.cc"], + deps = [ + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + ], +) + cc_test( name = "literal_test_util_test", srcs = ["literal_test_util_test.cc"], diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index c07f2745fe9e67898148bf0026ac32534eac506c..fa6c97bcb8767a3cd2bd7b976c1dd6165efe023a 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -26,9 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -45,7 +43,7 @@ namespace { class ArrayElementwiseOpTest : public ClientLibraryTestBase { public: - ErrorSpec error_spec_{0.0001}; + ErrorSpec error_spec_{0.0001, 0.0001}; }; class ArrayElementwiseOpTestParamCount @@ -158,13 +156,13 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { b_values.push_back(2 * i / static_cast(count + 2)); } - std::unique_ptr a_literal = LiteralUtil::CreateR1({a_values}); + std::unique_ptr a_literal = Literal::CreateR1({a_values}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); auto a_constant = builder.ConstantR1(a_values); auto a_param = builder.Parameter(0, a_literal->shape(), "a_param"); - std::unique_ptr b_literal = LiteralUtil::CreateR1({b_values}); + std::unique_ptr b_literal = Literal::CreateR1({b_values}); std::unique_ptr b_data = client_->TransferToServer(*b_literal).ConsumeValueOrDie(); auto b_constant = builder.Parameter(1, a_literal->shape(), "b_param"); @@ -804,7 +802,7 @@ TEST_F(ArrayElementwiseOpTest, PowSpecialF32) { std::vector values = {1.0f, 2.0f, 3.2f, -4.0f}; std::vector exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr param_literal = LiteralUtil::CreateR1(values); + std::unique_ptr param_literal = Literal::CreateR1(values); std::unique_ptr param_data = client_->TransferToServer(*param_literal).ConsumeValueOrDie(); @@ -826,6 +824,244 @@ TEST_F(ArrayElementwiseOpTest, PowSpecialF32) { ComputeAndCompareR1(&b, expected, {param_data.get()}, error_spec_); } +TEST_F(ArrayElementwiseOpTest, PowOfExpF32) { + ComputationBuilder b(client_, TestName()); + + std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; + std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; + + std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr data0 = + client_->TransferToServer(*literal0).ConsumeValueOrDie(); + std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr data1 = + client_->TransferToServer(*literal1).ConsumeValueOrDie(); + auto param0 = b.Parameter(0, literal0->shape(), "param0"); + auto param1 = b.Parameter(1, literal1->shape(), "param1"); + b.Pow(b.Exp(param0), param1); + + std::vector expected(values0.size()); + for (int64 i = 0; i < values0.size(); ++i) { + expected[i] = std::pow(std::exp(values0[i]), values1[i]); + } + + ComputeAndCompareR1(&b, expected, {data0.get(), data1.get()}, + error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) { + ComputationBuilder b(client_, TestName()); + + std::vector values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f}; + std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; + + std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr data0 = + client_->TransferToServer(*literal0).ConsumeValueOrDie(); + std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr data1 = + client_->TransferToServer(*literal1).ConsumeValueOrDie(); + auto param0 = b.Parameter(0, literal0->shape(), "param0"); + auto param1 = b.Parameter(1, literal1->shape(), "param1"); + b.Log(b.Pow(param0, param1)); + + std::vector expected(values0.size()); + for (int64 i = 0; i < values0.size(); ++i) { + expected[i] = std::log(std::pow(values0[i], values1[i])); + } + + ComputeAndCompareR1(&b, expected, {data0.get(), data1.get()}, + error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, MulOfExpF32) { + ComputationBuilder b(client_, TestName()); + + std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; + std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; + + std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr data0 = + client_->TransferToServer(*literal0).ConsumeValueOrDie(); + std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr data1 = + client_->TransferToServer(*literal1).ConsumeValueOrDie(); + auto param0 = b.Parameter(0, literal0->shape(), "param0"); + auto param1 = b.Parameter(1, literal1->shape(), "param1"); + b.Mul(b.Exp(param0), b.Exp(param1)); + + std::vector expected(values0.size()); + for (int64 i = 0; i < values0.size(); ++i) { + expected[i] = std::exp(values0[i]) * std::exp(values1[i]); + } + + ComputeAndCompareR1(&b, expected, {data0.get(), data1.get()}, + error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, DivOfExpF32) { + ComputationBuilder b(client_, TestName()); + + std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; + std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; + + std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr data0 = + client_->TransferToServer(*literal0).ConsumeValueOrDie(); + std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr data1 = + client_->TransferToServer(*literal1).ConsumeValueOrDie(); + auto param0 = b.Parameter(0, literal0->shape(), "param0"); + auto param1 = b.Parameter(1, literal1->shape(), "param1"); + b.Div(param0, b.Exp(param1)); + + std::vector expected(values0.size()); + for (int64 i = 0; i < values0.size(); ++i) { + expected[i] = values0[i] / std::exp(values1[i]); + } + + ComputeAndCompareR1(&b, expected, {data0.get(), data1.get()}, + error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) { + ComputationBuilder b(client_, TestName()); + + std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f}; + std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; + std::vector values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; + + std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr data0 = + client_->TransferToServer(*literal0).ConsumeValueOrDie(); + + std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr data1 = + client_->TransferToServer(*literal1).ConsumeValueOrDie(); + + std::unique_ptr literal2 = Literal::CreateR1(values2); + std::unique_ptr data2 = + client_->TransferToServer(*literal2).ConsumeValueOrDie(); + auto param0 = b.Parameter(0, literal0->shape(), "param0"); + auto param1 = b.Parameter(1, literal1->shape(), "param1"); + auto param2 = b.Parameter(2, literal2->shape(), "param2"); + b.Div(b.Div(param0, param1), param2); + + std::vector expected(values0.size()); + for (int64 i = 0; i < values0.size(); ++i) { + expected[i] = (values0[i] / values1[i]) / values2[i]; + } + + ComputeAndCompareR1( + &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) { + ComputationBuilder b(client_, TestName()); + + std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f}; + std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; + std::vector values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; + + std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr data0 = + client_->TransferToServer(*literal0).ConsumeValueOrDie(); + + std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr data1 = + client_->TransferToServer(*literal1).ConsumeValueOrDie(); + + std::unique_ptr literal2 = Literal::CreateR1(values2); + std::unique_ptr data2 = + client_->TransferToServer(*literal2).ConsumeValueOrDie(); + + auto param0 = b.Parameter(0, literal0->shape(), "param0"); + auto param1 = b.Parameter(1, literal1->shape(), "param1"); + auto param2 = b.Parameter(2, literal2->shape(), "param2"); + b.Div(param0, b.Div(param1, param2)); + + std::vector expected(values0.size()); + for (int64 i = 0; i < values0.size(); ++i) { + expected[i] = values0[i] / (values1[i] / values2[i]); + } + + ComputeAndCompareR1( + &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) { + ComputationBuilder b(client_, TestName()); + + std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f}; + std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f}; + std::vector values2 = {0.1f, 1.1f, 6.9f, 9.5f, -11.0f, -0.5f}; + + std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr data0 = + client_->TransferToServer(*literal0).ConsumeValueOrDie(); + + std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr data1 = + client_->TransferToServer(*literal1).ConsumeValueOrDie(); + + std::unique_ptr literal2 = Literal::CreateR1(values2); + std::unique_ptr data2 = + client_->TransferToServer(*literal2).ConsumeValueOrDie(); + + auto param0 = b.Parameter(0, literal0->shape(), "param0"); + auto param1 = b.Parameter(1, literal1->shape(), "param1"); + auto param2 = b.Parameter(2, literal2->shape(), "param2"); + b.Div(param0, b.Pow(param1, param2)); + + std::vector expected(values0.size()); + for (int64 i = 0; i < values0.size(); ++i) { + expected[i] = values0[i] / std::pow(values1[i], values2[i]); + } + + ComputeAndCompareR1( + &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, Div4F32) { + ComputationBuilder b(client_, TestName()); + + std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f}; + std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; + std::vector values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; + std::vector values3 = {2.1f, 3.1f, 9.9f, -4.5f, -11.0f, -21.5f}; + + std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr data0 = + client_->TransferToServer(*literal0).ConsumeValueOrDie(); + + std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr data1 = + client_->TransferToServer(*literal1).ConsumeValueOrDie(); + + std::unique_ptr literal2 = Literal::CreateR1(values2); + std::unique_ptr data2 = + client_->TransferToServer(*literal2).ConsumeValueOrDie(); + + std::unique_ptr literal3 = Literal::CreateR1(values3); + std::unique_ptr data3 = + client_->TransferToServer(*literal3).ConsumeValueOrDie(); + + auto param0 = b.Parameter(0, literal0->shape(), "param0"); + auto param1 = b.Parameter(1, literal1->shape(), "param1"); + auto param2 = b.Parameter(2, literal2->shape(), "param2"); + auto param3 = b.Parameter(3, literal3->shape(), "param2"); + b.Div(b.Div(param0, param1), b.Div(param2, param3)); + + std::vector expected(values0.size()); + for (int64 i = 0; i < values0.size(); ++i) { + expected[i] = (values0[i] / values1[i]) / (values2[i] / values3[i]); + } + + ComputeAndCompareR1( + &b, expected, {data0.get(), data1.get(), data2.get(), data3.get()}, + error_spec_); +} + TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) { const int count = GetParam(); ComputationBuilder builder(client_, TestName()); @@ -1241,12 +1477,12 @@ TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); + Literal::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - LiteralUtil::CreateR1({7.2f, 2.3f, 3.4f, 5.6f}); + Literal::CreateR1({7.2f, 2.3f, 3.4f, 5.6f}); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); @@ -1263,12 +1499,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR3FromArray3D(Array3D(0, 7, 0)); + Literal::CreateR3FromArray3D(Array3D(0, 7, 0)); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - LiteralUtil::CreateR3FromArray3D(Array3D(0, 7, 0)); + Literal::CreateR3FromArray3D(Array3D(0, 7, 0)); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); @@ -1285,7 +1521,7 @@ TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); + Literal::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -1297,6 +1533,15 @@ TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { {param0_data.get()}, error_spec_); } +XLA_TEST_F(ArrayElementwiseOpTest, CosF32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({3.14159f, 0.0f, 1.570796f, -0.78539f}); + auto result = builder.Cos(a); + + ComputeAndCompareR1(&builder, {-1.0f, 1.0f, 0.0f, 0.707107f}, {}, + error_spec_); +} + TEST_F(ArrayElementwiseOpTest, TanhF32s) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f}); @@ -1447,9 +1692,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) { auto cmp_dim_1 = builder.Eq(v, m, /*broadcast_dimensions=*/{0}); auto result = builder.Tuple({cmp_dim_0, cmp_dim_1}); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{true, true}, {true, false}}).get(), - LiteralUtil::CreateR2({{true, false}, {false, false}}).get()}); + auto expected = Literal::MakeTuple( + {Literal::CreateR2({{true, true}, {true, false}}).get(), + Literal::CreateR2({{true, false}, {false, false}}).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -1802,7 +2047,7 @@ TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) { std::iota(r1.begin(), r1.end(), 1.0); ComputationBuilder builder(client_, TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR4FromArray4D(r4); + std::unique_ptr a_literal = Literal::CreateR4FromArray4D(r4); *a_literal->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2, 3}); auto a = builder.ConstantLiteral(*a_literal); @@ -1838,8 +2083,8 @@ TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) { // broadcast. TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) { ComputationBuilder builder(client_, TestName()); - auto x_literal = LiteralUtil::CreateR1({1, 2, 3}); - auto y_literal = LiteralUtil::CreateR1({4, 5}); + auto x_literal = Literal::CreateR1({1, 2, 3}); + auto y_literal = Literal::CreateR1({4, 5}); auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); @@ -1862,8 +2107,6 @@ INSTANTIATE_TEST_CASE_P(ArrayElementwiseOpTestParamCount, int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); - xla::legacy_flags::AppendUserComputationFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/axpy_simple_test.cc b/tensorflow/compiler/xla/tests/axpy_simple_test.cc index a1ca1de584f8be808d19a43680f7c093d4f94def..67dbc913b42c89bf5a8fb5b91da13a29e5e248f5 100644 --- a/tensorflow/compiler/xla/tests/axpy_simple_test.cc +++ b/tensorflow/compiler/xla/tests/axpy_simple_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -76,7 +75,6 @@ TEST_F(AxpySimpleTest, AxpyTenValues) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc index ea58491038c1dfcc8069b3c14833ade554be0d8a..02be0b5ab83c23fda36c5ccc65a598fc8e4a1600 100644 --- a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc +++ b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -70,7 +69,6 @@ TEST_F(BadRngShapeValidationTest, ShapeWithoutLayoutIsOk) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc index 6a47f1b718a1734de731ec50d7094ac529eca9df..98bf2063c20ee2ca6c6dc734dcaddb600d096737 100644 --- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -23,13 +23,22 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/reference_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -48,7 +57,7 @@ class BatchNormalizationTest : public ClientLibraryTestBase { {5.0f, 4.4f}, // p2 }); input_array_.FillWithPZ(pz); - input_literal_ = *LiteralUtil::CreateR4FromArray4D(input_array_); + input_literal_ = *Literal::CreateR4FromArray4D(input_array_); CHECK_EQ(kSamples, input_array_.planes()); CHECK_EQ(kZ, input_array_.depth()); CHECK_EQ(kY, input_array_.height()); @@ -190,13 +199,258 @@ TEST_F(BatchNormalizationTest, SpecComparisonForward) { ComputeAndCompareR4(&builder, expected, {}, error_spec_); } +struct BatchNormTestParam { + std::vector bounds; + int64 feature_index; + float random_value_mean; + float random_value_var; +}; + +// Tests to test the fused operation of BatchNorm. +class BatchNormTest : public ClientLibraryTestBase, + public ::testing::WithParamInterface { +}; + +// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20. +XLA_TEST_P(BatchNormTest, DISABLED_ON_GPU(RandomizedTests)) { + float epsilon = 0.001; + ComputationBuilder builder(client_, TestName()); + const std::vector& bounds = GetParam().bounds; + Array4D input_array(bounds[0], bounds[1], bounds[2], bounds[3]); + input_array.FillRandom(GetParam().random_value_var, + GetParam().random_value_mean); + + const int64 feature_index = GetParam().feature_index; + const int64 num_elements_per_feature = + Product(bounds) / bounds[feature_index]; + const int64 feature_bound = bounds[feature_index]; + std::vector offset(feature_bound, 1); + std::vector scale(feature_bound, 2); + + auto input_squared = + ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; }); + std::vector reduce_dims; + for (int64 i = 0; i < bounds.size(); ++i) { + if (i != feature_index) { + reduce_dims.push_back(i); + } + } + + auto sum = + ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims, + [](float a, float b) { return a + b; }); + + auto sum_squared = + ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims, + [](float a, float b) { return a + b; }); + + std::vector mean(feature_bound); + + for (int64 i = 0; i < feature_bound; ++i) { + mean[i] = sum[i] / num_elements_per_feature; + } + + std::vector mean_square(feature_bound); + for (int64 i = 0; i < feature_bound; ++i) { + mean_square[i] = mean[i] * mean[i]; + } + + std::vector square_mean(feature_bound); + for (int64 i = 0; i < feature_bound; ++i) { + square_mean[i] = sum_squared[i] / num_elements_per_feature; + } + + std::vector var(feature_bound); + for (int64 i = 0; i < feature_bound; ++i) { + var[i] = square_mean[i] - mean_square[i]; + } + + Array4D mean_4D = + *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index); + auto var_4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index); + auto scale_4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index); + auto offset_4D = + *ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index); + + auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean_4D, var_4D, + scale_4D, offset_4D, epsilon); + + auto expected_normalized = Literal::CreateR4FromArray4D(normalized); + + auto offset_literal = Literal::CreateR1(offset); + auto scale_literal = Literal::CreateR1(scale); + auto input_literal = Literal::CreateR4FromArray4D(input_array); + + auto input_activations = + builder.Parameter(0, input_literal->shape(), "input"); + auto scale_activations = + builder.Parameter(1, scale_literal->shape(), "offset"); + auto offset_activations = + builder.Parameter(2, offset_literal->shape(), "scale"); + + auto expected = *Literal::MakeTuple({expected_normalized.get(), + Literal::CreateR1(mean).get(), + Literal::CreateR1(var).get()}); + + std::unique_ptr input_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + std::unique_ptr scale_data = + client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); + std::unique_ptr offset_data = + client_->TransferToServer(*offset_literal).ConsumeValueOrDie(); + + builder.BatchNormTraining(input_activations, scale_activations, + offset_activations, epsilon, feature_index); + + ComputeAndCompareTuple( + &builder, expected, + {input_data.get(), scale_data.get(), offset_data.get()}, + ErrorSpec(0.01, 1)); +} + +INSTANTIATE_TEST_CASE_P( + BatchNormTest_Instantiation, BatchNormTest, + ::testing::Values(BatchNormTestParam{{2, 2, 2, 2}, 0, 100.2f, 200.0f}, + BatchNormTestParam{{2, 2, 2, 2}, 3, 300.f, 400.0f}, + + BatchNormTestParam{{1, 10, 1, 1}, 0, 10.1f, 20.1f}, + BatchNormTestParam{{10, 10, 10, 10}, 1, 3.14f, 314.15f}, + BatchNormTestParam{{10, 10, 10, 10}, 2, 666.6f, 777.7f}, + BatchNormTestParam{{10, 10, 10, 10}, 1, -666.6f, 777.7f}, + BatchNormTestParam{{10, 10, 10, 10}, 2, 0.f, 777.7f}, + BatchNormTestParam{{1, 1, 10, 1}, 3, 888.8f, 9.9f}, + + BatchNormTestParam{{24, 129, 1, 2}, 2, 10000, 10000}, + BatchNormTestParam{{24, 129, 1, 2}, 3, 10000, 10000}, + + // Feature on low dimension to trigger relayout, test + // internal logical to physical dimension calculation + // is correct after relayout. + BatchNormTestParam{{1, 2, 3, 4}, 0, 100, 100})); + +// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20. +XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(BasicTraining)) { + const int kFeatureIndex = 3; + ComputationBuilder builder(client_, TestName()); + + auto operand = builder.ConstantR4FromArray4D( + {{{{1.f, 2.f}}, {{3.f, 4.f}}}, {{{5.f, 6.f}}, {{7.f, 8.f}}}}); + + auto scale = builder.ConstantR1({2.0f, 3.0f}); + + auto offset = builder.ConstantR1({1.0f, 2.0f}); + + auto tuple = builder.BatchNormTraining(operand, scale, offset, + /*epsilon=*/0.001, kFeatureIndex); + + auto expected = *Literal::MakeTuple( + {Literal::CreateR4({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}}, + {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}) + .get(), + Literal::CreateR1({4, 5}).get(), + Literal::CreateR1({5, 5}).get()}); + + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); +} + +// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20. +XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(BasicTrainingOnSublane)) { + const int kFeatureIndex = 2; + ComputationBuilder builder(client_, TestName()); + + auto operand = builder.ConstantR4FromArray4D( + {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}}); + + auto scale = builder.ConstantR1({2.0f, 3.0f}); + + auto offset = builder.ConstantR1({1.0f, 2.0f}); + + auto tuple = builder.BatchNormTraining(operand, scale, offset, + /*epsilon=*/0.001, kFeatureIndex); + + auto expected = *Literal::MakeTuple( + {Literal::CreateR4({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}}, + {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}) + .get(), + Literal::CreateR1({4, 5}).get(), + Literal::CreateR1({5, 5}).get()}); + + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); +} + +// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20. +XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(TrainingWithFeatureOnLowDimension)) { + // Use 0 dimension as feature, tests layout analyzer. + const int kFeatureIndex = 0; + ComputationBuilder builder(client_, TestName()); + + ComputationDataHandle h0; + auto operand = CreateR3Parameter(Array3D(260, 2, 2, 1.0f), + /*parameter_number=*/0, "operand", + &builder, &h0); + ComputationDataHandle h1; + auto scale = + CreateR1Parameter(std::vector(260, 1.0f), + /*parameter_number=*/1, "scale", &builder, &h1); + ComputationDataHandle h2; + auto offset = + CreateR1Parameter(std::vector(260, 1.0f), + /*parameter_number=*/2, "offset", &builder, &h2); + + auto tuple = builder.BatchNormTraining(h0, h1, h2, + /*epsilon=*/1, kFeatureIndex); + + auto expected = *Literal::MakeTuple( + {Literal::CreateR3FromArray3D(Array3D(260, 2, 2, 1.0f)) + .get(), + Literal::CreateR1(std::vector(260, 1.0f)).get(), + Literal::CreateR1(std::vector(260, 0.0f)).get()}); + + ComputeAndCompareTuple(&builder, expected, + {operand.get(), scale.get(), offset.get()}, + ErrorSpec(0.1)); +} + +// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20. +XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(LargeEpsilonTest)) { + // Test the correctness of choosing a large epsilon value. + const int kFeatureIndex = 2; + ComputationBuilder builder(client_, TestName()); + + ComputationDataHandle h0; + auto operand = CreateR3Parameter({{{0.0f}, {10.0f}, {20.0f}, {30.0f}}}, + /*parameter_number=*/0, "operand", + &builder, &h0); + ComputationDataHandle h1; + auto scale = + CreateR1Parameter(std::vector(1, 1.0f), + /*parameter_number=*/1, "scale", &builder, &h1); + ComputationDataHandle h2; + auto offset = + CreateR1Parameter(std::vector(1, 0.0f), + /*parameter_number=*/2, "offset", &builder, &h2); + + // var = 125, mean = 15, epsilon = -100 + auto tuple = builder.BatchNormTraining(h0, h1, h2, + /*epsilon=*/-100, kFeatureIndex); + + auto expected = *Literal::MakeTuple( + {Literal::CreateR3FromArray3D({{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}) + .get(), + Literal::CreateR1(std::vector(1, 15.0f)).get(), + Literal::CreateR1(std::vector(1, 125.0f)).get()}); + + ComputeAndCompareTuple(&builder, expected, + {operand.get(), scale.get(), offset.get()}, + ErrorSpec(0.1)); +} + } // namespace } // namespace xla int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/binop_scaling_test.cc b/tensorflow/compiler/xla/tests/binop_scaling_test.cc index 5e3b70702dd482e6b278386d70fef60b1bacb926..e6b853c2e4e4a08174012c1eb8be3739a2c9dba9 100644 --- a/tensorflow/compiler/xla/tests/binop_scaling_test.cc +++ b/tensorflow/compiler/xla/tests/binop_scaling_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -143,7 +142,6 @@ TEST_F(BinopScalingTest, R4PlusR0S32) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index 25fe04a930e3783ff6024a0bb3bddc430c4fafdd..2a57835ca93cd2b367fe0402aee1f986ae2d4ff3 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -21,9 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -63,9 +61,8 @@ class BroadcastSimpleTest : public ClientLibraryTestBase { Array3D* r3_array, float start, float end, int seed) { *r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); r3_array->FillRandom(start, end, seed); - auto r3_data = - LiteralUtil::Relayout(*LiteralUtil::CreateR3FromArray3D(*r3_array), - LayoutUtil::MakeLayout(minor_to_major)); + auto r3_data = Literal::CreateR3FromArray3D(*r3_array)->Relayout( + LayoutUtil::MakeLayout(minor_to_major)); std::unique_ptr r3_global_data = client_->TransferToServer(*r3_data).ConsumeValueOrDie(); return r3_global_data; @@ -77,9 +74,8 @@ class BroadcastSimpleTest : public ClientLibraryTestBase { Array2D* r2_array, float start, float end, int seed) { *r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); r2_array->FillRandom(start, end, seed); - auto r2_data = - LiteralUtil::Relayout(*LiteralUtil::CreateR2FromArray2D(*r2_array), - LayoutUtil::MakeLayout(minor_to_major)); + auto r2_data = Literal::CreateR2FromArray2D(*r2_array)->Relayout( + LayoutUtil::MakeLayout(minor_to_major)); std::unique_ptr r2_global_data = client_->TransferToServer(*r2_data).ConsumeValueOrDie(); return r2_global_data; @@ -217,13 +213,13 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) { ComputationBuilder b(client_, TestName()); b.Add(b.ConstantR2({{1.0, 5.0}}), - b.ConstantLiteral(*LiteralUtil::CreateR3( + b.ConstantLiteral(*Literal::CreateR3( {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), /*broadcast_dimensions=*/{1, 2}); auto expected = - LiteralUtil::CreateR3({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}}, - {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}}); + Literal::CreateR3({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}}, + {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -292,7 +288,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) { } } } - auto expected = LiteralUtil::CreateR3FromArray3D(expected_array); + auto expected = Literal::CreateR3FromArray3D(expected_array); ComputeAndCompareLiteral( &builder, *expected, {r3_implicit_global_data.get(), r3_global_data.get()}, @@ -317,7 +313,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { b.Add(r3h, r1h); auto expected = - LiteralUtil::CreateR3({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}}); + Literal::CreateR3({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}}); ComputeAndCompareLiteral(&b, *expected, {r3.get(), r1.get()}, ErrorSpec(0.0001)); @@ -325,81 +321,79 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) { ComputationBuilder b(client_, TestName()); - auto r1 = b.ConstantLiteral(*LiteralUtil::CreateR3({{{1, 2}}})); + auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1, 2}}})); auto r3 = b.ConstantLiteral( - *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); b.Add(r3, r1); auto expected = - LiteralUtil::CreateR3({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}}); + Literal::CreateR3({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) { ComputationBuilder b(client_, TestName()); - auto r1 = b.ConstantLiteral(*LiteralUtil::CreateR3({{{1}, {2}}})); + auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1}, {2}}})); auto r3 = b.ConstantLiteral( - *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); b.Add(r3, r1); auto expected = - LiteralUtil::CreateR3({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}}); + Literal::CreateR3({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) { ComputationBuilder b(client_, TestName()); - auto r1 = - b.ConstantLiteral(*LiteralUtil::CreateR3({{{1, 2}, {3, 4}}})); + auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1, 2}, {3, 4}}})); auto r3 = b.ConstantLiteral( - *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); b.Add(r3, r1); auto expected = - LiteralUtil::CreateR3({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}}); + Literal::CreateR3({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) { ComputationBuilder b(client_, TestName()); - auto r1 = - b.ConstantLiteral(*LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}})); + auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1, 2}}, {{3, 4}}})); auto r3 = b.ConstantLiteral( - *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); b.Add(r3, r1); auto expected = - LiteralUtil::CreateR3({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}}); + Literal::CreateR3({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) { ComputationBuilder b(client_, TestName()); - auto r1 = b.ConstantLiteral( - *LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}})); + auto r1 = + b.ConstantLiteral(*Literal::CreateR3({{{1}, {2}}, {{3}, {4}}})); auto r3 = b.ConstantLiteral( - *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); b.Add(r3, r1); auto expected = - LiteralUtil::CreateR3({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}}); + Literal::CreateR3({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) { ComputationBuilder b(client_, TestName()); - auto r1 = b.ConstantLiteral(*LiteralUtil::CreateR3({{{1}}})); + auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1}}})); auto r3 = b.ConstantLiteral( - *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); b.Add(r3, r1); auto expected = - LiteralUtil::CreateR3({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}}); + Literal::CreateR3({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -541,7 +535,7 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) { *v = ApplyOpToFloats(spec.op2, tmp, v3); }); - auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); + auto expected = Literal::CreateR2FromArray2D(expected_array); ComputeAndCompareLiteral( &builder, *expected, {r2_implicit_global_data1.get(), r2_global_data.get(), @@ -555,22 +549,22 @@ INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances, XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) { ComputationBuilder b(client_, TestName()); - auto r1 = b.ConstantLiteral(*LiteralUtil::CreateR2({{1, 2}})); - auto r2 = b.ConstantLiteral(*LiteralUtil::CreateR2({{1, 2}, {3, 4}})); + auto r1 = b.ConstantLiteral(*Literal::CreateR2({{1, 2}})); + auto r2 = b.ConstantLiteral(*Literal::CreateR2({{1, 2}, {3, 4}})); b.Add(r2, r1); - auto expected = LiteralUtil::CreateR2({{2, 4}, {4, 6}}); + auto expected = Literal::CreateR2({{2, 4}, {4, 6}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) { ComputationBuilder b(client_, TestName()); - auto r1 = b.ConstantLiteral(*LiteralUtil::CreateR2({{1}, {2}})); - auto r2 = b.ConstantLiteral(*LiteralUtil::CreateR2({{1, 2}, {3, 4}})); + auto r1 = b.ConstantLiteral(*Literal::CreateR2({{1}, {2}})); + auto r2 = b.ConstantLiteral(*Literal::CreateR2({{1, 2}, {3, 4}})); b.Add(r2, r1); - auto expected = LiteralUtil::CreateR2({{2, 3}, {5, 6}}); + auto expected = Literal::CreateR2({{2, 3}, {5, 6}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -579,11 +573,11 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) { ComputationBuilder b(client_, TestName()); auto r1 = b.ConstantR1({10, 20}); auto r3 = b.ConstantLiteral( - *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); b.Add(r3, r1, {0}); - auto expected = LiteralUtil::CreateR3( - {{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}}); + auto expected = + Literal::CreateR3({{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -592,11 +586,11 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) { ComputationBuilder b(client_, TestName()); auto r1 = b.ConstantR1({10, 20}); auto r3 = b.ConstantLiteral( - *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); b.Add(r1, r3, {1}); - auto expected = LiteralUtil::CreateR3( - {{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}}); + auto expected = + Literal::CreateR3({{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -605,11 +599,11 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) { ComputationBuilder b(client_, TestName()); auto r1 = b.ConstantR1({10, 20}); auto r3 = b.ConstantLiteral( - *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); b.Add(r1, r3, {2}); - auto expected = LiteralUtil::CreateR3( - {{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}}); + auto expected = + Literal::CreateR3({{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -620,7 +614,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { auto r1_1 = b.ConstantR1({100, 200}); auto r1_2 = b.ConstantR1({10, 20}); auto r3 = b.ConstantLiteral( - *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); for (int i = 0; i < 3; ++i) { r3 = b.Add(r1_0, r3, {0}); r3 = b.Add(r3, r1_1, {1}); @@ -628,7 +622,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { } r3 = b.Mul(r3, b.ConstantR0(-2)); - auto expected = LiteralUtil::CreateR3( + auto expected = Literal::CreateR3( {{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}}, {{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}}); @@ -649,7 +643,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) { } r3 = b.Mul(r3, b.ConstantR0(-1)); - auto expected = LiteralUtil::CreateR3( + auto expected = Literal::CreateR3( {{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}}, {{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}}); @@ -662,7 +656,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { ComputationBuilder b(client_, TestName()); b.Add(b.ConstantR2({{1.0, 5.0}, {1.0, 5.0}}), - b.ConstantLiteral(*LiteralUtil::CreateR3( + b.ConstantLiteral(*Literal::CreateR3( {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), /*broadcast_dimensions=*/{1, 2}); @@ -704,8 +698,6 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); - xla::legacy_flags::AppendUserComputationFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc index 96a329a9bd8296a11a3e22e8dea31d71dd973d76..dc1443f5363aab1e6166984a3f2f3fccefad908e 100644 --- a/tensorflow/compiler/xla/tests/broadcast_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -39,7 +38,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) { // Test degenerate case of broadcasting a scalar into a scalar. auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {}), input, {})); @@ -48,14 +47,14 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR0(42.0), *result, + LiteralTestUtil::ExpectNear(*Literal::CreateR0(42.0), *result, error_spec_); } XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {2, 2}), input, {})); @@ -65,14 +64,14 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR2({{42.0, 42.0}, {42.0, 42.0}}), *result, + *Literal::CreateR2({{42.0, 42.0}, {42.0, 42.0}}), *result, error_spec_); } XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); + Literal::CreateR1({1.0, 2.0, 3.0}))); // Broadcast vector in both dimension 0 and dimension 1. Join them in a tuple // to enable testing of the results. @@ -88,18 +87,18 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR2({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), + *Literal::CreateR2({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), result->tuple_literals(0), error_spec_); LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), + *Literal::CreateR2({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), result->tuple_literals(1), error_spec_); } XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {2, 2}), input, {0, 1})); @@ -109,7 +108,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), *result, + *Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), *result, error_spec_); } @@ -118,7 +117,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { // the dimensions, ie transpose. auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {2, 2}), input, {1, 0})); @@ -128,14 +127,14 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), *result, + *Literal::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), *result, error_spec_); } XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {2, 3, 2}), input, {0, 2})); @@ -145,15 +144,15 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, - {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}), + *Literal::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, + {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}), *result, error_spec_); } TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({1.0, 2.0}))); + HloInstruction::CreateConstant(Literal::CreateR1({1.0, 2.0}))); // Broadcast vector in dimension 1. builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -168,8 +167,8 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { Array2D pz({{1, 2}, {1, 2}}); expected.FillWithPZ(pz); - LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR4FromArray4D(expected), *result, error_spec_); + LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), + *result, error_spec_); } TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { @@ -178,7 +177,7 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { int64 r1_size = input_data.size(); std::iota(input_data.begin(), input_data.end(), 0.0f); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1(input_data))); + HloInstruction::CreateConstant(Literal::CreateR1(input_data))); // Broadcast vector in dimension 3. builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -198,8 +197,8 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { } expected.FillWithYX(yx); - LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR4FromArray4D(expected), *result, error_spec_); + LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), + *result, error_spec_); } XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { @@ -209,7 +208,7 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { std::vector r1_array(64, 42.0); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1(r1_array))); + HloInstruction::CreateConstant(Literal::CreateR1(r1_array))); // Broadcast vector in dimension 1. builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -220,14 +219,14 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR4FromArray4D(r4_array), - *result, error_spec_); + LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(r4_array), *result, + error_spec_); } TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {64, 64, 3, 3}), input, {})); @@ -240,15 +239,15 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { Array4D expected(64, 64, 3, 3); expected.Fill(1.0f); - LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR4FromArray4D(expected), *result, error_spec_); + LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), + *result, error_spec_); } TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { auto builder = HloComputation::Builder(TestName()); Array2D to_broadcast({{1.0f, 2.0f}, {3.0f, 4.0f}}); auto input = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2FromArray2D(to_broadcast))); + Literal::CreateR2FromArray2D(to_broadcast))); // Broadcast vector in dimensions 2 and 3. builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -262,8 +261,8 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { Array4D expected(3, 3, 2, 2); expected.FillWithYX(to_broadcast); - LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR4FromArray4D(expected), *result, error_spec_); + LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), + *result, error_spec_); } TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { @@ -282,7 +281,7 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { } } auto input = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR3FromArray3D(input_vals))); + Literal::CreateR3FromArray3D(input_vals))); // Broadcast vector in dimensions 2 and 3. builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -293,8 +292,8 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR4FromArray4D(expected), *result, error_spec_); + LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), + *result, error_spec_); } } // namespace @@ -302,7 +301,6 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl index 1f61743451a79a062205708d9ba6014f8a8591e9..a297132dd3ffdd597fac64483e161bc2b9602b41 100644 --- a/tensorflow/compiler/xla/tests/build_defs.bzl +++ b/tensorflow/compiler/xla/tests/build_defs.bzl @@ -1,12 +1,27 @@ """Build rules for XLA testing.""" load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured") +load("//tensorflow/compiler/xla/tests:plugin.bzl", "plugins") -def all_backends(): +all_backends = ["cpu", "cpu_parallel", "gpu"] + plugins.keys() + +def filter_backends(backends): + """Removes "gpu" from a backend list if CUDA is not enabled. + + This allows us to simply hardcode lists including "gpu" here and in the + BUILD file, without causing failures when CUDA isn't enabled.' + + Args: + backends: A list of backends to filter. + + Returns: + The filtered list of backends. + """ if cuda_is_configured(): - return ["cpu", "cpu_parallel", "gpu"] + return backends else: - return ["cpu", "cpu_parallel"] + return [backend for backend in backends if backend != "gpu"] + def xla_test(name, srcs, @@ -81,7 +96,7 @@ def xla_test(name, """ test_names = [] if not backends: - backends = all_backends() + backends = all_backends native.cc_library( name="%s_lib" % name, @@ -91,7 +106,7 @@ def xla_test(name, deps=deps + ["//tensorflow/compiler/xla/tests:test_macros_header"], ) - for backend in backends: + for backend in filter_backends(backends): test_name = "%s_%s" % (name, backend) this_backend_tags = ["xla_%s" % backend] this_backend_copts = [] @@ -107,6 +122,11 @@ def xla_test(name, backend_deps = ["//tensorflow/compiler/xla/service:gpu_plugin"] backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"] this_backend_tags += ["requires-gpu-sm35"] + elif backend in plugins: + backend_deps = plugins[backend]["deps"] + this_backend_copts += plugins[backend]["copts"] + this_backend_tags += plugins[backend]["tags"] + this_backend_args += plugins[backend]["args"] else: fail("Unknown backend %s" % backend) @@ -127,16 +147,16 @@ def xla_test(name, def generate_backend_suites(backends=[]): if not backends: - backends = all_backends() - for backend in backends: + backends = all_backends + for backend in filter_backends(backends): native.test_suite(name="%s_tests" % backend, tags = ["xla_%s" % backend]) def generate_backend_test_macros(backends=[]): if not backends: - backends = all_backends() - for backend in backends: + backends = all_backends + for backend in filter_backends(backends): native.cc_library( name="test_macros_%s" % backend, testonly = True, diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc index 55701c62db22f0fff6f4fdeabf0c72d600239969..086199fda1445c966917cff6849373e4474d16f7 100644 --- a/tensorflow/compiler/xla/tests/call_test.cc +++ b/tensorflow/compiler/xla/tests/call_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -78,7 +77,7 @@ class CallOpTest : public ClientLibraryTestBase { XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR0F32IdentityScalar)) { ComputationBuilder builder(client_, TestName()); Computation callee = CreateR0F32IdentityComputation(); - auto constant = builder.ConstantLiteral(*LiteralUtil::CreateR0(42.0)); + auto constant = builder.ConstantLiteral(*Literal::CreateR0(42.0)); builder.Call(callee, {constant}); ComputeAndCompareR0(&builder, 42.0, {}, ErrorSpec(0.01f)); @@ -87,8 +86,8 @@ XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR0F32IdentityScalar)) { XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR1S0F32AddArray)) { ComputationBuilder builder(client_, TestName()); Computation callee = CreateR1S0F32AdditionComputation(); - auto x = builder.ConstantLiteral(*LiteralUtil::CreateR1({})); - auto y = builder.ConstantLiteral(*LiteralUtil::CreateR1({})); + auto x = builder.ConstantLiteral(*Literal::CreateR1({})); + auto y = builder.ConstantLiteral(*Literal::CreateR1({})); builder.Call(callee, {x, y}); ComputeAndCompareR1(&builder, {}, {}, ErrorSpec(0.01f)); @@ -97,8 +96,8 @@ XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR1S0F32AddArray)) { XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR1S2F32AddArray)) { ComputationBuilder builder(client_, TestName()); Computation callee = CreateR1S2F32AdditionComputation(); - auto x = builder.ConstantLiteral(*LiteralUtil::CreateR1({1.0f, 2.0f})); - auto y = builder.ConstantLiteral(*LiteralUtil::CreateR1({2.0f, 3.0f})); + auto x = builder.ConstantLiteral(*Literal::CreateR1({1.0f, 2.0f})); + auto y = builder.ConstantLiteral(*Literal::CreateR1({2.0f, 3.0f})); builder.Call(callee, {x, y}); ComputeAndCompareR1(&builder, {3.0f, 5.0f}, {}, ErrorSpec(0.01f)); @@ -107,8 +106,8 @@ XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR1S2F32AddArray)) { XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR0F32Tuple)) { ComputationBuilder builder(client_, TestName()); Computation callee = CreateR0F32TupleComputation(); - auto elem = LiteralUtil::CreateR0(42.0); - auto tuple = LiteralUtil::MakeTuple({elem.get()}); + auto elem = Literal::CreateR0(42.0); + auto tuple = Literal::MakeTuple({elem.get()}); builder.Call(callee, {builder.ConstantLiteral(*elem)}); ComputeAndCompareTuple(&builder, *tuple, {}, ErrorSpec(0.01f)); @@ -120,7 +119,6 @@ XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR0F32Tuple)) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc index 4825eaf19dc28fd78a5d91a3c1e722c3916f6c20..2f4ad22f5bf0573ba97e6d28a3a207480fcdae18 100644 --- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc +++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -38,7 +37,7 @@ class CheckExecutionArityTest : public ClientLibraryTestBase {}; TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { ComputationBuilder builder(client_, "add_two_params"); - auto param_literal = LiteralUtil::CreateR1({1.1f, 2.2f}); + auto param_literal = Literal::CreateR1({1.1f, 2.2f}); auto p0 = builder.Parameter(0, param_literal->shape(), "param0"); auto p1 = builder.Parameter(1, param_literal->shape(), "param1"); @@ -55,18 +54,20 @@ TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { // The arity of the UserComputation is 2 arguments. Execution will succeed // with 2 arguments, but fail with a different number. - auto result_two_args = - client_->Execute(computation, {param0_data.get(), param1_data.get()}); + auto result_two_args = client_->Execute( + computation, {param0_data.get(), param1_data.get()}, &execution_options_); ASSERT_IS_OK(result_two_args.status()); - auto result_one_arg = client_->Execute(computation, {param0_data.get()}); + auto result_one_arg = + client_->Execute(computation, {param0_data.get()}, &execution_options_); ASSERT_FALSE(result_one_arg.ok()); ASSERT_EQ(result_one_arg.status().code(), tensorflow::error::INVALID_ARGUMENT); ASSERT_THAT(result_one_arg.status().error_message(), ContainsRegex("takes 2")); - auto result_zero_args = client_->Execute(computation, {}); + auto result_zero_args = + client_->Execute(computation, {}, &execution_options_); ASSERT_FALSE(result_zero_args.ok()); ASSERT_EQ(result_zero_args.status().code(), tensorflow::error::INVALID_ARGUMENT); @@ -85,35 +86,38 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { ASSERT_IS_OK(computation_status.status()); auto computation = computation_status.ConsumeValueOrDie(); - auto f32_literal = LiteralUtil::CreateR0(1.1f); + auto f32_literal = Literal::CreateR0(1.1f); auto f32_data = client_->TransferToServer(*f32_literal).ConsumeValueOrDie(); - auto f32_4_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f, 4.0f}); + auto f32_4_literal = Literal::CreateR1({1.0f, 2.0f, 3.0f, 4.0f}); auto f32_4_data = client_->TransferToServer(*f32_4_literal).ConsumeValueOrDie(); - auto u8_4_literal = LiteralUtil::CreateR1U8("hola"); + auto u8_4_literal = Literal::CreateR1U8("hola"); auto u8_4_data = client_->TransferToServer(*u8_4_literal).ConsumeValueOrDie(); // Match - auto status = - client_->Execute(computation, {f32_data.get(), f32_4_data.get()}); + auto status = client_->Execute( + computation, {f32_data.get(), f32_4_data.get()}, &execution_options_); ASSERT_IS_OK(status.status()); // Shape mismatch in parameter 0 - status = client_->Execute(computation, {f32_4_data.get(), f32_4_data.get()}); + status = client_->Execute(computation, {f32_4_data.get(), f32_4_data.get()}, + &execution_options_); ASSERT_FALSE(status.ok()); ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT); ASSERT_THAT(status.status().error_message(), ContainsRegex("expects parameter 0")); // Shape mismatch in parameter 1 (rank) - status = client_->Execute(computation, {f32_data.get(), f32_data.get()}); + status = client_->Execute(computation, {f32_data.get(), f32_data.get()}, + &execution_options_); ASSERT_FALSE(status.ok()); ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT); ASSERT_THAT(status.status().error_message(), ContainsRegex("expects parameter 1")); // Shape mismatch in parameter 1 (element type) - status = client_->Execute(computation, {f32_data.get(), u8_4_data.get()}); + status = client_->Execute(computation, {f32_data.get(), u8_4_data.get()}, + &execution_options_); ASSERT_FALSE(status.ok()); ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT); ASSERT_THAT(status.status().error_message(), @@ -126,7 +130,6 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index b96bb8f846909589a52269f0d314dbfd0af2be09..1f7a0ec9e75be574f40c917d672a99410e9370b1 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -45,10 +45,8 @@ Client* GetOrCreateLocalClientOrDie(se::Platform* platform) { } // namespace ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform) - : client_(GetOrCreateLocalClientOrDie(platform)) { - *(execution_options_.mutable_debug_options()) = - legacy_flags::GetDebugOptionsFromFlags(); - + : client_(GetOrCreateLocalClientOrDie(platform)), + execution_options_(CreateDefaultExecutionOptions()) { // Disabling constant_folding so that tests (usually written using Constants) // will exercise the intended code paths, instead of being constant folded. // @@ -71,13 +69,16 @@ StatusOr> ClientLibraryTestBase::Execute( return client_->Execute(computation, arguments, &execution_options_); } +StatusOr ClientLibraryTestBase::ExecuteAsync( + const Computation& computation, + tensorflow::gtl::ArraySlice arguments) { + return client_->ExecuteAsync(computation, arguments, &execution_options_); +} + StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( - ComputationBuilder* builder, + const Computation& computation, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_output_layout) { - // Build the computation, as a convenience. - TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); - ExecutionOptions execution_options = execution_options_; if (shape_with_output_layout != nullptr) { *execution_options.mutable_shape_with_output_layout() = @@ -87,6 +88,15 @@ StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( &execution_options); } +StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( + ComputationBuilder* builder, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_output_layout) { + // Build the computation, as a convenience. + TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); + return ExecuteAndTransfer(computation, arguments, shape_with_output_layout); +} + std::unique_ptr ClientLibraryTestBase::ExecuteOrDie( ComputationBuilder* builder, tensorflow::gtl::ArraySlice arguments) { @@ -113,14 +123,14 @@ string ClientLibraryTestBase::ExecuteToString( if (!result.ok()) { return result.status().ToString(); } else { - return LiteralUtil::ToString(*result.ValueOrDie()); + return result.ValueOrDie()->ToString(); } } void ClientLibraryTestBase::ComputeAndCompareR1( ComputationBuilder* builder, const tensorflow::core::Bitmap& expected, tensorflow::gtl::ArraySlice arguments) { - std::unique_ptr expected_literal = LiteralUtil::CreateR1(expected); + std::unique_ptr expected_literal = Literal::CreateR1(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -179,10 +189,10 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8( auto actual = actual_status.ConsumeValueOrDie(); // Turn the expected value into a literal. - std::unique_ptr expected_literal = LiteralUtil::CreateR1U8(expected); + std::unique_ptr expected_literal = Literal::CreateR1U8(expected); - VLOG(1) << "expected: " << LiteralUtil::ToString(*expected_literal); - VLOG(1) << "actual: " << LiteralUtil::ToString(*actual); + VLOG(1) << "expected: " << expected_literal->ToString(); + VLOG(1) << "actual: " << actual->ToString(); EXPECT_EQ(expected, actual->u8s_string()); } diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index f9e1082ebb43ae112c417ff9a71ef8d38b5de900..f665bb4514a9ed4c76c60d814fdf8424d3a05b5b 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -66,14 +66,23 @@ class ClientLibraryTestBase : public ::testing::Test { // TODO(b/25566808): Add helper that populates a literal from a testdata file. - // Convenience methods for building and running a computation from a builder. + // Convenience methods for building and running a computation with the member + // execution options. Modify execution_options_ in your test if you want to + // customize the options. StatusOr> Execute( ComputationBuilder* builder, tensorflow::gtl::ArraySlice arguments); + StatusOr ExecuteAsync( + const Computation& computation, + tensorflow::gtl::ArraySlice arguments); StatusOr> ExecuteAndTransfer( ComputationBuilder* builder, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_output_layout = nullptr); + StatusOr> ExecuteAndTransfer( + const Computation& computation, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_output_layout = nullptr); // Convenience OrDie variants of above methods. std::unique_ptr ExecuteOrDie( @@ -278,7 +287,7 @@ void ClientLibraryTestBase::ComputeAndCompareR0( ComputationBuilder* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = - LiteralUtil::CreateR0(expected); + Literal::CreateR0(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -291,7 +300,7 @@ void ClientLibraryTestBase::ComputeAndCompareR0( std::is_same::value, "Floating point type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = - LiteralUtil::CreateR0(expected); + Literal::CreateR0(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -301,7 +310,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1( ComputationBuilder* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = - LiteralUtil::CreateR1(expected); + Literal::CreateR1(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -314,7 +323,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1( std::is_same::value, "Floating point type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = - LiteralUtil::CreateR1(expected); + Literal::CreateR1(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -324,7 +333,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2( ComputationBuilder* builder, const Array2D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = - LiteralUtil::CreateR2FromArray2D(expected); + Literal::CreateR2FromArray2D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -337,7 +346,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2( std::is_same::value, "Floating point type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = - LiteralUtil::CreateR2FromArray2D(expected); + Literal::CreateR2FromArray2D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -347,7 +356,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3( ComputationBuilder* builder, const Array3D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = - LiteralUtil::CreateR3FromArray3D(expected); + Literal::CreateR3FromArray3D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -360,7 +369,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3( std::is_same::value, "Floating point type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = - LiteralUtil::CreateR3FromArray3D(expected); + Literal::CreateR3FromArray3D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -370,7 +379,7 @@ void ClientLibraryTestBase::ComputeAndCompareR4( ComputationBuilder* builder, const Array4D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = - LiteralUtil::CreateR4FromArray4D(expected); + Literal::CreateR4FromArray4D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -383,7 +392,7 @@ void ClientLibraryTestBase::ComputeAndCompareR4( std::is_same::value, "Floating point type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = - LiteralUtil::CreateR4FromArray4D(expected); + Literal::CreateR4FromArray4D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -392,7 +401,7 @@ template std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( NativeT value, int64 parameter_number, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle) { - std::unique_ptr literal = LiteralUtil::CreateR0(value); + std::unique_ptr literal = Literal::CreateR0(value); std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); *data_handle = builder->Parameter(parameter_number, literal->shape(), name); @@ -404,7 +413,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( tensorflow::gtl::ArraySlice values, int64 parameter_number, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle) { - std::unique_ptr literal = LiteralUtil::CreateR1(values); + std::unique_ptr literal = Literal::CreateR1(values); std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); *data_handle = builder->Parameter(parameter_number, literal->shape(), name); @@ -416,7 +425,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( const Array2D& array_2d, int64 parameter_number, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle) { - std::unique_ptr literal = LiteralUtil::CreateR2FromArray2D(array_2d); + std::unique_ptr literal = Literal::CreateR2FromArray2D(array_2d); std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); *data_handle = builder->Parameter(parameter_number, literal->shape(), name); @@ -428,7 +437,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR3Parameter( const Array3D& array_3d, int64 parameter_number, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle) { - std::unique_ptr literal = LiteralUtil::CreateR3FromArray3D(array_3d); + std::unique_ptr literal = Literal::CreateR3FromArray3D(array_3d); std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); *data_handle = builder->Parameter(parameter_number, literal->shape(), name); diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index 1247804dae0effd387d5f276a3d64667bc69e18b..e84a6ce710229043c903c5e50daf33e2f93fa6da 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -47,7 +46,7 @@ TEST_F(ClientTest, ExecuteWithLayout) { auto computation = b.Build(); ASSERT_TRUE(computation.ok()) << computation.status(); - ExecutionOptions execution_options; + ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, execute_layout); @@ -77,7 +76,7 @@ TEST_F(ClientTest, ExecuteWithTupleLayout) { auto computation = b.Build(); ASSERT_TRUE(computation.ok()) << computation.status(); - ExecutionOptions execution_options; + ExecutionOptions execution_options = execution_options_; // Create a result shape with one element column major and the other row // major. *execution_options.mutable_shape_with_output_layout() = @@ -115,7 +114,6 @@ TEST_F(ClientTest, ExecuteWithTupleLayout) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/codegen_test_base.cc b/tensorflow/compiler/xla/tests/codegen_test_base.cc index cc3eb0e8d46a8ab13553cb78f58bfc48b16ee862..90767c4a17478d4e7edd6202a8629db5b115381d 100644 --- a/tensorflow/compiler/xla/tests/codegen_test_base.cc +++ b/tensorflow/compiler/xla/tests/codegen_test_base.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/compiler.h" @@ -32,6 +33,20 @@ limitations under the License. namespace xla { +std::unique_ptr CodegenTestBase::CreateNewModuleWithEmbeddedIr( + bool ftz) { + HloModuleConfig config; + auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); + debug_options.set_xla_embed_ir_in_executable(true); + debug_options.set_xla_gpu_ftz(ftz); + // TODO(b/38354253): Change tests to use Parameters instead of Constants. + debug_options.add_xla_disable_hlo_passes("constant_folding"); + config.set_debug_options(debug_options); + + return MakeUnique(TestName(), VersionedComputationHandle(), + config); +} + void CodegenTestBase::CompileAndVerifyIr(std::unique_ptr hlo_module, const string& pattern) { std::unique_ptr executable = @@ -43,8 +58,7 @@ void CodegenTestBase::CompileAndVerifyIr(std::unique_ptr hlo_module, std::unique_ptr CodegenTestBase::CompileToExecutable( std::unique_ptr hlo_module) { return backend_->compiler() - ->Compile(std::move(hlo_module), test_hlo_dumper_, - backend_->default_stream_executor()) + ->Compile(std::move(hlo_module), backend_->default_stream_executor()) .ConsumeValueOrDie(); } diff --git a/tensorflow/compiler/xla/tests/codegen_test_base.h b/tensorflow/compiler/xla/tests/codegen_test_base.h index 50c0453107095c5fdb6238c88a17b31728b6bf22..fa073cd91ee07462d7aaf40789e87dbc831da95e 100644 --- a/tensorflow/compiler/xla/tests/codegen_test_base.h +++ b/tensorflow/compiler/xla/tests/codegen_test_base.h @@ -28,7 +28,11 @@ namespace xla { // Tests that verify IR emitted by the CPU/GPU backend is as expected. class CodegenTestBase : public HloTestBase { protected: - CodegenTestBase() {} + // Like HloTestBase::CreateNewModule, but also sets the "embed ir in + // executable" flag to true, since this is needed for codegen tests. + // The optional ftz flags configures whether these modules have their ftz + // option turned on. + std::unique_ptr CreateNewModuleWithEmbeddedIr(bool ftz = false); // Returns the embedded LLVM IR from the given executable. Codegen tests must // override this method, but execution tests do not have to because they do diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc index 18ea9714d1a8f5f5b127881f657e948d65003ab1..7038afc5b1f5dd388731ae82586fe24ac5476e8b 100644 --- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc +++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -48,10 +47,10 @@ class CompilationCacheTest : public ClientLibraryTestBase { std::unique_ptr result = client_ ->ExecuteAndTransfer(computation, arguments, - /*execution_options=*/nullptr, + /*execution_options=*/&execution_options_, &execution_profile) .ConsumeValueOrDie(); - LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR0(expected_result), + LiteralTestUtil::ExpectNear(*Literal::CreateR0(expected_result), *result, error_spec_); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); } @@ -62,14 +61,13 @@ class CompilationCacheTest : public ClientLibraryTestBase { std::initializer_list> expected_result, bool expect_cache_hit) { ExecutionProfile execution_profile; - auto data_handle = - client_ - ->Execute(computation, arguments, /*execution_options=*/nullptr, - &execution_profile) - .ConsumeValueOrDie(); + auto data_handle = client_ + ->Execute(computation, arguments, + &execution_options_, &execution_profile) + .ConsumeValueOrDie(); std::unique_ptr result = client_->Transfer(*data_handle).ConsumeValueOrDie(); - LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR2(expected_result), + LiteralTestUtil::ExpectNear(*Literal::CreateR2(expected_result), *result, error_spec_); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); } @@ -89,13 +87,13 @@ XLA_TEST_F(CompilationCacheTest, ComputationCalledMultipleTimes) { XLA_TEST_F(CompilationCacheTest, ComputationCalledWithDifferentParameters) { std::unique_ptr data_42 = - client_->TransferToServer(*LiteralUtil::CreateR0(42.0f)) + client_->TransferToServer(*Literal::CreateR0(42.0f)) .ConsumeValueOrDie(); std::unique_ptr data_123 = - client_->TransferToServer(*LiteralUtil::CreateR0(123.0f)) + client_->TransferToServer(*Literal::CreateR0(123.0f)) .ConsumeValueOrDie(); std::unique_ptr data_456 = - client_->TransferToServer(*LiteralUtil::CreateR0(456.0f)) + client_->TransferToServer(*Literal::CreateR0(456.0f)) .ConsumeValueOrDie(); ComputationBuilder builder(client_, TestName()); @@ -205,7 +203,6 @@ XLA_TEST_F(CompilationCacheTest, MutatedComputation) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index 13c78fb16331340ae9b3586ac47a071230b73a83..4384c9b31495437db10744ea2b98b5b0b05b7ae4 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -86,7 +85,7 @@ class ComputeConstantTest : public ::testing::Test { ComputationBuilder* builder) { TF_ASSIGN_OR_RETURN(auto literal, ComputeConstantLiteral(client, operand, builder)); - return LiteralUtil::Get(*literal, {}); + return literal->Get({}); } bool IsConstant(const ComputationDataHandle& operand, @@ -211,7 +210,7 @@ TEST_F(ComputeConstantTest, NonScalarAdd) { auto computed = ComputeConstantLiteral(client, computation, &b); ASSERT_TRUE(computed.ok()) << computed.status(); std::unique_ptr expected_literal = - LiteralUtil::CreateR1({4, 6}); + Literal::CreateR1({4, 6}); LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); } } @@ -225,7 +224,7 @@ TEST_F(ComputeConstantTest, IntegerDivide) { auto computed = ComputeConstantLiteral(client, computation, &b); ASSERT_TRUE(computed.ok()) << computed.status(); - std::unique_ptr expected_literal = LiteralUtil::CreateR0(5); + std::unique_ptr expected_literal = Literal::CreateR0(5); LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); } } @@ -291,7 +290,6 @@ TEST_F(ComputeConstantTest, DISABLED_ON_CPU(ReuseComputedConstant)) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index a7034930bc9493dfc4931a77c05cf87e4d138173..c5d88ad6a08476731b5b09cb4ae16a3e76bbaf98 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -518,8 +517,8 @@ TEST_P(ConcatR2BinaryTest, DoIt) { // concat XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) { auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {}); - auto x_literal = LiteralUtil::CreateR0(2.f); - auto y_literal = LiteralUtil::CreateR0(3.f); + auto x_literal = Literal::CreateR0(2.f); + auto y_literal = Literal::CreateR0(3.f); auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); @@ -540,9 +539,9 @@ XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) { // produces the correct result in rank 1. XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) { auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {}); - auto x_literal = LiteralUtil::CreateR1({2.0f, 3.0f, 5.0f, 6.0f}); - auto y_literal = LiteralUtil::CreateR0(1.5f); - auto z_literal = LiteralUtil::CreateR0(5.5f); + auto x_literal = Literal::CreateR1({2.0f, 3.0f, 5.0f, 6.0f}); + auto y_literal = Literal::CreateR0(1.5f); + auto z_literal = Literal::CreateR0(5.5f); auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); @@ -568,9 +567,9 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) { XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) { auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {}); Array3D x3d(3, 5, 7, 3.14f); - auto x_literal = LiteralUtil::CreateR3FromArray3D(x3d); - auto y_literal = LiteralUtil::CreateR0(1.5f); - auto z_literal = LiteralUtil::CreateR0(5.5f); + auto x_literal = Literal::CreateR3FromArray3D(x3d); + auto y_literal = Literal::CreateR0(1.5f); + auto z_literal = Literal::CreateR0(5.5f); auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); @@ -607,7 +606,6 @@ INSTANTIATE_TEST_CASE_P(ConcatR2BinaryTestInstantiation, ConcatR2BinaryTest, int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc index 1c065de8ba7663ac2e7b3dcd52298e6587d993f0..7c276c8c8d0c0e97b0dfba7a5d6a6165386e5261 100644 --- a/tensorflow/compiler/xla/tests/constants_test.cc +++ b/tensorflow/compiler/xla/tests/constants_test.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -113,7 +112,7 @@ TEST_F(ConstantsTest, Small_2x2) { TEST_F(ConstantsTest, Empty_3x0x2) { ComputationBuilder builder(client_, TestName()); auto constant = builder.ConstantLiteral( - *LiteralUtil::CreateR3FromArray3D(Array3D(3, 0, 2))); + *Literal::CreateR3FromArray3D(Array3D(3, 0, 2))); ComputeAndCompareR3(&builder, Array3D(3, 0, 2), {}); } @@ -128,8 +127,8 @@ TEST_F(ConstantsTest, Small_2x2x2) { {{5.f, 6.f}, // y0 {7.f, 8.f}}, // y1 }); - auto constant = builder.ConstantLiteral( - *LiteralUtil::CreateR3FromArray3D(array3d)); + auto constant = + builder.ConstantLiteral(*Literal::CreateR3FromArray3D(array3d)); ComputeAndCompareR3(&builder, array3d, {}); } @@ -143,7 +142,7 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { {5.0f, 4.4f}, // p2 }); input_array.FillWithPZ(pz); - Literal input_literal = *LiteralUtil::CreateR4FromArray4D(input_array); + Literal input_literal = *Literal::CreateR4FromArray4D(input_array); { ComputationBuilder builder(client_, TestName()); @@ -161,9 +160,9 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { // TODO(b/29263943): Support tuple constants. TEST_F(ConstantsTest, DISABLED_TupleConstant) { ComputationBuilder builder(client_, TestName()); - builder.ConstantLiteral(*LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0}, {2.0}}).get(), - LiteralUtil::CreateR1({2.0, 42}).get()})); + builder.ConstantLiteral( + *Literal::MakeTuple({Literal::CreateR2({{1.0}, {2.0}}).get(), + Literal::CreateR1({2.0, 42}).get()})); std::unique_ptr result = ExecuteAndTransferOrDie(&builder, {}); @@ -179,7 +178,6 @@ TEST_F(ConstantsTest, DISABLED_TupleConstant) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index 6d3797972507f2c17b545c612c0dd839212e5ae5..2d181938ded0804776847772d4bb58bbc5e334f4 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -70,6 +69,24 @@ TEST_F(ConvertTest, ConvertR1S32ToR1F32) { ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } +TEST_F(ConvertTest, ConvertR1PREDToR1S32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({true, false, true}); + builder.ConvertElementType(a, S32); + + std::vector expected = {1, 0, 1}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(ConvertTest, ConvertR1PREDToR1F32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({true, false, true}); + builder.ConvertElementType(a, F32); + + std::vector expected = {1., 0., 1.}; + ComputeAndCompareR1(&builder, expected, {}); +} + XLA_TEST_F(ConvertTest, ConvertR1S0S32ToR1S0F32) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({}); @@ -197,7 +214,6 @@ TEST_F(ConvertTest, ConvertReshape) { int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc index 0b09416a74771a8a9df804dcae783dc220420fc2..fb50d9b0ebf5b4a6c9d244f699620e2dcb74acaf 100644 --- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -63,8 +62,7 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest, auto weight_array = MakeUnique>(4, 3, 1, 1); weight_array->FillWithMultiples(0.2); auto weight_data = - client_ - ->TransferToServer(*LiteralUtil::CreateR4FromArray4D(*weight_array)) + client_->TransferToServer(*Literal::CreateR4FromArray4D(*weight_array)) .ConsumeValueOrDie(); ComputationBuilder builder(client_, TestName()); @@ -102,7 +100,6 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest, int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index ec19469fa66c16cff3d1349b7ccc1d0de94d0b54..a110082f9a52ded5e836fa835e82f790e05df0e0 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -115,10 +114,10 @@ TEST_F(ConvolutionTest, Convolve_1x1x1x2_1x1x1x2_Valid) { ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kValid); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(input)) + client_->TransferToServer(*Literal::CreateR4FromArray4D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(filter)) + client_->TransferToServer(*Literal::CreateR4FromArray4D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR4(&builder, *aexpected, @@ -158,10 +157,10 @@ TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Valid) { ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kValid); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(input)) + client_->TransferToServer(*Literal::CreateR4FromArray4D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(filter)) + client_->TransferToServer(*Literal::CreateR4FromArray4D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR4(&builder, *aexpected, @@ -201,10 +200,10 @@ TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Same) { ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kSame); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(input)) + client_->TransferToServer(*Literal::CreateR4FromArray4D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(filter)) + client_->TransferToServer(*Literal::CreateR4FromArray4D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR4(&builder, *aexpected, @@ -246,10 +245,10 @@ TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x3x3_Same) { ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kSame); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(input)) + client_->TransferToServer(*Literal::CreateR4FromArray4D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(filter)) + client_->TransferToServer(*Literal::CreateR4FromArray4D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR4(&builder, *aexpected, @@ -273,10 +272,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) { Array3D expected({{{510, 610, 710, 810}}}); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) + client_->TransferToServer(*Literal::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) + client_->TransferToServer(*Literal::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -313,21 +312,18 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); std::iota(input_elems.begin(), input_elems.end(), 1.0f); - auto input_r1 = LiteralUtil::CreateR1(input_elems); - auto input_r5 = - LiteralUtil::Reshape(*input_r1, input_dims).ConsumeValueOrDie(); + auto input_r1 = Literal::CreateR1(input_elems); + auto input_r5 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); std::iota(filter_elems.begin(), filter_elems.end(), 1.0f); - auto filter_r1 = LiteralUtil::CreateR1(filter_elems); - auto filter_r5 = - LiteralUtil::Reshape(*filter_r1, filter_dims).ConsumeValueOrDie(); + auto filter_r1 = Literal::CreateR1(filter_elems); + auto filter_r5 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); - auto expected_r1 = LiteralUtil::CreateR1( + auto expected_r1 = Literal::CreateR1( {19554, 19962, 20370, 22110, 22590, 23070, 34890, 35730, 36570, 37446, 38358, 39270, 50226, 51498, 52770, 52782, 54126, 55470}); - auto expected_r5 = - LiteralUtil::Reshape(*expected_r1, {1, 3, 1, 2, 3}).ConsumeValueOrDie(); + auto expected_r5 = expected_r1->Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie(); auto input_literal = client_->TransferToServer(*input_r5).ConsumeValueOrDie(); auto filter_literal = @@ -344,7 +340,6 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc index b5afc2498dace11c57a7099e9a3d32eb2a387984..c8e74aa01a50042b1e5297920cc184b1eeb51fd3 100644 --- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc @@ -28,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -1312,20 +1311,19 @@ TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding1D) { TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) { ComputationBuilder builder(client_, TestName()); - auto gradients_flat = LiteralUtil::CreateR1({1}); + auto gradients_flat = Literal::CreateR1({1}); auto gradients_literal = - LiteralUtil::Reshape(*gradients_flat, {1, 1, 1, 1, 1}) - .ConsumeValueOrDie(); + gradients_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); auto gradients = builder.ConstantLiteral(*gradients_literal); - auto weights_flat = LiteralUtil::CreateR1({1, 10, 100}); + auto weights_flat = Literal::CreateR1({1, 10, 100}); auto weights_literal = - LiteralUtil::Reshape(*weights_flat, {1, 1, 1, 1, 3}).ConsumeValueOrDie(); + weights_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); auto weights = builder.ConstantLiteral(*weights_literal); - auto expected_flat = LiteralUtil::CreateR1({10}); + auto expected_flat = Literal::CreateR1({10}); auto expected_literal = - LiteralUtil::Reshape(*expected_flat, {1, 1, 1, 1, 1}).ConsumeValueOrDie(); + expected_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); auto mirrored_weights = builder.Rev(weights, {2, 3, 4}); builder.ConvWithGeneralPadding(gradients, mirrored_weights, @@ -1337,21 +1335,19 @@ TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) { TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { ComputationBuilder builder(client_, TestName()); - auto activations_flat = LiteralUtil::CreateR1({1, 2, 3, 4}); + auto activations_flat = Literal::CreateR1({1, 2, 3, 4}); auto activations_literal = - LiteralUtil::Reshape(*activations_flat, {1, 1, 1, 1, 4}) - .ConsumeValueOrDie(); + activations_flat->Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie(); auto activations = builder.ConstantLiteral(*activations_literal); - auto gradients_flat = LiteralUtil::CreateR1({100, 10, 1}); + auto gradients_flat = Literal::CreateR1({100, 10, 1}); auto gradients_literal = - LiteralUtil::Reshape(*gradients_flat, {1, 1, 1, 1, 3}) - .ConsumeValueOrDie(); + gradients_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); auto gradients = builder.ConstantLiteral(*gradients_literal); - auto expected_flat = LiteralUtil::CreateR1({13, 24, 130}); + auto expected_flat = Literal::CreateR1({13, 24, 130}); auto expected_literal = - LiteralUtil::Reshape(*expected_flat, {1, 1, 1, 1, 3}).ConsumeValueOrDie(); + expected_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); auto forward_conv = builder.ConvGeneralDilated( activations, gradients, @@ -1370,7 +1366,6 @@ TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 4c2413d0fe43d486ebf306fc51601467d6ebf7fd..76ae280f1a0f309d9aa159079827a7e2c7e833d7 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -58,39 +57,34 @@ class CopyOpTest : public HloTestBase { tensorflow::gtl::ArraySlice permutation); }; -TEST_F(CopyOpTest, CopyR0Bool) { - TestCopyOp(*LiteralUtil::CreateR0(true)); -} +TEST_F(CopyOpTest, CopyR0Bool) { TestCopyOp(*Literal::CreateR0(true)); } -TEST_F(CopyOpTest, CopyR1S0U32) { - TestCopyOp(*LiteralUtil::CreateR1({})); -} +TEST_F(CopyOpTest, CopyR1S0U32) { TestCopyOp(*Literal::CreateR1({})); } TEST_F(CopyOpTest, CopyR1S3U32) { - TestCopyOp(*LiteralUtil::CreateR1({1, 2, 3})); + TestCopyOp(*Literal::CreateR1({1, 2, 3})); } TEST_F(CopyOpTest, CopyR3F32_2x2x3) { - TestCopyOp( - *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); + TestCopyOp(*Literal::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); } TEST_F(CopyOpTest, CopyR4S32_2x2x3x2) { - TestCopyOp(*LiteralUtil::CreateR4( + TestCopyOp(*Literal::CreateR4( {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); } TEST_F(CopyOpTest, CopyR4S32_0x2x3x2) { - TestCopyOp(*LiteralUtil::CreateR4FromArray4D(Array4D(0, 2, 3, 2))); + TestCopyOp(*Literal::CreateR4FromArray4D(Array4D(0, 2, 3, 2))); } TEST_F(CopyOpTest, CopyParameterScalar) { auto builder = HloComputation::Builder(TestName()); // Copy literal to device to use as parameter. - auto literal = LiteralUtil::CreateR0(42.0); + auto literal = Literal::CreateR0(42.0); Shape shape = literal->shape(); auto constant_device_base = TransferToDevice(*literal); @@ -112,7 +106,7 @@ TEST_F(CopyOpTest, CopyParameterScalar) { TEST_F(CopyOpTest, CopyConstantR2Twice) { auto builder = HloComputation::Builder(TestName()); - auto literal = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto literal = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -134,7 +128,7 @@ TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) { HloComputation::Builder builder(TestName()); std::unique_ptr literal = - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); // Reverse the minor-to-major order of the literal. Layout* literal_layout = literal->mutable_shape()->mutable_layout(); ASSERT_EQ(2, literal_layout->minor_to_major_size()); @@ -170,7 +164,7 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) { HloComputation::Builder builder(TestName()); - std::unique_ptr literal = LiteralUtil::CreateR3FromArray3D(a); + std::unique_ptr literal = Literal::CreateR3FromArray3D(a); HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -204,7 +198,7 @@ void CopyOpTest::TestCopyConstantLayoutR4( HloComputation::Builder builder(TestName()); - std::unique_ptr literal = LiteralUtil::CreateR4FromArray4D(a); + std::unique_ptr literal = Literal::CreateR4FromArray4D(a); HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -247,7 +241,7 @@ using CopyOpClientTest = ClientLibraryTestBase; XLA_TEST_F(CopyOpClientTest, Copy0x0) { Shape in_shape = ShapeUtil::MakeShapeWithLayout(F32, {0, 0}, {0, 1}); Shape out_shape = ShapeUtil::MakeShapeWithLayout(F32, {0, 0}, {1, 0}); - auto empty = LiteralUtil::CreateFromShape(in_shape); + auto empty = Literal::CreateFromShape(in_shape); ComputationBuilder builder(client_, TestName()); auto param0 = builder.Parameter(0, in_shape, "input"); @@ -263,7 +257,6 @@ XLA_TEST_F(CopyOpClientTest, Copy0x0) { int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index 32232acf6e34517587b80d5091dbb9d603223184..73772fdec02fc95cb6c8e0685037515183478e85 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -68,7 +67,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); builder.AddInstruction( HloInstruction::CreateCustomCall(r0f32_, {constant}, "R0F32Add2")); @@ -89,7 +88,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { array(1, 1) = 4.0f; auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D(array))); + HloInstruction::CreateConstant(Literal::CreateR2FromArray2D(array))); builder.AddInstruction( HloInstruction::CreateCustomCall(r0f32_, {constant}, "R2F32ReduceSum")); @@ -105,7 +104,7 @@ XLA_TEST_F(CustomCallTest, auto b = HloComputation::Builder(TestName()); auto input = b.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D( + HloInstruction::CreateConstant(Literal::CreateR2FromArray2D( Array2D{{1.0f, 2.0f}, {3.0f, 4.0f}}))); auto incremented = b.AddInstruction(HloInstruction::CreateCustomCall( ShapeUtil::MakeShape(F32, {1, 2, 2}), {input}, "Add1ToValues")); @@ -129,7 +128,6 @@ XLA_TEST_F(CustomCallTest, int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/deallocation_test.cc b/tensorflow/compiler/xla/tests/deallocation_test.cc index 074753bf6f8f9e64626b9ed2015b94b58dfebc87..0c7c3a8ff6656b05041e672cca97b285a4420446 100644 --- a/tensorflow/compiler/xla/tests/deallocation_test.cc +++ b/tensorflow/compiler/xla/tests/deallocation_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -42,7 +41,8 @@ class DeallocationTest : public ClientLibraryTestBase { tensorflow::gtl::ArraySlice arguments) { Computation computation = builder->Build().ConsumeValueOrDie(); auto global_data = - client_->Execute(computation, arguments).ConsumeValueOrDie(); + client_->Execute(computation, arguments, &execution_options_) + .ConsumeValueOrDie(); TF_CHECK_OK(client_->Transfer(*global_data).status()); return global_data; } @@ -143,7 +143,6 @@ XLA_TEST_F(DeallocationTest, DeallocateNestedTuple) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc index fcddffc1e1340028f11b67cbe14537a240120de7..3d6a995a245c636fd91aee8a71aeac53ada90b1c 100644 --- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc +++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -48,7 +47,8 @@ class DeconstructTupleTest : public ClientLibraryTestBase { tensorflow::gtl::ArraySlice arguments) { Computation computation = builder->Build().ConsumeValueOrDie(); auto global_data = - client_->Execute(computation, arguments).ConsumeValueOrDie(); + client_->Execute(computation, arguments, &execution_options_) + .ConsumeValueOrDie(); TF_CHECK_OK(client_->Transfer(*global_data).status()); return global_data; } @@ -173,7 +173,7 @@ TEST_F(DeconstructTupleTest, DeconstructNonTuple) { XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({3.14f, -100.25f}); + Literal::CreateR1({3.14f, -100.25f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "param0"); @@ -205,7 +205,6 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructNestedTuple) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/deep_graph_test.cc b/tensorflow/compiler/xla/tests/deep_graph_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..60953a7421d410722b499625b4ce4b9ca90aa874 --- /dev/null +++ b/tensorflow/compiler/xla/tests/deep_graph_test.cc @@ -0,0 +1,56 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" + +namespace xla { +namespace { +TEST_F(ClientLibraryTestBase, DeepGraph) { + // TODO(b/62624812): To trigger the stack overflow this test is + // intended to track, we need to set kDepth to 20000. + // Unfortunately, setting it that high causes the test to time out. + const int kDepth = 200; + ComputationBuilder b(client_, TestName()); + ComputationDataHandle x; + ComputationDataHandle y; + auto x_data = CreateR0Parameter(3, 0, "x", &b, &x); + auto y_data = CreateR0Parameter(1, 1, "y", &b, &y); + ComputationDataHandle z = x; + for (int i = 0; i < kDepth; ++i) { + z = b.Add(z, y); + } + ComputeAndCompareR0(&b, /*expected=*/kDepth + 3, + {x_data.get(), y_data.get()}); +} +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 754eec1b1edc286b98d02f70c8e5661523bd85de..63a630f9e58b1ba2afd31a39253f737a950a9287 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -20,8 +20,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/legacy_flags/layout_util_flags.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -186,14 +184,14 @@ void DotOperationTest::TestMatrixDot(int M, int K, int N, bool lhs_row_major, bool rhs_row_major) { std::unique_ptr> lhs_data = MakeLinspaceArray2D(0.0, 1.0, M, K); - std::unique_ptr lhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout( + std::unique_ptr lhs_lit = Literal::CreateR2FromArray2DWithLayout( *lhs_data, LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major))); auto lhs_handle = client_->TransferToServer(*lhs_lit).ConsumeValueOrDie(); std::unique_ptr> rhs_data = MakeLinspaceArray2D(0.0, 1.0, K, N); - std::unique_ptr rhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout( + std::unique_ptr rhs_lit = Literal::CreateR2FromArray2DWithLayout( *rhs_data, LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major))); auto rhs_handle = client_->TransferToServer(*rhs_lit).ConsumeValueOrDie(); @@ -380,12 +378,12 @@ XLA_TEST_F(DotOperationTest, BatchMatMul) { builder.Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2}); auto x_data = client_ - ->TransferToServer(*LiteralUtil::CreateR4( + ->TransferToServer(*Literal::CreateR4( {{{{1000, 100}, {10, 1}}, {{2000, 200}, {20, 2}}}, {{{3000, 300}, {30, 3}}, {{4000, 400}, {40, 4}}}})) .ConsumeValueOrDie(); auto y_data = client_ - ->TransferToServer(*LiteralUtil::CreateR4( + ->TransferToServer(*Literal::CreateR4( {{{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}, {{{11, 22}, {33, 44}}, {{55, 66}, {77, 88}}}})) .ConsumeValueOrDie(); @@ -416,14 +414,14 @@ TEST_F(DotOperationTest, TransposeFolding) { auto lhs_handle = client_ ->TransferToServer( - *LiteralUtil::CreateR2FromArray2DWithLayout( + *Literal::CreateR2FromArray2DWithLayout( *lhs, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(row_major)))) .ConsumeValueOrDie(); auto rhs_handle = client_ ->TransferToServer( - *LiteralUtil::CreateR2FromArray2DWithLayout( + *Literal::CreateR2FromArray2DWithLayout( *rhs, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(row_major)))) .ConsumeValueOrDie(); @@ -462,8 +460,6 @@ int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendLayoutUtilFlags(&flag_list); xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuRuntimeFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index b7bb1792f3b9b96fea5f446c787eb55e2577b01b..f653766f39d3b38d06b08fd98b1dd237e990e12e 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -389,8 +388,8 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { template void DumpArray(const string& name, const Array3D values) { std::unique_ptr literal = - LiteralUtil::CreateR3FromArray3D(values); - LOG(INFO) << name << ":" << LiteralUtil::ToString(*literal); + Literal::CreateR3FromArray3D(values); + LOG(INFO) << name << ":" << literal->ToString(); } }; @@ -470,7 +469,7 @@ void BM_DynamicSlice(int num_iters) { ComputationBuilder builder(client, "DynamicSlice"); // Create input as a constant: shape [1, 2, 3, 4] - auto input_literal = LiteralUtil::CreateR4( + auto input_literal = Literal::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); auto input = builder.ConstantLiteral(*input_literal); @@ -488,7 +487,7 @@ void BM_DynamicSlice(int num_iters) { &allocator, 0) .ConsumeValueOrDie(); - auto start_indices_literal = LiteralUtil::CreateR1({0, 1, 2, 3}); + auto start_indices_literal = Literal::CreateR1({0, 1, 2, 3}); ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( executors[device_ordinal], *start_indices_literal, buffer->mutable_buffer({}))); @@ -521,7 +520,6 @@ BENCHMARK(BM_DynamicSlice); int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/floor_ceil_test.cc b/tensorflow/compiler/xla/tests/floor_ceil_test.cc index 80267e5459d2ab12e3530110c0def699b7695351..90c5aa65592302e076821aaaeaa701ae40c07a6c 100644 --- a/tensorflow/compiler/xla/tests/floor_ceil_test.cc +++ b/tensorflow/compiler/xla/tests/floor_ceil_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -114,7 +113,6 @@ TEST_F(FloorCeilTest, R0Ceil) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/fmax_test.cc b/tensorflow/compiler/xla/tests/fmax_test.cc index ee4e92505d9dd1f880473f1e76e5be3f01a1cfb3..9c86c65e5bb5b90072f79f5dee1923fa92b36e21 100644 --- a/tensorflow/compiler/xla/tests/fmax_test.cc +++ b/tensorflow/compiler/xla/tests/fmax_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -47,7 +46,6 @@ TEST_F(FmaxSimpleTest, FmaxTenValues) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index fa36381267e80e3afe693a4d85152d2367956be3..7803d234fdfe3330e0a53f7739c2c9117854c67a 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -20,7 +20,9 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -29,7 +31,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -37,10 +41,13 @@ limitations under the License. #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/types.h" using tensorflow::gtl::ArraySlice; +namespace se = ::perftools::gputools; + namespace xla { namespace { @@ -81,7 +88,7 @@ class FusionTest : public HloTestBase { HloInstruction* hlos[4]; for (int i = 0; i < Arity; ++i) { hlos[i + 1] = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2FromArray2D(operand_data[i]))); + Literal::CreateR2FromArray2D(operand_data[i]))); } auto answer_shape = ShapeUtil::MakeShape(prim_type, {test_width, test_height}); @@ -107,7 +114,7 @@ class FusionTest : public HloTestBase { ArraySlice(hlos, 0, Arity + 1), HloInstruction::FusionKind::kLoop); - auto expected = LiteralUtil::CreateR2FromArray2D(answer_data); + auto expected = Literal::CreateR2FromArray2D(answer_data); auto actual = ExecuteAndTransfer(std::move(hlo_module), {}); if (primitive_util::IsFloatingPointType(prim_type)) { LiteralTestUtil::ExpectNear(*expected, *actual, ErrorSpec(1e-4)); @@ -178,28 +185,27 @@ XLA_TEST_F(FusionTest, Test) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0}, {2.0}, {3.0}}))); + Literal::CreateR2({{1.0}, {2.0}, {3.0}}))); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{-1.0}, {-1.0}, {-1.0}}))); + Literal::CreateR2({{-1.0}, {-1.0}, {-1.0}}))); auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(F32, {3, 1}), HloOpcode::kAdd, const0, const1)); auto reshape3 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {1, 3}), add2, {1, 0})); auto const4 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.62, 2.72, 3.14}}))); + Literal::CreateR2({{1.62, 2.72, 3.14}}))); auto concat5 = builder.AddInstruction(HloInstruction::CreateConcatenate( ShapeUtil::MakeShape(F32, {2, 3}), {reshape3, const4}, 0)); auto const6 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}))); + Literal::CreateR2({{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}))); auto negate7 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kNegate, const6)); auto add8 = builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kAdd, concat5, negate7)); auto const9 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}}))); - auto const10 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR2( - {{true, false, true}, {false, true, false}}))); + Literal::CreateR2({{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}}))); + auto const10 = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{true, false, true}, {false, true, false}}))); auto select11 = builder.AddInstruction( HloInstruction::CreateTernary(ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kSelect, const10, add8, const9)); @@ -214,7 +220,7 @@ XLA_TEST_F(FusionTest, Test) { const4, reshape3, add2, const1, const0}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR2({{0.5}, {2.72}}), + LiteralTestUtil::ExpectNear(*Literal::CreateR2({{0.5}, {2.72}}), *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)); } @@ -226,11 +232,11 @@ XLA_TEST_F(FusionTest, Parameter) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 2.0, 3.0}}))); + Literal::CreateR2({{1.0, 2.0, 3.0}}))); auto copy1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kCopy, const0)); auto const2 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{-2.0, -2.0, -2.0}}))); + Literal::CreateR2({{-2.0, -2.0, -2.0}}))); // add3 = copy1 + const2 = const0 + const2 = {1,2,3} + {-2,-2,-2} = {-1,0,+1} auto add3 = builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kAdd, copy1, const2)); @@ -240,7 +246,7 @@ XLA_TEST_F(FusionTest, Parameter) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{add3, const2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR2({{-1.0, 0.0, 1.0}}), + LiteralTestUtil::ExpectNear(*Literal::CreateR2({{-1.0, 0.0, 1.0}}), *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)); } @@ -249,9 +255,9 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const_vector = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); + Literal::CreateR1({1.0, 2.0, 3.0}))); auto const_array = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}}))); + Literal::CreateR2({{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}}))); auto broadcast = builder.AddInstruction( HloInstruction::CreateBroadcast(const_array->shape(), const_vector, {1})); // add2 = broadcast(const_vector) + const_array @@ -265,7 +271,7 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { HloInstruction::FusionKind::kLoop); LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR2({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}), + *Literal::CreateR2({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}), *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)); } @@ -273,13 +279,13 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto single_element_array = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR2({{5}}))); + HloInstruction::CreateConstant(Literal::CreateR2({{5}}))); auto reshape = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {}), single_element_array)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR0(5), + LiteralTestUtil::ExpectEqual(*Literal::CreateR0(5), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -287,14 +293,14 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}))); + Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1, 2, 3}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); LiteralTestUtil::ExpectEqual( - *LiteralUtil::CreateR3({{{1, 2, 3}, {4, 5, 6}}}), + *Literal::CreateR3({{{1, 2, 3}, {4, 5, 6}}}), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -302,14 +308,14 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR3({{{1, 2, 3}, {4, 5, 6}}}))); + Literal::CreateR3({{{1, 2, 3}, {4, 5, 6}}}))); auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 2}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); LiteralTestUtil::ExpectEqual( - *LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}), + *Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -317,13 +323,13 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR3({{{7}}}))); + HloInstruction::CreateConstant(Literal::CreateR3({{{7}}}))); auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR0(7), + LiteralTestUtil::ExpectEqual(*Literal::CreateR0(7), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -331,13 +337,13 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(7))); + HloInstruction::CreateConstant(Literal::CreateR0(7))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1, 1, 1}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR3({{{7}}}), + LiteralTestUtil::ExpectEqual(*Literal::CreateR3({{{7}}}), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -345,13 +351,13 @@ XLA_TEST_F(FusionTest, Reshape__) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(7))); + HloInstruction::CreateConstant(Literal::CreateR0(7))); auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR0(7), + LiteralTestUtil::ExpectEqual(*Literal::CreateR0(7), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -359,14 +365,14 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 3}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); LiteralTestUtil::ExpectEqual( - *LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}), + *Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -374,14 +380,14 @@ XLA_TEST_F(FusionTest, Transpose_2by3) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {3, 2}), const0, {1, 0})); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); LiteralTestUtil::ExpectEqual( - *LiteralUtil::CreateR2({{1, 4}, {2, 5}, {3, 6}}), + *Literal::CreateR2({{1, 4}, {2, 5}, {3, 6}}), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -389,14 +395,14 @@ XLA_TEST_F(FusionTest, Transpose_3by3) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {3, 3}), const0, {1, 0})); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); LiteralTestUtil::ExpectEqual( - *LiteralUtil::CreateR2({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}), + *Literal::CreateR2({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -404,14 +410,14 @@ XLA_TEST_F(FusionTest, Reverse) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3}))); + HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3}))); auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse( ShapeUtil::MakeShape(S32, {3}), const0, {0})); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reverse1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR1({3, 2, 1}), + LiteralTestUtil::ExpectEqual(*Literal::CreateR1({3, 2, 1}), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -430,10 +436,10 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { auto hlo_module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1, 2, 4, 8}))); + auto const0 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 4, 8}))); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + HloInstruction::CreateConstant(Literal::CreateR0(0))); auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce( ShapeUtil::MakeShape(S32, {}), const0, const1, {0}, hlo_module->AddEmbeddedComputation(MakeReduceTestComputation()))); @@ -441,7 +447,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR0(15), + LiteralTestUtil::ExpectEqual(*Literal::CreateR0(15), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -449,10 +455,10 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { auto hlo_module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1, 2, 4, 8}))); + auto const0 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 4, 8}))); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + HloInstruction::CreateConstant(Literal::CreateR0(0))); auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce( ShapeUtil::MakeShape(S32, {}), const0, const1, {0}, hlo_module->AddEmbeddedComputation(MakeReduceTestComputation()))); @@ -462,7 +468,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate3, reduce2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR1({-15}), + LiteralTestUtil::ExpectEqual(*Literal::CreateR1({-15}), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -470,9 +476,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}}))); + Literal::CreateR2({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}}))); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + HloInstruction::CreateConstant(Literal::CreateR0(1))); Window window; ASSERT_TRUE( tensorflow::protobuf::TextFormat::ParseFromString("dimensions:{\n" @@ -512,7 +518,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { HloInstruction::FusionKind::kLoop); LiteralTestUtil::ExpectEqual( - *LiteralUtil::CreateR2({{462, 2145}, {24871, 62491}}), + *Literal::CreateR2({{462, 2145}, {24871, 62491}}), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -568,12 +574,66 @@ XLA_TEST_F(FusionTest, Clamp2D) { TestElementwise2D(HloOpcode::kClamp); } +void BM_ParallelFusion(int num_iters) { + // Simple element-wise computation to benchmark parallel task partitioning. + tensorflow::testing::StopTiming(); + + se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie(); + auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie(); + StreamExecutorMemoryAllocator allocator(platform, executors); + + const int64 intra_op_parallelism_threads = 16; + xla::LocalClientOptions client_options; + client_options.set_platform(platform); + client_options.set_intra_op_parallelism_threads(intra_op_parallelism_threads); + auto client = + ClientLibrary::GetOrCreateLocalClient(client_options).ValueOrDie(); + + const int64 dim_size = 1024; + // Create a simple fusable elementwise computation. + ComputationBuilder builder(client, "ParallelFusion"); + Shape input_shape = ShapeUtil::MakeShape(F32, {dim_size, dim_size}); + auto input0 = builder.Broadcast(builder.ConstantR0(1.5f), + AsInt64Slice(input_shape.dimensions())); + auto input1 = builder.Broadcast(builder.ConstantR0(2.0f), + AsInt64Slice(input_shape.dimensions())); + auto input2 = builder.Broadcast(builder.ConstantR0(3.0f), + AsInt64Slice(input_shape.dimensions())); + auto x = builder.Mul(input0, input1); + auto y = builder.Add(x, input2); + auto computation = builder.Build().ConsumeValueOrDie(); + + std::unique_ptr executable = + client->Compile(computation, {}, ExecutableBuildOptions()) + .ConsumeValueOrDie(); + + // Run some warm-up executions. + ExecutableRunOptions options; + options.set_allocator(&allocator); + const int kWarmups = 2; + for (int i = 0; i < kWarmups; ++i) { + auto result = executable->Run({}, options); + ASSERT_TRUE(result.ok()); + } + + // Run benchmark. + tensorflow::testing::BytesProcessed(static_cast(num_iters) * dim_size * + dim_size * sizeof(float)); + tensorflow::testing::UseRealTime(); + tensorflow::testing::StartTiming(); + for (int i = 0; i < num_iters; ++i) { + auto result = executable->Run({}, options); + ASSERT_TRUE(result.ok()); + } +} + +BENCHMARK(BM_ParallelFusion); + } // namespace } // namespace xla int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); @@ -586,5 +646,6 @@ int main(int argc, char** argv) { LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; return 2; } + tensorflow::testing::RunBenchmarks(); return RUN_ALL_TESTS(); } diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 5f7b7aa434e29980a7d813dfb57f3b7988ed6e6d..8149e2b7cc72018ef8deb61305bb61ceb77200f9 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -24,14 +24,12 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" -#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_layout.h" @@ -56,17 +54,6 @@ struct HloTestBase::EigenThreadPoolWrapper { HloTestBase::HloTestBase() : backend_(Backend::CreateDefaultBackend().ConsumeValueOrDie()) { - // TODO(b/62411181): get rid of this flag entirely when the usual debug flags - // are piped to all HLO tests. - test_hlo_dumper_ = [](const HloModule& module, const string& label) { - legacy_flags::HloTestBaseFlags* flags = legacy_flags::GetHloTestBaseFlags(); - if (flags->xla_hlo_test_generate_hlo_graph) { - const bool show_addresses = true; - const bool show_layouts = true; - hlo_graph_dumper::DumpGraph(*module.entry_computation(), label, - show_addresses, show_layouts); - } - }; VLOG(1) << "executing on platform " << backend_->platform()->Name(); } @@ -77,9 +64,16 @@ HloTestBase::~HloTestBase() { } } +/* static */ std::unique_ptr HloTestBase::CreateNewModule() { HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + + auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); + // TODO(b/38354253): Change tests to use Parameters instead of Constants. + debug_options.add_xla_disable_hlo_passes("constant_folding"); + + config.set_debug_options(debug_options); + return MakeUnique(TestName(), VersionedComputationHandle(), config); } @@ -91,7 +85,7 @@ StatusOr HloTestBase::Execute( Shape* result_shape) { TF_ASSIGN_OR_RETURN( std::unique_ptr executable, - backend_->compiler()->Compile(std::move(module), test_hlo_dumper_, + backend_->compiler()->Compile(std::move(module), backend_->default_stream_executor())); se::Stream stream(backend_->default_stream_executor()); diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 98bc35ae528970e262740631b283b7dbb6d01538..7f3d163290aba3cfcea1b3204e6c88134e172ed7 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -48,7 +48,7 @@ class HloTestBase : public ::testing::Test { // TestName() for its name; it will also automatically populate its debug // options from command-line flags. It's recommended to use this method to // create all HloModules for tests. - std::unique_ptr CreateNewModule(); + static std::unique_ptr CreateNewModule(); // Executes the given module and returns a global data handle. StatusOr Execute( @@ -104,8 +104,6 @@ class HloTestBase : public ::testing::Test { std::unique_ptr backend_; - Compiler::HloDumper test_hlo_dumper_; - // This vector contains handles of all the device memory allocations performed // by the test. These are deallocated on destruction of the test object. std::vector allocations_; diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index eb979ad189db7b238ae6cc393d84d0c6c9fc27d1..cca04e17ca08c061bdffef7ef2d006d7efa40b7a 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -41,20 +41,25 @@ namespace xla { /* static */ void LiteralTestUtil::AssertEqualShapes(const Shape& expected, const Shape& actual) { - ASSERT_EQ(ShapeUtil::Rank(expected), ShapeUtil::Rank(actual)); - ASSERT_EQ(expected.element_type(), actual.element_type()) - << PrimitiveType_Name(expected.element_type()) << " vs " - << PrimitiveType_Name(actual.element_type()); - ASSERT_EQ(expected.dimensions_size(), actual.dimensions_size()); - for (int i = 0; i < expected.dimensions_size(); ++i) { - ASSERT_EQ(expected.dimensions(i), actual.dimensions(i)) - << "mismatch in dimension #" << i - << " expected: " << ShapeUtil::HumanString(expected) - << " actual: " << ShapeUtil::HumanString(actual); - } - ASSERT_EQ(expected.tuple_shapes_size(), actual.tuple_shapes_size()); - for (int i = 0; i < expected.tuple_shapes_size(); ++i) { - AssertEqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i)); + ASSERT_EQ(ShapeUtil::IsTuple(expected), ShapeUtil::IsTuple(actual)); + if (ShapeUtil::IsTuple(expected)) { + ASSERT_EQ(ShapeUtil::TupleElementCount(expected), + ShapeUtil::TupleElementCount(actual)); + for (int i = 0; i < expected.tuple_shapes_size(); ++i) { + AssertEqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i)); + } + } else { + ASSERT_EQ(ShapeUtil::Rank(expected), ShapeUtil::Rank(actual)); + ASSERT_EQ(expected.element_type(), actual.element_type()) + << PrimitiveType_Name(expected.element_type()) << " vs " + << PrimitiveType_Name(actual.element_type()); + ASSERT_EQ(expected.dimensions_size(), actual.dimensions_size()); + for (int i = 0; i < expected.dimensions_size(); ++i) { + ASSERT_EQ(expected.dimensions(i), actual.dimensions(i)) + << "mismatch in dimension #" << i + << " expected: " << ShapeUtil::HumanString(expected) + << " actual: " << ShapeUtil::HumanString(actual); + } } } @@ -128,8 +133,8 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, tensorflow::gtl::MutableArraySlice multi_index, int64 dimension) { if (dimension == expected.shape().dimensions_size()) { - NativeT expected_value = LiteralUtil::Get(expected, multi_index); - NativeT actual_value = LiteralUtil::Get(actual, multi_index); + NativeT expected_value = expected.Get(multi_index); + NativeT actual_value = actual.Get(multi_index); ::testing::AssertionResult result = CompareEqual(expected_value, actual_value); return result; // Defines implicit coersion to bool. @@ -148,10 +153,10 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, /* static */ void LiteralTestUtil::ExpectEqual(const Literal& expected, const Literal& actual) { - EXPECT_TRUE(Equal(expected, actual)) << "expected:\n" - << LiteralUtil::ToString(expected) - << "\n\tvs actual:\n" - << LiteralUtil::ToString(actual); + EXPECT_TRUE(Equal(expected, actual)) + << "expected:\n" + << expected.ToString() << "\n\tvs actual:\n" + << actual.ToString(); } /* static */ void LiteralTestUtil::ExpectNotEqual(const Literal& expected, @@ -161,8 +166,8 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, /* static */ ::testing::AssertionResult LiteralTestUtil::Equal( const Literal& expected, const Literal& actual) { - VLOG(1) << "expected: " << LiteralUtil::ToString(expected); - VLOG(1) << "actual: " << LiteralUtil::ToString(actual); + VLOG(1) << "expected: " << expected.ToString(); + VLOG(1) << "actual: " << actual.ToString(); AssertEqualShapes(expected.shape(), actual.shape()); std::vector multi_index(expected.shape().dimensions_size(), 0); @@ -210,8 +215,8 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, ::testing::AssertionResult result = ::testing::AssertionSuccess(); if (!match) { result = ::testing::AssertionFailure() - << "expected: " << LiteralUtil::ToString(expected) - << "\nactual: " << LiteralUtil::ToString(actual); + << "expected: " << expected.ToString() + << "\nactual: " << actual.ToString(); VLOG(1) << result.message(); } return result; @@ -219,8 +224,8 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, /* static */ void LiteralTestUtil::ExpectEqualTuple(const Literal& expected, const Literal& actual) { - VLOG(1) << "expected: " << LiteralUtil::ToString(expected); - VLOG(1) << "actual: " << LiteralUtil::ToString(actual); + VLOG(1) << "expected: " << expected.ToString(); + VLOG(1) << "actual: " << actual.ToString(); ASSERT_TRUE(ShapeUtil::IsTuple(expected.shape())); ASSERT_TRUE(ShapeUtil::IsTuple(actual.shape())); @@ -247,8 +252,8 @@ class NearComparator { // within the error bound. Emits useful log messages and dumps literals to // temporary files on failure. Returns true if literals match. bool ExpectNear(const Literal& expected, const Literal& actual) { - VLOG(1) << "expected: " << LiteralUtil::ToString(expected); - VLOG(1) << "actual: " << LiteralUtil::ToString(actual); + VLOG(1) << "expected: " << expected.ToString(); + VLOG(1) << "actual: " << actual.ToString(); LiteralTestUtil::AssertEqualShapes(expected.shape(), actual.shape()); @@ -282,9 +287,9 @@ class NearComparator { if (num_miscompares_ > 0) { if (!VLOG_IS_ON(1)) { LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected.shape()) - << " " << LiteralUtil::ToString(expected); + << " " << expected.ToString(); LOG(INFO) << "actual: " << ShapeUtil::HumanString(actual.shape()) - << " " << LiteralUtil::ToString(actual); + << " " << actual.ToString(); } EXPECT_TRUE(num_miscompares_ == 0) << "\nmax relative mismatch at index " @@ -369,10 +374,9 @@ class NearComparator { void ExpectLiteralsNear(const Literal& expected, const Literal& actual, int64 dimension) { if (dimension == expected.shape().dimensions_size()) { - bool near = - ExpectValuesNear(LiteralUtil::Get(expected, multi_index_), - LiteralUtil::Get(actual, multi_index_)); - LiteralUtil::Set(&miscompares_, multi_index_, !near); + bool near = ExpectValuesNear(expected.Get(multi_index_), + actual.Get(multi_index_)); + miscompares_.Set(multi_index_, !near); } else { for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) { multi_index_[dimension] = i; @@ -437,8 +441,8 @@ class NearComparator { /* static */ ::testing::AssertionResult LiteralTestUtil::NearTuple( const Literal& expected, const Literal& actual, const ErrorSpec& error) { - VLOG(1) << "expected: " << LiteralUtil::ToString(expected); - VLOG(1) << "actual: " << LiteralUtil::ToString(actual); + VLOG(1) << "expected: " << expected.ToString(); + VLOG(1) << "actual: " << actual.ToString(); if (!ShapeUtil::IsTuple(expected.shape()) || !ShapeUtil::IsTuple(actual.shape())) { @@ -504,8 +508,7 @@ class NearComparator { *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); // Allocate space in the new literal. - LiteralUtil::Reserve(ShapeUtil::ElementsIn(literal.shape()), - new_literal.get()); + new_literal->Reserve(ShapeUtil::ElementsIn(literal.shape())); // Copy data into new literal, element-by-element. for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) { @@ -515,44 +518,36 @@ class NearComparator { IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i); switch (literal.shape().element_type()) { case PRED: - LiteralUtil::Set( - new_literal.get(), to_multi_index, - LiteralUtil::Get(literal, from_multi_index)); + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); break; case U8: - LiteralUtil::Set( - new_literal.get(), to_multi_index, - LiteralUtil::Get(literal, from_multi_index)); + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); break; case U32: - LiteralUtil::Set( - new_literal.get(), to_multi_index, - LiteralUtil::Get(literal, from_multi_index)); + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); break; case S32: - LiteralUtil::Set( - new_literal.get(), to_multi_index, - LiteralUtil::Get(literal, from_multi_index)); + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); break; case U64: - LiteralUtil::Set( - new_literal.get(), to_multi_index, - LiteralUtil::Get(literal, from_multi_index)); + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); break; case S64: - LiteralUtil::Set( - new_literal.get(), to_multi_index, - LiteralUtil::Get(literal, from_multi_index)); + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); break; case F32: - LiteralUtil::Set( - new_literal.get(), to_multi_index, - LiteralUtil::Get(literal, from_multi_index)); + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); break; case F64: - LiteralUtil::Set( - new_literal.get(), to_multi_index, - LiteralUtil::Get(literal, from_multi_index)); + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); break; default: LOG(FATAL) << "Unhandled primitive element type: " diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index a8b07a2c5d13e93d068cd475cb96a727c8346cc5..0def25f34e535f0680b231f1d3862fc338e8840f 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -210,20 +210,20 @@ class LiteralTestUtil { template /* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected, const Literal& actual) { - ExpectEqual(*LiteralUtil::CreateR0(expected), actual); + ExpectEqual(*Literal::CreateR0(expected), actual); } template /* static */ void LiteralTestUtil::ExpectR1Equal( tensorflow::gtl::ArraySlice expected, const Literal& actual) { - ExpectEqual(*LiteralUtil::CreateR1(expected), actual); + ExpectEqual(*Literal::CreateR1(expected), actual); } template /* static */ void LiteralTestUtil::ExpectR2Equal( std::initializer_list> expected, const Literal& actual) { - ExpectEqual(*LiteralUtil::CreateR2(expected), actual); + ExpectEqual(*Literal::CreateR2(expected), actual); } template @@ -231,46 +231,46 @@ template std::initializer_list>> expected, const Literal& actual) { - ExpectEqual(*LiteralUtil::CreateR3(expected), actual); + ExpectEqual(*Literal::CreateR3(expected), actual); } template /* static */ void LiteralTestUtil::ExpectR2EqualArray2D( const Array2D& expected, const Literal& actual) { - ExpectEqual(*LiteralUtil::CreateR2FromArray2D(expected), actual); + ExpectEqual(*Literal::CreateR2FromArray2D(expected), actual); } template /* static */ void LiteralTestUtil::ExpectR3EqualArray3D( const Array3D& expected, const Literal& actual) { - ExpectEqual(*LiteralUtil::CreateR3FromArray3D(expected), actual); + ExpectEqual(*Literal::CreateR3FromArray3D(expected), actual); } template /* static */ void LiteralTestUtil::ExpectR4EqualArray4D( const Array4D& expected, const Literal& actual) { - ExpectEqual(*LiteralUtil::CreateR4FromArray4D(expected), actual); + ExpectEqual(*Literal::CreateR4FromArray4D(expected), actual); } template /* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected, const Literal& actual, const ErrorSpec& error) { - ExpectNear(*LiteralUtil::CreateR0(expected), actual, error); + ExpectNear(*Literal::CreateR0(expected), actual, error); } template /* static */ void LiteralTestUtil::ExpectR1Near( tensorflow::gtl::ArraySlice expected, const Literal& actual, const ErrorSpec& error) { - ExpectNear(*LiteralUtil::CreateR1(expected), actual, error); + ExpectNear(*Literal::CreateR1(expected), actual, error); } template /* static */ void LiteralTestUtil::ExpectR2Near( std::initializer_list> expected, const Literal& actual, const ErrorSpec& error) { - ExpectNear(*LiteralUtil::CreateR2(expected), actual, error); + ExpectNear(*Literal::CreateR2(expected), actual, error); } template @@ -278,28 +278,28 @@ template std::initializer_list>> expected, const Literal& actual, const ErrorSpec& error) { - ExpectNear(*LiteralUtil::CreateR3(expected), actual, error); + ExpectNear(*Literal::CreateR3(expected), actual, error); } template /* static */ void LiteralTestUtil::ExpectR2NearArray2D( const Array2D& expected, const Literal& actual, const ErrorSpec& error) { - ExpectNear(*LiteralUtil::CreateR2FromArray2D(expected), actual, error); + ExpectNear(*Literal::CreateR2FromArray2D(expected), actual, error); } template /* static */ void LiteralTestUtil::ExpectR3NearArray3D( const Array3D& expected, const Literal& actual, const ErrorSpec& error) { - ExpectNear(*LiteralUtil::CreateR3FromArray3D(expected), actual, error); + ExpectNear(*Literal::CreateR3FromArray3D(expected), actual, error); } template /* static */ void LiteralTestUtil::ExpectR4NearArray4D( const Array4D& expected, const Literal& actual, const ErrorSpec& error) { - ExpectNear(*LiteralUtil::CreateR4FromArray4D(expected), actual, error); + ExpectNear(*Literal::CreateR4FromArray4D(expected), actual, error); } template @@ -309,9 +309,9 @@ LiteralTestUtil::CreateRandomLiteral( const std::function)>& generator) { using NativeT = typename primitive_util::PrimitiveTypeToNative::type; TF_RET_CHECK(shape.element_type() == type); - std::unique_ptr literal = LiteralUtil::CreateFromShape(shape); - TF_RETURN_IF_ERROR(LiteralUtil::Populate( - literal.get(), [&](tensorflow::gtl::ArraySlice indexes) { + std::unique_ptr literal = Literal::CreateFromShape(shape); + TF_RETURN_IF_ERROR(literal.get()->Populate( + [&](tensorflow::gtl::ArraySlice indexes) { return generator(indexes); })); return std::move(literal); diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index a94f45f73b7d058d6b82f91967f61624a28fea3d..2acf27ed390b0732ba40fcf505c746bd7d8b651e 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -31,9 +31,8 @@ namespace xla { namespace { TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) { - std::unique_ptr literal = LiteralUtil::MakeTuple({ - LiteralUtil::CreateR0(42).get(), - LiteralUtil::CreateR0(64).get(), + std::unique_ptr literal = Literal::MakeTuple({ + Literal::CreateR0(42).get(), Literal::CreateR0(64).get(), }); LiteralTestUtil::ExpectEqual(*literal, *literal); } @@ -43,13 +42,11 @@ TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) { // un-fail an assertion failure. The CHECK-failure is death, so we can make a // death assertion. auto unequal_things_are_equal = [] { - std::unique_ptr lhs = LiteralUtil::MakeTuple({ - LiteralUtil::CreateR0(42).get(), - LiteralUtil::CreateR0(64).get(), + std::unique_ptr lhs = Literal::MakeTuple({ + Literal::CreateR0(42).get(), Literal::CreateR0(64).get(), }); - std::unique_ptr rhs = LiteralUtil::MakeTuple({ - LiteralUtil::CreateR0(64).get(), - LiteralUtil::CreateR0(42).get(), + std::unique_ptr rhs = Literal::MakeTuple({ + Literal::CreateR0(64).get(), Literal::CreateR0(42).get(), }); CHECK(LiteralTestUtil::Equal(*lhs, *rhs)) << "LHS and RHS are unequal"; }; @@ -58,8 +55,8 @@ TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) { TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { auto dummy_lambda = [] { - auto two = LiteralUtil::CreateR0(2); - auto four = LiteralUtil::CreateR0(4); + auto two = Literal::CreateR0(2); + auto four = Literal::CreateR0(4); ErrorSpec error(0.001); CHECK(LiteralTestUtil::Near(*two, *four, error)) << "two is not near four"; }; @@ -88,11 +85,11 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { &literal_proto)); Literal literal(literal_proto); if (result.find("expected") != string::npos) { - EXPECT_EQ("2", LiteralUtil::ToString(literal)); + EXPECT_EQ("2", literal.ToString()); } else if (result.find("actual") != string::npos) { - EXPECT_EQ("4", LiteralUtil::ToString(literal)); + EXPECT_EQ("4", literal.ToString()); } else if (result.find("miscompares") != string::npos) { - EXPECT_EQ("true", LiteralUtil::ToString(literal)); + EXPECT_EQ("true", literal.ToString()); } else { FAIL() << "unknown file in temporary directory: " << result; } diff --git a/tensorflow/compiler/xla/tests/log_test.cc b/tensorflow/compiler/xla/tests/log_test.cc index 796f43ea4edc2c4858eb85c7fa8a16bbe8401a4b..4cb383a78dfed8a4867f4b589c6c32db345dfc9f 100644 --- a/tensorflow/compiler/xla/tests/log_test.cc +++ b/tensorflow/compiler/xla/tests/log_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -62,7 +61,6 @@ TEST_F(LogTest, LogTenValues) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc index e4dbd6864a325546fabd88b56acf341b99cb73c8..47a8acbf4ab76758d8387e84eb271c130aba5a64 100644 --- a/tensorflow/compiler/xla/tests/map_test.cc +++ b/tensorflow/compiler/xla/tests/map_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -170,7 +169,7 @@ class MapTest : public ClientLibraryTestBase { TEST_F(MapTest, MapEachElemPlusOneR0) { // Applies lambda (x) (+ x 1)) to an input scalar. ComputationBuilder builder(client_, TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(42.0); + std::unique_ptr param0_literal = Literal::CreateR0(42.0); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -184,7 +183,7 @@ TEST_F(MapTest, MapEachElemPlusOneR0) { XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0. ComputationBuilder builder(client_, TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); + std::unique_ptr param0_literal = Literal::CreateR1({}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -199,7 +198,7 @@ TEST_F(MapTest, MapEachElemPlusOneR1S4) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4. ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -213,7 +212,7 @@ TEST_F(MapTest, MapEachElemPlusOneR1S4) { TEST_F(MapTest, MapEachF32ElementToS32Constant) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -226,7 +225,7 @@ TEST_F(MapTest, MapEachF32ElementToS32Constant) { TEST_F(MapTest, MapEachF32ElementToU32Constant) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -240,7 +239,7 @@ TEST_F(MapTest, MapEachElemLongerChainR1) { // Maps (lambda (x) (* (+ x 1) x)) onto an input R1F32 vector. ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f}); + Literal::CreateR1({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -256,7 +255,7 @@ XLA_TEST_F(MapTest, MapMultipleMapsR1S0) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0, and then // maps (lambda (x) (* x 2)) on the result. ComputationBuilder builder(client_, TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); + std::unique_ptr param0_literal = Literal::CreateR1({}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -273,7 +272,7 @@ TEST_F(MapTest, MapMultipleMapsR1S4) { // maps (lambda (x) (* x 2)) on the result. ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -288,7 +287,7 @@ TEST_F(MapTest, MapMultipleMapsR1S4) { TEST_F(MapTest, MapEachElemPlusOneR2) { // Maps (lambda (x) (+ x 1)) onto an input R2F32 vector. ComputationBuilder builder(client_, TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR2( + std::unique_ptr param0_literal = Literal::CreateR2( {{13.25f, 14.0f}, {-7.1f, -7.2f}, {-8.8f, 8.8f}}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -385,11 +384,11 @@ TEST_F(MapTest, MapBinaryAdder) { // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors. ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); + Literal::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); @@ -434,12 +433,12 @@ XLA_TEST_F(MapTest, AddWithMixedLayouts) { XLA_TEST_F(MapTest, AddR3_3x0x2) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR3FromArray3D(Array3D(3, 0, 2)); + Literal::CreateR3FromArray3D(Array3D(3, 0, 2)); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - LiteralUtil::CreateR3FromArray3D(Array3D(3, 0, 2)); + Literal::CreateR3FromArray3D(Array3D(3, 0, 2)); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); @@ -456,15 +455,15 @@ TEST_F(MapTest, MapTernaryAdder) { // Maps (lambda (x y z) (+ x y z)) onto three R1F32 vectors. ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); + Literal::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); std::unique_ptr param2_literal = - LiteralUtil::CreateR1({-10.0f, -100.0f, -900.0f, -400.0f}); + Literal::CreateR1({-10.0f, -100.0f, -900.0f, -400.0f}); std::unique_ptr param2_data = client_->TransferToServer(*param2_literal).ConsumeValueOrDie(); @@ -517,11 +516,11 @@ TEST_F(MapTest, MapOperantionWithBuildError) { auto error_add = sub_builder->BuildAndNoteError(); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); + Literal::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); @@ -531,9 +530,10 @@ TEST_F(MapTest, MapOperantionWithBuildError) { StatusOr computation_status = builder.Build(); ASSERT_TRUE(!computation_status.ok()); - EXPECT_THAT(computation_status.status().ToString(), - ::testing::HasSubstr("error from: ErrorAdd: binary op with " - "different element types: f32[] and u16[]")); + EXPECT_THAT( + computation_status.status().ToString(), + ::testing::HasSubstr("error from: ErrorAdd: binary op BINOP_ADD with " + "different element types: f32[] and u16[]")); } // MapTest disables inline and algsimp. MapTestWithFullOpt runs all @@ -554,8 +554,8 @@ TEST_F(MapTestWithFullOpt, MapScalarPower) { sub_builder->Pow(x, y); auto power = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(2.0f); - std::unique_ptr param1_literal = LiteralUtil::CreateR0(5.0f); + std::unique_ptr param0_literal = Literal::CreateR0(2.0f); + std::unique_ptr param1_literal = Literal::CreateR0(5.0f); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = @@ -581,8 +581,8 @@ TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) { sub_builder->Sub(y, x); // note that this is y - x, not x - y auto sub_opposite = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(2.0f); - std::unique_ptr param1_literal = LiteralUtil::CreateR0(5.0f); + std::unique_ptr param0_literal = Literal::CreateR0(2.0f); + std::unique_ptr param1_literal = Literal::CreateR0(5.0f); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = @@ -606,7 +606,7 @@ TEST_F(MapTestWithFullOpt, MapSquare) { sub_builder->Mul(x, x); auto square = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(10.0f); + std::unique_ptr param0_literal = Literal::CreateR0(10.0f); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -622,7 +622,6 @@ TEST_F(MapTestWithFullOpt, MapSquare) { int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index 51261f0ac1c15ee96dd0f749fec35971d73b34f2..717e9cd49489fe2533465c9e9d69e3b299dbab47 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -88,8 +87,8 @@ TEST_F(MatOpsSimpleTest, ExpTwoByTwoValues) { builder.Exp(data); std::unique_ptr expected = - LiteralUtil::CreateR2({{2.71828, 1.00000}, // row 0 - {0.36788, 1.64872}}); // row 1 + Literal::CreateR2({{2.71828, 1.00000}, // row 0 + {0.36788, 1.64872}}); // row 1 ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); } @@ -116,8 +115,8 @@ TEST_F(MatOpsSimpleTest, MapTwoByTwo) { auto map = builder.Map({data}, add_half); std::unique_ptr expected = - LiteralUtil::CreateR2({{1.5, 0.5}, // row 0 - {-0.5, 1.0}}); // row 1 + Literal::CreateR2({{1.5, 0.5}, // row 0 + {-0.5, 1.0}}); // row 1 ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); } @@ -134,8 +133,8 @@ TEST_F(MatOpsSimpleTest, MaxTwoByTwoValues) { auto max = builder.Max(lhs, rhs); std::unique_ptr expected = - LiteralUtil::CreateR2({{7.0, 6.0}, // row 0 - {3.0, -4.0}}); // row 1 + Literal::CreateR2({{7.0, 6.0}, // row 0 + {3.0, -4.0}}); // row 1 ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6)); } @@ -181,14 +180,12 @@ TEST_P(MatOpsDotAddTest, Dot_Add_2x2_2x2) { TF_ASSIGN_OR_ASSERT_OK( auto lhs_handle, - client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2DWithLayout( - lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + client_->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( + lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); TF_ASSIGN_OR_ASSERT_OK( auto rhs_handle, - client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2DWithLayout( - rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + client_->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( + rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); ComputationBuilder builder(client_, TestName()); auto lhs_arg = builder.Parameter(0, lhs_shape, "lhs"); @@ -218,7 +215,6 @@ INSTANTIATE_TEST_CASE_P(MatOpsDotAddTestInstances, MatOpsDotAddTest, int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc index 4929e25c580c427a3f034ccf82e7821222be0d8a..56c15e5ff7256cc75a10733e5934894cc88a34da 100644 --- a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc +++ b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -60,7 +59,6 @@ XLA_TEST_F(SliceTest, Slice3D) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b34e1d7db24fbbc5927102bce94f576f3e6d4947 --- /dev/null +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -0,0 +1,197 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/platform/types.h" + +using tensorflow::gtl::ArraySlice; + +namespace xla { +namespace { + +class MultiOutputFusionTest : public HloTestBase { + public: + ErrorSpec error_spec_{0.0001, 1e-2}; + + protected: + MultiOutputFusionTest() {} + void RunTest2D(bool manual_fusion, int64 size) { + auto builder = HloComputation::Builder(TestName()); + auto hlo_module = CreateNewModule(); + + const Shape elem_shape0 = ShapeUtil::MakeShape(F32, {}); + const Shape elem_shape2 = ShapeUtil::MakeShape(F32, {size, size}); + + auto const0 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(8.0f))); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, elem_shape0, "0")); + + auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( + elem_shape0, HloOpcode::kAdd, param0, const0)); + + HloInstruction* broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(elem_shape2, add1, {0, 1})); + + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, elem_shape2, "1")); + + HloInstruction* add2 = builder.AddInstruction(HloInstruction::CreateBinary( + elem_shape2, HloOpcode::kAdd, broadcast, param1)); + HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary( + elem_shape2, HloOpcode::kSubtract, param1, broadcast)); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateBinary(elem_shape2, HloOpcode::kDot, sub, add2)); + auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); + + if (manual_fusion) { + auto tuple = computation->AddInstruction(HloInstruction::CreateTuple( + ArraySlice({sub, add2}, 0, 2))); + auto gte0 = computation->AddInstruction( + HloInstruction::CreateGetTupleElement(elem_shape2, tuple, 0)); + auto gte1 = computation->AddInstruction( + HloInstruction::CreateGetTupleElement(elem_shape2, tuple, 1)); + TF_CHECK_OK(dot->ReplaceOperandWith(0, gte0)); + TF_CHECK_OK(dot->ReplaceOperandWith(1, gte1)); + + CHECK_NE( + computation->CreateFusionInstruction( + {tuple, sub, add2, broadcast}, HloInstruction::FusionKind::kLoop), + nullptr); + } + + Literal input; + input.PopulateWithValue(2.5f, {size, size}); + auto p1 = TransferToDevice(input); + auto p0 = TransferToDevice(*Literal::CreateR0(-9.0f)); + + Literal expect; + expect.PopulateWithValue(size * 1.5f * 3.5f, {size, size}); + auto actual = ExecuteAndTransfer(std::move(hlo_module), {p0, p1}); + LiteralTestUtil::ExpectNear(expect, *actual, error_spec_); + } + + void RunTest1D(bool manual_fusion, int size) { + auto builder = HloComputation::Builder(TestName()); + auto hlo_module = CreateNewModule(); + + const Shape elem_shape_F32 = ShapeUtil::MakeShape(F32, {size}); + const Shape elem_shape_U8 = ShapeUtil::MakeShape(F64, {size}); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, elem_shape_F32, "0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, elem_shape_U8, "1")); + + HloInstruction* param0_U8 = builder.AddInstruction( + HloInstruction::CreateConvert(elem_shape_U8, param0)); + HloInstruction* param1_F32 = builder.AddInstruction( + HloInstruction::CreateConvert(elem_shape_F32, param1)); + HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( + elem_shape_F32, HloOpcode::kAdd, param0, param1_F32)); + HloInstruction* sub_U8 = + builder.AddInstruction(HloInstruction::CreateBinary( + elem_shape_U8, HloOpcode::kSubtract, param0_U8, param1)); + HloInstruction* sub = builder.AddInstruction( + HloInstruction::CreateConvert(elem_shape_F32, sub_U8)); + + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {size, 1}), add)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kDot, sub, reshape)); + auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); + + if (manual_fusion) { + auto tuple = computation->AddInstruction(HloInstruction::CreateTuple( + ArraySlice({sub_U8, add}, 0, 2))); + + auto gte0 = computation->AddInstruction( + HloInstruction::CreateGetTupleElement(elem_shape_U8, tuple, 0)); + auto gte1 = computation->AddInstruction( + HloInstruction::CreateGetTupleElement(elem_shape_F32, tuple, 1)); + TF_CHECK_OK(sub->ReplaceOperandWith(0, gte0)); + TF_CHECK_OK(reshape->ReplaceOperandWith(0, gte1)); + + CHECK_NE(computation->CreateFusionInstruction( + {tuple, sub_U8, add, param0_U8, param1_F32}, + HloInstruction::FusionKind::kLoop), + nullptr); + } + + Literal input0, input1; + input0.PopulateWithValue(2.5f, {size}); + input1.PopulateWithValue(1, {size}); + auto p0 = TransferToDevice(input0); + auto p1 = TransferToDevice(input1); + + Literal expect = *Literal::CreateR0(size * 1.5f * 3.5f); + auto actual = ExecuteAndTransfer(std::move(hlo_module), {p0, p1}); + LiteralTestUtil::ExpectNear(expect, *actual, error_spec_); + } +}; + +XLA_TEST_F(MultiOutputFusionTest, 2DNofusion) { RunTest2D(false, 5); } +XLA_TEST_F(MultiOutputFusionTest, 2DFusion) { RunTest2D(true, 5); } +XLA_TEST_F(MultiOutputFusionTest, 2DFusionSize129) { RunTest2D(true, 129); } +XLA_TEST_F(MultiOutputFusionTest, DiffentTypesNoFusion) { RunTest1D(false, 8); } +XLA_TEST_F(MultiOutputFusionTest, DiffentTypesFusion) { RunTest1D(true, 8); } + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc index 4922bbf21c447e4db193e63919d4df5f8079e3be..e270a0477fe140b75b6d4ddffb5d4d98ced2171d 100644 --- a/tensorflow/compiler/xla/tests/pad_test.cc +++ b/tensorflow/compiler/xla/tests/pad_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -183,8 +182,8 @@ TEST_F(PadTest, Pad4DFloatArrayMinorFirstSmall) { const float pad_value = -5.123f; Array4D input_array(1, 1, 2, 3, {1, 2, 3, 4, 5, 6}); - auto input = LiteralUtil::CreateR4FromArray4D(input_array); - input = LiteralUtil::Relayout(*input, layout); + auto input = Literal::CreateR4FromArray4D(input_array); + input = input->Relayout(layout); b.Pad(b.ConstantLiteral(*input), b.ConstantR0(pad_value), padding_config); @@ -228,8 +227,8 @@ XLA_TEST_F(PadTest, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) { input_array(0, 0, 0, 0) = 1.0f; input_array(0, 24, 6, 6) = 2.0f; input_array(0, 17, 2, 5) = 3.0f; - auto input = LiteralUtil::CreateR4FromArray4D(input_array); - input = LiteralUtil::Relayout(*input, layout); + auto input = Literal::CreateR4FromArray4D(input_array); + input = input->Relayout(layout); b.Pad(b.ConstantLiteral(*input), b.ConstantR0(pad_value), padding_config); @@ -308,7 +307,7 @@ XLA_TEST_F(PadTest, Large2DPad) { auto ones = MakeUnique>(4, 4); ones->Fill(1.0f); - auto input_literal = LiteralUtil::CreateR2FromArray2D(*ones); + auto input_literal = Literal::CreateR2FromArray2D(*ones); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -334,7 +333,7 @@ XLA_TEST_F(PadTest, AllTypes2DPad) { auto operand = MakeUnique>(in_rows, in_cols); operand->FillUnique(0.0f); - auto input_literal = LiteralUtil::CreateR2FromArray2D(*operand); + auto input_literal = Literal::CreateR2FromArray2D(*operand); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -365,7 +364,7 @@ XLA_TEST_F(PadTest, High2DPad) { auto operand = MakeUnique>(in_rows, in_cols); operand->FillUnique(1.0f); - auto input_literal = LiteralUtil::CreateR2FromArray2D(*operand); + auto input_literal = Literal::CreateR2FromArray2D(*operand); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -397,7 +396,7 @@ XLA_TEST_F(PadTest, NegativePadding2D) { auto operand = MakeUnique>(in_rows, in_cols); operand->FillUnique(1.0f); - auto input_literal = LiteralUtil::CreateR2FromArray2D(*operand); + auto input_literal = Literal::CreateR2FromArray2D(*operand); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -429,7 +428,7 @@ XLA_TEST_F(PadTest, NegativeAndInteriorPadding2D) { auto operand = MakeUnique>(in_rows, in_cols); operand->FillUnique(1.0f); - auto input_literal = LiteralUtil::CreateR2FromArray2D(*operand); + auto input_literal = Literal::CreateR2FromArray2D(*operand); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -453,7 +452,7 @@ XLA_TEST_F(PadTest, ReducePad) { auto ones = MakeUnique>(2, 2, 2, 2); ones->Fill(1.0); - auto input_literal = LiteralUtil::CreateR4FromArray4D(*ones); + auto input_literal = Literal::CreateR4FromArray4D(*ones); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -470,7 +469,6 @@ XLA_TEST_F(PadTest, ReducePad) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc index 3e1bfcd3090df6df69e344c157390a41476f17a4..a7692fceb4751a4e81851c382be0371efbff8dc8 100644 --- a/tensorflow/compiler/xla/tests/params_test.cc +++ b/tensorflow/compiler/xla/tests/params_test.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -44,8 +43,7 @@ class ParamsTest : public ClientLibraryTestBase {}; XLA_TEST_F(ParamsTest, ConstantR0F32Param) { ComputationBuilder builder(client_, TestName()); - std::unique_ptr param0_literal = - LiteralUtil::CreateR0(3.14159f); + std::unique_ptr param0_literal = Literal::CreateR0(3.14159f); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -57,7 +55,7 @@ XLA_TEST_F(ParamsTest, ConstantR0F32Param) { XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) { ComputationBuilder builder(client_, TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); + std::unique_ptr param0_literal = Literal::CreateR1({}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -70,7 +68,7 @@ XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) { XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({3.14f, -100.25f}); + Literal::CreateR1({3.14f, -100.25f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -83,7 +81,7 @@ XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) { XLA_TEST_F(ParamsTest, ConstantR1U8Param) { ComputationBuilder builder(client_, TestName()); string str("hello world"); - std::unique_ptr param0_literal = LiteralUtil::CreateR1U8(str); + std::unique_ptr param0_literal = Literal::CreateR1U8(str); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -96,7 +94,7 @@ XLA_TEST_F(ParamsTest, ConstantR1U8Param) { XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR2FromArray2D(Array2D(3, 0)); + Literal::CreateR2FromArray2D(Array2D(3, 0)); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -108,7 +106,7 @@ XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { XLA_TEST_F(ParamsTest, ConstantR2F32Param) { ComputationBuilder builder(client_, TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR2( + std::unique_ptr param0_literal = Literal::CreateR2( {{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -124,12 +122,12 @@ XLA_TEST_F(ParamsTest, ConstantR2F32Param) { XLA_TEST_F(ParamsTest, TwoParameters) { ComputationBuilder builder(client_, TestName()); - std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); + std::unique_ptr literal0 = Literal::CreateR1({1, 2}); std::unique_ptr param0_data = client_->TransferToServer(*literal0).ConsumeValueOrDie(); auto param0 = builder.Parameter(0, literal0->shape(), "param0"); - std::unique_ptr literal1 = LiteralUtil::CreateR1({10, 20}); + std::unique_ptr literal1 = Literal::CreateR1({10, 20}); std::unique_ptr param1_data = client_->TransferToServer(*literal1).ConsumeValueOrDie(); auto param1 = builder.Parameter(1, literal1->shape(), "param1"); @@ -155,7 +153,7 @@ XLA_TEST_F(ParamsTest, TwoParameters) { XLA_TEST_F(ParamsTest, MissingParameter) { // Test that an error is returned when a computation with an incomplete set of // parameters (parameter numbers not contiguous from 0) is executed. - std::unique_ptr literal = LiteralUtil::CreateR0(3.14159f); + std::unique_ptr literal = Literal::CreateR0(3.14159f); std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -173,12 +171,12 @@ XLA_TEST_F(ParamsTest, MissingParameter) { XLA_TEST_F(ParamsTest, UnusedParameter) { ComputationBuilder builder(client_, TestName()); - std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); + std::unique_ptr literal0 = Literal::CreateR1({1, 2}); std::unique_ptr param0_data = client_->TransferToServer(*literal0).ConsumeValueOrDie(); auto param0 = builder.Parameter(0, literal0->shape(), "param0"); - std::unique_ptr literal1 = LiteralUtil::CreateR1({10, 20}); + std::unique_ptr literal1 = Literal::CreateR1({10, 20}); std::unique_ptr param1_data = client_->TransferToServer(*literal1).ConsumeValueOrDie(); auto param1 = builder.Parameter(1, literal1->shape(), "param1"); @@ -193,12 +191,11 @@ XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) { // unused expression. ComputationBuilder builder(client_, TestName()); - std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); + std::unique_ptr literal0 = Literal::CreateR1({1, 2}); std::unique_ptr param0_data = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = - LiteralUtil::CreateR1({10, 20, 30}); + std::unique_ptr literal1 = Literal::CreateR1({10, 20, 30}); std::unique_ptr param1_data = client_->TransferToServer(*literal1).ConsumeValueOrDie(); @@ -238,7 +235,7 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { std::vector sum_value = {{entry0, entry1}}; sum_value.resize(size); - std::unique_ptr literal = LiteralUtil::CreateR1(sum_value); + std::unique_ptr literal = Literal::CreateR1(sum_value); param_data_owner.push_back( client_->TransferToServer(*literal).ConsumeValueOrDie()); ComputationDataHandle param = @@ -268,9 +265,9 @@ XLA_TEST_F(ParamsTest, std::unique_ptr data = client_ - ->TransferToServer(*LiteralUtil::MakeTuple({ - LiteralUtil::CreateR1({1, 2, 3}).get(), - LiteralUtil::CreateR1({4, 5, 6}).get(), + ->TransferToServer(*Literal::MakeTuple({ + Literal::CreateR1({1, 2, 3}).get(), + Literal::CreateR1({4, 5, 6}).get(), })) .ConsumeValueOrDie(); @@ -282,7 +279,7 @@ XLA_TEST_F(ParamsTest, // Verifies that passing a 2x2 with {0, 1} layout returns the same value back // when (transferred to the server and) passed through a parameter. XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) { - std::unique_ptr literal = LiteralUtil::CreateR2({ + std::unique_ptr literal = Literal::CreateR2({ {1, 2}, {3, 4}, }); *literal->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); @@ -296,7 +293,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) { // As above, but for {1, 0} layout. XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) { - std::unique_ptr literal = LiteralUtil::CreateR2({ + std::unique_ptr literal = Literal::CreateR2({ {1, 3}, {2, 4}, }); *literal->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0}); @@ -309,7 +306,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) { } XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { - std::unique_ptr literal = LiteralUtil::CreateR2({ + std::unique_ptr literal = Literal::CreateR2({ {1, 3}, {2, 4}, }); const Shape original = literal->shape(); @@ -322,7 +319,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { std::reverse(original_layout.begin(), original_layout.end()); *literal->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout(original_layout); - ASSERT_EQ(2, LiteralUtil::Get(*literal, {0, 1})); + ASSERT_EQ(2, literal->Get({0, 1})); } // Use the original shape in building the computation. ComputationBuilder builder(client_, TestName()); @@ -344,7 +341,6 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/tools/ci_build/builds/tensorboard.sh b/tensorflow/compiler/xla/tests/plugin.bzl old mode 100755 new mode 100644 similarity index 60% rename from tensorflow/tools/ci_build/builds/tensorboard.sh rename to tensorflow/compiler/xla/tests/plugin.bzl index 77bd29c09f8a1009708ed2bd95987df954fd4a77..1b10c778ce3587d9b3f345a92abbb4da92bcad9b --- a/tensorflow/tools/ci_build/builds/tensorboard.sh +++ b/tensorflow/compiler/xla/tests/plugin.bzl @@ -1,5 +1,4 @@ -#!/usr/bin/env bash -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,17 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""Additional XLA devices to be included in the unit test suite.""" -set -e - -export LAUNCHPAD_CHROME=${LAUNCHPAD_CHROME:-$(which chromium-browser)} - -cd tensorflow/tensorboard - -# Install all js dependencies (tooling via npm, frontend assets via bower) -npm run prepare +# Example: +# +# plugins = { +# "foo": { +# "deps": [ +# "//tensorflow/compiler/plugin/foo:foo_lib", +# "//tensorflow/compiler/plugin/foo:test_macros", +# ], +# "copts": [], +# "tags": [], +# "args": [] +# }, +# } -npm run compile +plugins = {} -# Run wct in headless chrome using xvfb -xvfb-run ./node_modules/web-component-tester/bin/wct --skip-plugin=sauce diff --git a/tensorflow/compiler/xla/tests/pred_test.cc b/tensorflow/compiler/xla/tests/pred_test.cc index b031725d8abd897c83e40a3514bcccb7d7d76acf..d865297ae612f614f45aa6b4b226e15ee154ed2f 100644 --- a/tensorflow/compiler/xla/tests/pred_test.cc +++ b/tensorflow/compiler/xla/tests/pred_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -142,7 +141,6 @@ TEST_F(PredTest, AnyR2False) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 5117478bfd55093a82a5fa361feb5cf59fd68fd1..ed994fda4501deec78b75154a96ee88755a1c7c4 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -58,11 +57,10 @@ void PrngTest::UniformTest(T a, T b, tensorflow::gtl::ArraySlice dims) { SetSeed(42); auto actual = ExecuteAndTransferOrDie(&builder, /*arguments=*/{}); EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions())); - LiteralUtil::EachCell(*actual, - [=](tensorflow::gtl::ArraySlice, T value) { - EXPECT_LE(a, value); - EXPECT_LT(value, b); - }); + actual->EachCell([=](tensorflow::gtl::ArraySlice, T value) { + EXPECT_LE(a, value); + EXPECT_LT(value, b); + }); } void PrngTest::BernoulliTest(float p, tensorflow::gtl::ArraySlice dims) { @@ -71,7 +69,7 @@ void PrngTest::BernoulliTest(float p, tensorflow::gtl::ArraySlice dims) { builder.RngBernoulli(builder.ConstantR0(p), shape); TF_ASSIGN_OR_ASSERT_OK(auto computation, builder.Build()); - ExecutionOptions execution_options; + ExecutionOptions execution_options = execution_options_; execution_options.set_seed(42); TF_ASSIGN_OR_ASSERT_OK( auto actual, @@ -79,8 +77,8 @@ void PrngTest::BernoulliTest(float p, tensorflow::gtl::ArraySlice dims) { &execution_options)); EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions())); int32 sum = 0; - LiteralUtil::EachCell( - *actual, [&sum](tensorflow::gtl::ArraySlice, uint32 value) { + actual->EachCell( + [&sum](tensorflow::gtl::ArraySlice, uint32 value) { EXPECT_TRUE(value == 0 || value == 1); sum += value; }); @@ -124,10 +122,8 @@ double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count) { SetSeed(42); auto actual = ExecuteAndTransferOrDie(&builder, /*arguments=*/{}); std::vector counts(range_size, 0); - LiteralUtil::EachCell( - *actual, [&counts](tensorflow::gtl::ArraySlice, int32 value) { - ++counts[value]; - }); + actual->EachCell([&counts](tensorflow::gtl::ArraySlice, + int32 value) { ++counts[value]; }); int64 sum = 0; for (int32 i = 0; i < range_size; ++i) { sum += Square(static_cast(counts[i] - expected_count)); @@ -170,7 +166,7 @@ XLA_TEST_F(PrngTest, MapUsingRng) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({2.2f, 5.3f, 4.4f, 5.5f}); + Literal::CreateR1({2.2f, 5.3f, 4.4f, 5.5f}); TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr param0_data, client_->TransferToServer(*param0_literal)); @@ -180,7 +176,7 @@ XLA_TEST_F(PrngTest, MapUsingRng) { TF_ASSIGN_OR_ASSERT_OK(auto computation, builder.Build()); - ExecutionOptions execution_options; + ExecutionOptions execution_options = execution_options_; execution_options.set_seed(125); TF_ASSIGN_OR_ASSERT_OK( auto actual, @@ -209,10 +205,10 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { return builder.Build(); }; - ExecutionOptions execution_options1; + ExecutionOptions execution_options1 = execution_options_; execution_options1.set_seed(42); - ExecutionOptions execution_options2; + ExecutionOptions execution_options2 = execution_options_; execution_options2.set_seed(65); std::unique_ptr result1; @@ -247,9 +243,11 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { client_->ExecuteAndTransfer(computation, /*arguments=*/{}, &execution_options2)); TF_ASSIGN_OR_ASSERT_OK( - result5, client_->ExecuteAndTransfer(computation, /*arguments=*/{})); + result5, client_->ExecuteAndTransfer(computation, /*arguments=*/{}, + &execution_options_)); TF_ASSIGN_OR_ASSERT_OK( - result6, client_->ExecuteAndTransfer(computation, /*arguments=*/{})); + result6, client_->ExecuteAndTransfer(computation, /*arguments=*/{}, + &execution_options_)); } LiteralTestUtil::ExpectEqual(*result1, *result2); @@ -273,13 +271,23 @@ XLA_TEST_F(PrngTest, TenValuesN01) { // TODO(b/25995601): Test that resultant values are reasonable } +XLA_TEST_F(PrngTest, RngUniformCrash) { + ComputationBuilder builder(client_, TestName()); + + // This used to crash XLA during LLVM IR generation for CPUs. + auto rng_uniform = builder.RngUniform(builder.ConstantR0(0), + builder.ConstantR0(1000 * 1000), + ShapeUtil::MakeShape(S32, {})); + SetSeed(0); + ExecuteAndTransferOrDie(&builder, /*arguments=*/{}); +} + } // namespace } // namespace xla int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc index 4a02567a1a2ea8014cceca085c3d3d8589d6500f..0078733e197685fea575e78b8435485ea9de4926 100644 --- a/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc +++ b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -46,7 +45,6 @@ TEST_F(QueryInferredShapeTest, OnePlusOneShape) { int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a66c9b44872d7e4e5db7252b5bd0291f3ef88b4e --- /dev/null +++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc @@ -0,0 +1,261 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class ReducePrecisionTest : public ClientLibraryTestBase, + public ::testing::WithParamInterface {}; + +// For reduction to IEEE-f16, we want to test the following cases, in both +// positive and negative variants. (Note: IEEE-f16 is 5 exponent bits and 10 +// mantissa bits.) +// +// Vectors of exponent and mantissa sizes to test. We want to test IEEE-f32 (a +// no-op), IEEE-f16, and exponent-reduction-only and mantissa-reduction-only +// variants of IEEE-f16. +static const int exponent_sizes[] = {8, 5, 5, 8}; +static const int mantissa_sizes[] = {23, 10, 23, 10}; + +string TestDataToString(const ::testing::TestParamInfo data) { + int i = data.param; + return tensorflow::strings::StrCat(exponent_sizes[i], "_exponent_bits_", + mantissa_sizes[i], "_mantissa_bits"); +} + +// The FPVAL macro allows us to write out the binary representation of the +// input and expected values in a more readable manner. The mantissa bits +// are separated into the "high" bits (retained with reduction to IEEE-f16) +// and the "low" bits (truncated with reduction to IEEE-f16). +#define FPVAL(EXPONENT, HIGH_MANTISSA, LOW_MANTISSA) \ + ((0b##EXPONENT << 23) + (0b##HIGH_MANTISSA << 13) + (0b##LOW_MANTISSA)) + +// Each element in the test-value array consists of four numbers. The first is +// the input value and the following are the expected output values for the +// various precision-reduction cases. +static const uint32_t test_values[][4] = { + // True zero. + { + FPVAL(00000000, 0000000000, 0000000000000), // 0.0 + FPVAL(00000000, 0000000000, 0000000000000), // 0.0 + FPVAL(00000000, 0000000000, 0000000000000), // 0.0 + FPVAL(00000000, 0000000000, 0000000000000) // 0.0 + }, + // Largest exponent that underflows to zero. + { + FPVAL(01110000, 0000000000, 0000000000000), // 3.05176e-05 + FPVAL(00000000, 0000000000, 0000000000000), // 0.0 + FPVAL(00000000, 0000000000, 0000000000000), // 0.0 + FPVAL(01110000, 0000000000, 0000000000000) // 3.05176e-05 + }, + // Largest value that rounds to a denormal and thus clamps to zero. + { + FPVAL(01110000, 1111111111, 0111111111111), // 6.10203e-05 + FPVAL(00000000, 0000000000, 0000000000000), // 0.0 + FPVAL(00000000, 0000000000, 0000000000000), // 0.0 + FPVAL(01110000, 1111111111, 0000000000000) // 6.10054e-05 + }, + // Smallest value that doesn't underflow to zero, due to mantissa rounding + // up and incrementing the exponent out of the denormal range. + { + FPVAL(01110000, 1111111111, 1000000000000), // 6.10203e-05 + FPVAL(01110001, 0000000000, 0000000000000), // 6.10352e-05 + FPVAL(00000000, 0000000000, 0000000000000), // 0.0 + FPVAL(01110001, 0000000000, 0000000000000) // 6.10352e-05 + }, + // Smallest value that doesn't underflow to zero even without mantissa + // rounding. + { + FPVAL(01110001, 0000000000, 0000000000000), // 6.10352e-05 + FPVAL(01110001, 0000000000, 0000000000000), // 6.10352e-05 + FPVAL(01110001, 0000000000, 0000000000000), // 6.10352e-05 + FPVAL(01110001, 0000000000, 0000000000000) // 6.10352e-05 + }, + // One (to make sure bias-handling is done correctly. + { + FPVAL(01111111, 0000000000, 0000000000000), // 1.0 + FPVAL(01111111, 0000000000, 0000000000000), // 1.0 + FPVAL(01111111, 0000000000, 0000000000000), // 1.0 + FPVAL(01111111, 0000000000, 0000000000000) // 1.0 + }, + // Values in a space where ties round down due to ties-to-even: + // Value with highest mantissa that rounds down. + { + FPVAL(01111111, 0000000000, 1000000000000), // 1.00049 + FPVAL(01111111, 0000000000, 0000000000000), // 1.0 + FPVAL(01111111, 0000000000, 1000000000000), // 1.00049 + FPVAL(01111111, 0000000000, 0000000000000) // 1.0 + }, + // Value with lowest mantissa that rounds up. + { + FPVAL(01111111, 0000000000, 1000000000001), // 1.00049 + FPVAL(01111111, 0000000001, 0000000000000), // 1.00098 + FPVAL(01111111, 0000000000, 1000000000001), // 1.00049 + FPVAL(01111111, 0000000001, 0000000000000) // 1.00098 + }, + // Values in a space where ties round up due to ties-to-even: + // Value with highest mantissa that rounds down. + { + FPVAL(01111111, 0000000001, 0111111111111), // 1.00146 + FPVAL(01111111, 0000000001, 0000000000000), // 1.00098 + FPVAL(01111111, 0000000001, 0111111111111), // 1.00146 + FPVAL(01111111, 0000000001, 0000000000000) // 1.00098 + }, + // Value with a mantissa that rounds up. + { + FPVAL(01111111, 0000000001, 1000000000000), // 1.00146 + FPVAL(01111111, 0000000010, 0000000000000), // 1.00195 + FPVAL(01111111, 0000000001, 1000000000000), // 1.00146 + FPVAL(01111111, 0000000010, 0000000000000) // 1.00195 + }, + // Largest value that does not overflow to infinity. + { + FPVAL(10001110, 1111111111, 0111111111111), // 65520.0 + FPVAL(10001110, 1111111111, 0000000000000), // 65504.0 + FPVAL(10001110, 1111111111, 0111111111111), // 65520.0 + FPVAL(10001110, 1111111111, 0000000000000) // 65504.0 + }, + // Smallest value that overflows to infinity due to mantissa rounding up. + { + FPVAL(10001110, 1111111111, 1000000000000), // 65520.0 + FPVAL(11111111, 0000000000, 0000000000000), // Inf + FPVAL(10001110, 1111111111, 1000000000000), // 65520.0 + FPVAL(10001111, 0000000000, 0000000000000) // 65536.0 + }, + // Smallest value that overflows to infinity, without mantissa rounding. + { + FPVAL(10001111, 0000000000, 0000000000000), // 65536.0 + FPVAL(11111111, 0000000000, 0000000000000), // Inf + FPVAL(11111111, 0000000000, 0000000000000), // Inf + FPVAL(10001111, 0000000000, 0000000000000) // 65536.0 + }, + // Smallest value that overflows to infinity due to mantissa rounding up, + // even when exponent bits aren't reduced. + { + FPVAL(11111110, 1111111111, 1000000000000), // 3.40199e+38 + FPVAL(11111111, 0000000000, 0000000000000), // Inf + FPVAL(11111111, 0000000000, 0000000000000), // Inf + FPVAL(11111111, 0000000000, 0000000000000) // Inf + }, + // True infinity. + { + FPVAL(11111111, 0000000000, 0000000000000), // Inf + FPVAL(11111111, 0000000000, 0000000000000), // Inf + FPVAL(11111111, 0000000000, 0000000000000), // Inf + FPVAL(11111111, 0000000000, 0000000000000) // Inf + }, + // NAN with a 1 in the preserved bits. + { + FPVAL(11111111, 1000000000, 0000000000000), // NaN + FPVAL(11111111, 1000000000, 0000000000000), // NaN + FPVAL(11111111, 1000000000, 0000000000000), // NaN + FPVAL(11111111, 1000000000, 0000000000000) // NaN + }, + // NAN with a 1 in the truncated bits. + { + FPVAL(11111111, 0000000000, 0000000000001), // NaN + FPVAL(11111111, 0000000000, 0000000000001), // NaN + FPVAL(11111111, 0000000000, 0000000000001), // NaN + FPVAL(11111111, 0000000000, 0000000000001) // NaN + }, + // NAN with all ones, causing rounding overflow. + { + FPVAL(11111111, 1111111111, 1111111111111), // NaN + FPVAL(11111111, 1111111111, 1111111111111), // NaN + FPVAL(11111111, 1111111111, 1111111111111), // NaN + FPVAL(11111111, 1111111111, 1111111111111) // NaN + }}; + +XLA_TEST_P(ReducePrecisionTest, ReducePrecisionF32) { + int index = GetParam(); + int exponent_bits = exponent_sizes[index]; + int mantissa_bits = mantissa_sizes[index]; + + std::vector input_values; + std::vector expected_values; + + const uint32_t sign_bit = 1u << 31; + for (const auto& test_value : test_values) { + // Add positive values. + input_values.push_back(tensorflow::bit_cast(test_value[0])); + expected_values.push_back(tensorflow::bit_cast(test_value[index])); + // Add negative values. We do this in the + input_values.push_back( + tensorflow::bit_cast(test_value[0] & sign_bit)); + expected_values.push_back( + tensorflow::bit_cast(test_value[index] & sign_bit)); + } + + // This is required for proper handling of NaN values. + SetFastMathDisabled(true); + + ComputationBuilder builder(client_, TestName()); + + std::unique_ptr a_literal = Literal::CreateR1({input_values}); + std::unique_ptr a_data = + client_->TransferToServer(*a_literal).ConsumeValueOrDie(); + auto a = builder.Parameter(0, a_literal->shape(), "a"); + + auto reduce_precision = + builder.ReducePrecision(a, exponent_bits, mantissa_bits); + + ComputeAndCompareR1(&builder, expected_values, {a_data.get()}); +} + +INSTANTIATE_TEST_CASE_P(ReducePrecisionTest, ReducePrecisionTest, + ::testing::Values(0, 1, 2, 3), TestDataToString); + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index ff24177520eab5c6c2061d01223530249050448c..ac65a47afa573e7eb04a87d216a9e17878a11d4c 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -40,7 +40,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -64,12 +63,12 @@ class ReduceTest : public ClientLibraryTestBase { ReduceTest() { // Implementation note: laid out z >> y >> x by default. // clang-format off - literal_2d_ = LiteralUtil::CreateR2({ + literal_2d_ = Literal::CreateR2({ // x0 x1 x2 { 1.f, 2.f, 3.f}, // y0 { 4.f, 5.f, 6.f}, // y1 }); - literal_3d_ = LiteralUtil::CreateR3Projected({ + literal_3d_ = Literal::CreateR3Projected({ // x0 x1 x2 { 1.f, 2.f, 3.f}, // y0 { 4.f, 5.f, 6.f}, // y1 @@ -98,7 +97,7 @@ class ReduceTest : public ClientLibraryTestBase { } } std::unique_ptr input_literal = - LiteralUtil::CreateR1(AsSlice(input_data)); + Literal::CreateR1(AsSlice(input_data)); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -130,7 +129,7 @@ class ReduceTest : public ClientLibraryTestBase { builder.Reduce(pred_values, init_value, reduce, /*dimensions_to_reduce=*/{0}); - std::unique_ptr input_literal = LiteralUtil::CreateR1(input_data); + std::unique_ptr input_literal = Literal::CreateR1(input_data); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -157,9 +156,9 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); - input_literal = LiteralUtil::Relayout( - *input_literal, LayoutUtil::MakeLayout({minor, major})); + Literal::CreateR2FromArray2D(input_data); + input_literal = + input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -185,9 +184,9 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); - input_literal = LiteralUtil::Relayout( - *input_literal, LayoutUtil::MakeLayout({minor, major})); + Literal::CreateR2FromArray2D(input_data); + input_literal = + input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -306,9 +305,8 @@ XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); - input_literal = - LiteralUtil::Relayout(*input_literal, LayoutUtil::MakeLayout({0, 1})); + Literal::CreateR2FromArray2D(input_data); + input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1})); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -339,9 +337,8 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); - input_literal = - LiteralUtil::Relayout(*input_literal, LayoutUtil::MakeLayout({0, 1})); + Literal::CreateR2FromArray2D(input_data); + input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1})); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -372,7 +369,7 @@ XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { Array3D input_data(rows, 2, cols / 2); input_data.FillRandom(3.14f, 0.04); std::unique_ptr input_literal = - LiteralUtil::CreateR3FromArray3D(input_data); + Literal::CreateR3FromArray3D(input_data); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -435,7 +432,7 @@ XLA_TEST_F(ReduceTest, MaxReduce2DToR0) { auto max = CreateScalarMaxComputation(F32, &builder); Array2D input(300, 250); input.FillRandom(214.0f); - auto input_literal = LiteralUtil::CreateR2FromArray2D(input); + auto input_literal = Literal::CreateR2FromArray2D(input); builder.Reduce(builder.ConstantLiteral(*input_literal), builder.ConstantR0(FLT_MIN), max, {0, 1}); auto input_max = FLT_MIN; @@ -450,7 +447,7 @@ XLA_TEST_F(ReduceTest, MinReduce2DToR0) { auto min = CreateScalarMinComputation(F32, &builder); Array2D input(150, 130); input.FillRandom(214.0f); - auto input_literal = LiteralUtil::CreateR2FromArray2D(input); + auto input_literal = Literal::CreateR2FromArray2D(input); builder.Reduce(builder.ConstantLiteral(*input_literal), builder.ConstantR0(FLT_MAX), min, {0, 1}); @@ -580,9 +577,9 @@ XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) { Array3D input_array(bounds[0], bounds[1], bounds[2]); input_array.FillRandom(3.14f, 0.05); - auto input_literal = LiteralUtil::CreateR3FromArray3D(input_array); - input_literal = LiteralUtil::Relayout( - *input_literal, LayoutUtil::MakeLayout(GetParam().layout)); + auto input_literal = Literal::CreateR3FromArray3D(input_array); + input_literal = + input_literal->Relayout(LayoutUtil::MakeLayout(GetParam().layout)); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -630,7 +627,6 @@ INSTANTIATE_TEST_CASE_P( int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index ec7b47bc283538d7d9219610e4297fee8028d07f..4b2fa683d8551a173fb9ddb4b05ce3c344baba4c 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -58,7 +57,7 @@ class ReduceWindowTest : public ClientLibraryTestBase { tensorflow::gtl::ArraySlice window_strides, Padding padding) { builder_.ReduceWindow( - input, builder_.ConstantLiteral(LiteralUtil::MinValue(F32)), + input, builder_.ConstantLiteral(Literal::MinValue(F32)), CreateScalarMax(), window_dimensions, window_strides, padding); } @@ -67,7 +66,7 @@ class ReduceWindowTest : public ClientLibraryTestBase { tensorflow::gtl::ArraySlice window_strides, Padding padding) { builder_.ReduceWindow(input, - builder_.ConstantLiteral(LiteralUtil::MaxValue(F32)), + builder_.ConstantLiteral(Literal::MaxValue(F32)), CreateScalarMinComputation(F32, &builder_), window_dimensions, window_strides, padding); } @@ -75,6 +74,12 @@ class ReduceWindowTest : public ClientLibraryTestBase { ComputationBuilder builder_; }; +TEST_F(ReduceWindowTest, DISABLED_ON_CPU(DISABLED_ON_GPU(Min3In5Stride2))) { + const auto input = builder_.ConstantR1({10000, 1000, 100, 10, 1}); + ReduceWindowMin(input, {3}, {2}, Padding::kValid); + ComputeAndCompareR1(&builder_, {100, 1}, {}, ErrorSpec(0.0001)); +} + XLA_TEST_F(ReduceWindowTest, ZeroElementSmall) { Array4D input_array(1, 0, 2, 1); @@ -132,6 +137,26 @@ TEST_F(ReduceWindowTest, Along2ndMinorDim) { ComputeAndCompareR4(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); } +TEST_F(ReduceWindowTest, AmongMajor2Dims) { + Array4D input_array(4, 4, 6, 8); + input_array.FillWithMinorDimNum(); + + int win_len = 3; + int win_stride = 1; + + Padding padding = Padding::kSame; + const auto input_data_handle = + builder_.ConstantR4FromArray4D(input_array); + // Reduce only along the x and y dimensions, according to the win_len. + ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1}, + {win_stride, win_stride, 1, 1}, padding); + + auto result = ReferenceUtil::ReduceWindow4DAdd( + input_array, 0.0f, {win_len, win_len, 1, 1}, + {win_stride, win_stride, 1, 1}, padding); + ComputeAndCompareR4(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3)); +} + TEST_F(ReduceWindowTest, AmongMajor2DimsMediumSize) { Array4D input_array(9, 12, 4, 89); input_array.FillRandom(2.0f); @@ -184,202 +209,6 @@ TEST_F(ReduceWindowTest, DISABLED_AmongMajor2DimsMediumSizeLargePadding) { ComputeAndCompareR4(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3)); } -// TODO(b/31809540): Implement minor dim reduction to reduce num of reshapes. -TEST_F(ReduceWindowTest, ReduceR4AmongXYMinorSmall) { - Array4D input_array(2, 2, 4, 16); - - Array2D yx({{0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, - 11.f, 12.f, 13.f, 14.f, 15.f}, - {16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, - 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f}, - {32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, - 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f}, - {48.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, - 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f}}); - input_array.FillWithYX(yx); - - int win_len = 2; - int win_stride = 2; - const auto input = builder_.ConstantR4FromArray4D(input_array); - Padding padding = Padding::kValid; - ReduceWindowAdd(input, {1, 1, win_len, win_len}, - {1, 1, win_stride, win_stride}, padding); - - auto res = ReferenceUtil::ReduceWindow4DAdd( - input_array, 0.0f, {1, 1, win_len, win_len}, - {1, 1, win_stride, win_stride}, padding); - ComputeAndCompareR4(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); -} - -// TODO(b/31809540): Implement minor dim reduction to reduce num of reshapes. -TEST_F(ReduceWindowTest, ReduceR4AmongXYMinorSmallOverlapped) { - constexpr int64 p = 2; - constexpr int64 z = 2; - constexpr int64 y = 4; - constexpr int64 x = 16; - Array4D input_array(p, z, y, x); - - Array2D yx({{0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, - 11.f, 12.f, 13.f, 14.f, 15.f}, - {16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, - 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f}, - {32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, - 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f}, - {48.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, - 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f}}); - input_array.FillWithYX(yx); - - int win_len = 4; - int win_stride = 2; - const auto input = builder_.ConstantR4FromArray4D(input_array); - ReduceWindowAdd(input, {1, 1, win_len, win_len}, - {1, 1, win_stride, win_stride}, Padding::kValid); - - // Expected result - Array2D yx_result({{408.f, 440.f, 472.f, 504.f, 536.f, 568.f, 600.f}}); - Array4D expected(p, z, 1, 7); - expected.FillWithYX(yx_result); - ComputeAndCompareR4(&builder_, expected, {}, ErrorSpec(1e-3, 1e-3)); -} - -TEST_F(ReduceWindowTest, MaxTrivial) { - const auto input = builder_.ConstantR1({42}); - ReduceWindowMax(input, {1}, {1}, Padding::kValid); - ComputeAndCompareR1(&builder_, {42}, {}, ErrorSpec(0.0001)); -} - -TEST_F(ReduceWindowTest, Add3In3) { - const auto input = builder_.ConstantR1({20, 100, 3}); - ReduceWindowAdd(input, {3}, {1}, Padding::kValid); - ComputeAndCompareR1(&builder_, {123}, {}, ErrorSpec(0.0001)); -} - -TEST_F(ReduceWindowTest, Add4In16Stride4) { - const auto input = builder_.ConstantR1( - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - ReduceWindowAdd(input, {4}, {4}, Padding::kValid); - ComputeAndCompareR1(&builder_, {10, 26, 42, 58}, {}, - ErrorSpec(0.0001)); -} - -TEST_F(ReduceWindowTest, DISABLED_ON_CPU(DISABLED_ON_GPU(Min3In5Stride2))) { - const auto input = builder_.ConstantR1({10000, 1000, 100, 10, 1}); - ReduceWindowMin(input, {3}, {2}, Padding::kValid); - ComputeAndCompareR1(&builder_, {100, 1}, {}, ErrorSpec(0.0001)); -} - -TEST_F(ReduceWindowTest, Max3In3) { - const auto input = builder_.ConstantR1({20, 100, 3}); - ReduceWindowMax(input, {3}, {1}, Padding::kValid); - ComputeAndCompareR1(&builder_, {100}, {}, ErrorSpec(0.0001)); -} - -TEST_F(ReduceWindowTest, Add2In3) { - const auto input = builder_.ConstantR1({100, 10, 1}); - ReduceWindowAdd(input, {2}, {1}, Padding::kValid); - ComputeAndCompareR1(&builder_, {110, 11}, {}, ErrorSpec(0.0001)); -} - -TEST_F(ReduceWindowTest, Add3In5Stride2) { - const auto input = builder_.ConstantR1({10000, 1000, 100, 10, 1}); - ReduceWindowAdd(input, {3}, {2}, Padding::kValid); - ComputeAndCompareR1(&builder_, {11100, 111}, {}, ErrorSpec(0.0001)); -} - -TEST_F(ReduceWindowTest, Max4In16Stride4) { - const auto input = builder_.ConstantR1( - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - ReduceWindowMax(input, {4}, {4}, Padding::kValid); - ComputeAndCompareR1(&builder_, {4, 8, 12, 16}, {}, ErrorSpec(0.0001)); -} - -TEST_F(ReduceWindowTest, Max4In16Stride3) { - const auto input = builder_.ConstantR1( - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - ReduceWindowMax(input, {4}, {3}, Padding::kValid); - ComputeAndCompareR1(&builder_, {4, 7, 10, 13, 16}, {}, - ErrorSpec(0.0001)); -} - -TEST_F(ReduceWindowTest, Max4In16Stride8) { - const auto input = builder_.ConstantR1( - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - ReduceWindowMax(input, {4}, {8}, Padding::kValid); - ComputeAndCompareR1(&builder_, {4, 12}, {}, ErrorSpec(0.0001)); -} - -TEST_F(ReduceWindowTest, Max3In5Stride2) { - const auto input = builder_.ConstantR1({10000, 1000, 100, 10, 1}); - ReduceWindowMax(input, {3}, {2}, Padding::kValid); - ComputeAndCompareR1(&builder_, {10000, 100}, {}, ErrorSpec(0.0001)); -} - -TEST_F(ReduceWindowTest, Max3In5Stride1) { - const auto input = builder_.ConstantR1({10000, 1000, 100, 10, 101}); - ReduceWindowMax(input, {3}, {1}, Padding::kValid); - ComputeAndCompareR1(&builder_, {10000, 1000, 101}, {}, - ErrorSpec(0.0001)); -} - -TEST_F(ReduceWindowTest, Add3In4Stride2) { - const auto input = builder_.ConstantR1({1000, 100, 10, 1}); - ReduceWindowAdd(input, {3}, {2}, Padding::kValid); - ComputeAndCompareR1(&builder_, {1110}, {}, ErrorSpec(0.0001)); -} - -XLA_TEST_F(ReduceWindowTest, Add2In3SamePad) { - const auto input = builder_.ConstantR1({100, 10, 1}); - ReduceWindowAdd(input, {2}, {1}, Padding::kSame); - ComputeAndCompareR1(&builder_, {110, 11, 1}, {}, ErrorSpec(0.0001)); -} - -XLA_TEST_F(ReduceWindowTest, Add3In3SamePad) { - const auto input = builder_.ConstantR1({100, 10, 1}); - ReduceWindowAdd(input, {3}, {1}, Padding::kSame); - ComputeAndCompareR1(&builder_, {110, 111, 11}, {}, ErrorSpec(0.0001)); -} - -XLA_TEST_F(ReduceWindowTest, Add3In3Stride3SamePad) { - const auto input = builder_.ConstantR1({100, 10, 1}); - ReduceWindowAdd(input, {3}, {2}, Padding::kSame); - ComputeAndCompareR1(&builder_, {110, 11}, {}, ErrorSpec(0.0001)); -} - -TEST_F(ReduceWindowTest, Add2x2In2x2Overlapped) { - Array2D input_array({{1.2f, -2.5f, 0.9f, 1.0f}, - {3.7f, 0.2f, -1.0f, -0.2f}, - {-0.4f, 2.7f, 1.1f, 2.2f}, - {0.6f, 1.7f, 1.4f, -0.2f}}); - auto input = builder_.ConstantR2FromArray2D(input_array); - ReduceWindowAdd(input, {2, 2}, {1, 1}, Padding::kValid); - Array2D expected( - {{2.6f, -2.4f, 0.7f}, {6.2f, 3.0f, 2.1f}, {4.6f, 6.9f, 4.5f}}); - ComputeAndCompareR2(&builder_, expected, {}, ErrorSpec(0.0001)); -} - -TEST_F(ReduceWindowTest, Add2x2In2x2Disjoint) { - Array2D input_array({{1.2f, -2.5f, 0.9f, 1.0f}, - {3.7f, 0.2f, -1.0f, -0.2f}, - {-0.4f, 2.7f, 1.1f, 2.2f}, - {0.6f, 1.7f, 1.4f, -0.2f}}); - auto input = builder_.ConstantR2FromArray2D(input_array); - ReduceWindowAdd(input, {2, 2}, {2, 2}, Padding::kValid); - Array2D expected({ - {2.6f, 0.7f}, {4.6f, 4.5f}, - }); - ComputeAndCompareR2(&builder_, expected, {}, ErrorSpec(0.0001)); -} - -TEST_F(ReduceWindowTest, Add1x2In2x2Same) { - Array2D input_array({{1.0f, 2.0f}, {3.0f, 4.0f}}); - auto input = builder_.ConstantR2FromArray2D(input_array); - ReduceWindowAdd(input, {1, 2}, {1, 1}, Padding::kSame); - Array2D expected({ - {3.0f, 2.0f}, {7.0f, 4.0f}, - }); - ComputeAndCompareR2(&builder_, expected, {}, ErrorSpec(0.0001)); -} - XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x2) { Array3D input_array(2, 1, 2); input_array(0, 0, 0) = 1000; @@ -470,13 +299,620 @@ XLA_TEST_F(ReduceWindowTest, NonstandardReduceFunction) { ComputeAndCompareR4(&builder_, *expected, {}, ErrorSpec(1e-3, 1e-3)); } +TEST_F(ReduceWindowTest, R4UnitWindow) { + Array4D input_array(13, 12, 8, 15); + input_array.Fill(1.0f); + std::unique_ptr input_literal = + Literal::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({0, 3, 2, 1})); + ComputationDataHandle input = + builder_.Parameter(0, input_literal->shape(), "operand"); + + Padding padding = Padding::kSame; + ReduceWindowAdd(input, {1, 1, 7, 1}, {1, 4, 1, 1}, padding); + + auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 7, 1}, + {1, 4, 1, 1}, padding); + + TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr input_data, + client_->TransferToServer(*input_literal)); + ComputeAndCompareR4(&builder_, *res, {input_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(ReduceWindowTest, R4SecondMinorStride) { + Array4D input_array(2, 1, 27, 119); + input_array.FillRandom(2.0f); + std::unique_ptr input_literal = + Literal::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); + ComputationDataHandle input = + builder_.Parameter(0, input_literal->shape(), "operand"); + + int win_len = 1; + int stride = 8; + Padding padding = Padding::kSame; + ReduceWindowAdd(input, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); + + auto res = ReferenceUtil::ReduceWindow4DAdd( + input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); + + TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr input_data, + client_->TransferToServer(*input_literal)); + ComputeAndCompareR4(&builder_, *res, {input_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(ReduceWindowTest, R4SecondMinorUnitStride) { + Array4D input_array(3, 2, 4, 64); + input_array.FillRandom(2.0f); + std::unique_ptr input_literal = + Literal::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); + ComputationDataHandle input = + builder_.Parameter(0, input_literal->shape(), "operand"); + + int win_len = 3; + int stride = 1; + Padding padding = Padding::kSame; + ReduceWindowAdd(input, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); + + auto res = ReferenceUtil::ReduceWindow4DAdd( + input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); + + TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr input_data, + client_->TransferToServer(*input_literal)); + ComputeAndCompareR4(&builder_, *res, {input_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(ReduceWindowTest, R4SecondMinorWin) { + Array4D input_array(1, 3, 12, 200); + input_array.FillRandom(2.0f); + std::unique_ptr input_literal = + Literal::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); + ComputationDataHandle input = + builder_.Parameter(0, input_literal->shape(), "operand"); + + int win_len = 8; + int stride = 5; + Padding padding = Padding::kSame; + ReduceWindowAdd(input, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); + + auto res = ReferenceUtil::ReduceWindow4DAdd( + input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); + + TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr input_data, + client_->TransferToServer(*input_literal)); + ComputeAndCompareR4(&builder_, *res, {input_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +TEST_F(ReduceWindowTest, AmongMajor2DimsMultipleMinor) { + Array4D input_array(6, 4, 10, 130); + input_array.FillRandom(2.0f); + + int win_len = 3; + int win_stride = 2; + + Padding padding = Padding::kSame; + const auto input_data_handle = + builder_.ConstantR4FromArray4D(input_array); + // Reduce only along the x and y dimensions, according to the win_len. + ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1}, + {win_stride, win_stride, 1, 1}, padding); + + auto result = ReferenceUtil::ReduceWindow4DAdd( + input_array, 0.0f, {win_len, win_len, 1, 1}, + {win_stride, win_stride, 1, 1}, padding); + ComputeAndCompareR4(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(ReduceWindowTest, Add24In1152_NoOverlap) { + std::vector input_vector(128 * 9, 1); + const auto input = builder_.ConstantR1(input_vector); + ReduceWindowAdd(input, {32}, {128}, Padding::kValid); + ComputeAndCompareR1(&builder_, {32, 32, 32, 32, 32, 32, 32, 32, 32}, + {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ReduceWindowTest, Add128In128Stride128) { + const auto input = builder_.ConstantR1( + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + ReduceWindowAdd(input, {128}, {128}, Padding::kValid); + ComputeAndCompareR1(&builder_, {1088}, {}, ErrorSpec(0.0001)); +} + +// Regression test for a bug that appeared in Inception (b/34784899). +TEST_F(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) { + Array2D input_array(14, 14, 1.0f); + ComputationDataHandle input = + builder_.Broadcast(builder_.ConstantLiteral(Literal::One(F32)), {14, 14}); + + int win_len = 3; + int stride = 1; + Padding padding = Padding::kSame; + ReduceWindowAdd(input, {win_len, win_len}, {stride, stride}, padding); + + auto res = ReferenceUtil::ReduceWindow2DAdd( + input_array, 0.0f, {win_len, win_len}, {stride, stride}, padding); + + ComputeAndCompareR2(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); +} + +TEST_F(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { + Array2D input_array(6, 4, 1.0f); + ComputationDataHandle input = + builder_.Broadcast(builder_.ConstantLiteral(Literal::One(F32)), {6, 4}); + + Padding padding = Padding::kSame; + ReduceWindowAdd(input, {4, 2}, {3, 3}, padding); + + auto res = ReferenceUtil::ReduceWindow2DAdd(input_array, 0.0f, {4, 2}, {3, 3}, + padding); + + ComputeAndCompareR2(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); +} + +enum Reducer { kAdd, kMax }; + +struct R4ReduceWindowTestData { + int64 base_bounds[4]; + int64 window_bounds[4]; + int64 strides[4]; + int64 pad_low[4]; + int64 pad_high[4]; + + Reducer reducer; +}; + +string R4ReduceWindowTestDataToString( + const ::testing::TestParamInfo& data) { + string str = tensorflow::strings::StrCat( + "base_bounds_", + tensorflow::str_util::Join(data.param.base_bounds, "x"), // + "__window_bounds_", + tensorflow::str_util::Join(data.param.window_bounds, "x"), // + "__strides_", tensorflow::str_util::Join(data.param.strides, "x"), // + "__pad_low_", tensorflow::str_util::Join(data.param.pad_low, "x"), // + "__pad_high_", tensorflow::str_util::Join(data.param.pad_high, "x"), // + (data.param.reducer == kAdd) ? "add" : "max"); + CHECK(data.param.reducer == kAdd || data.param.reducer == kMax); + + // Test names are not allowed to contain the '-' character. + std::replace(str.begin(), str.end(), '-', 'n'); + return str; +} + +class R4ReduceWindowTest + : public ClientLibraryTestBase, + public ::testing::WithParamInterface { + protected: + void DoIt() { + ComputationBuilder b(client_, TestName()); + const auto& param = GetParam(); + + const float kInitValue = 0.0f; + + Array4D input(param.base_bounds[0], param.base_bounds[1], + param.base_bounds[2], param.base_bounds[3]); + input.FillIota(1); + std::unique_ptr input_literal = + Literal::CreateR4FromArray4D(input); + TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr input_arg, + client_->TransferToServer(*input_literal)); + + std::vector> padding(4); + for (int i = 0; i < 4; ++i) { + padding[i] = {param.pad_low[i], param.pad_high[i]}; + } + + auto parameter = b.Parameter(0, input_literal->shape(), "p0"); + auto pad_value = b.ConstantR0(kInitValue); + CHECK(param.reducer == kAdd || param.reducer == kMax); + auto computation = param.reducer == kAdd + ? CreateScalarAddComputation(F32, &b) + : CreateScalarMaxComputation(F32, &b); + b.ReduceWindowWithGeneralPadding( + /*operand=*/parameter, + /*init_value=*/pad_value, + /*computation=*/computation, + /*window_dimensions=*/param.window_bounds, + /*window_strides=*/param.strides, + /*padding=*/padding); + + CHECK(param.reducer == kAdd || param.reducer == kMax); + auto reduce_func = param.reducer == kAdd + ? +[](float a, float b) { return a + b; } + : +[](float a, float b) { return std::max(a, b); }; + std::unique_ptr> expected = + ReferenceUtil::ReduceWindow4DGeneric( + /*operand=*/input, + /*init=*/kInitValue, + /*reduce_func=*/reduce_func, + /*window=*/param.window_bounds, + /*stride=*/param.strides, + /*padding=*/padding); + ComputeAndCompareR4(&b, *expected, {input_arg.get()}); + } +}; + +TEST_P(R4ReduceWindowTest, DoIt) { DoIt(); } + +// base_bounds, window_bounds, strides, pad_low, pad_high +const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { + // Minimal edge case. + R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 1, 1}, + /*window_bounds=*/{1, 1, 1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*reducer=*/kAdd}, + + // Zero base bound edge case. + R4ReduceWindowTestData{/*base_bounds=*/{1, 0, 1, 1}, + /*window_bounds=*/{1, 1, 1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*reducer=*/kAdd}, + + // With non-1x1 window. + R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, + /*window_bounds=*/{2, 3, 1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*reducer=*/kAdd}, + + // With max instead of add. + R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, + /*window_bounds=*/{2, 3, 1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*reducer=*/kMax}, + + // With stride. + R4ReduceWindowTestData{/*base_bounds=*/{4, 10, 17, 140}, + /*window_bounds=*/{3, 2, 1, 1}, + /*strides=*/{2, 4, 1, 1}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*reducer=*/kAdd}, + + // With low padding. + R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, + /*window_bounds=*/{3, 2, 1, 1}, + /*strides=*/{2, 2, 1, 1}, + /*pad_low=*/{3, 2, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*reducer=*/kAdd}, + + // With high padding. + R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, + /*window_bounds=*/{3, 2, 1, 1}, + /*strides=*/{2, 2, 1, 1}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{2, 3, 0, 0}, + /*reducer=*/kAdd}, + + // Window touches both sides of the padding simultaneously. + R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 17, 140}, + /*window_bounds=*/{3, 3, 1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{1, 1, 0, 0}, + /*pad_high=*/{1, 1, 0, 0}, + /*reducer=*/kAdd}, + + // Window is entirely in the padding for some positions. + R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 17, 140}, + /*window_bounds=*/{3, 3, 1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{4, 4, 0, 0}, + /*pad_high=*/{4, 4, 0, 0}, + /*reducer=*/kAdd}, + + // Zero base bound with padding edge case. + R4ReduceWindowTestData{/*base_bounds=*/{2, 0, 3, 4}, + /*window_bounds=*/{1, 1, 1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{0, 1, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*reducer=*/kAdd}, + + // With stride, low padding and high padding. + R4ReduceWindowTestData{/*base_bounds=*/{4, 3, 17, 140}, + /*window_bounds=*/{3, 4, 1, 1}, + /*strides=*/{3, 1, 1, 1}, + /*pad_low=*/{10, 1, 0, 0}, + /*pad_high=*/{2, 3, 0, 0}, + /*reducer=*/kAdd}, + + // With second minor dimension == 9. + R4ReduceWindowTestData{/*base_bounds=*/{2, 3, 9, 127}, + /*window_bounds=*/{1, 1, 1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*reducer=*/kAdd}, + + // With minor dimension == 129. + R4ReduceWindowTestData{/*base_bounds=*/{3, 2, 7, 129}, + /*window_bounds=*/{1, 1, 1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*reducer=*/kAdd}, + + // With minor dims reduction and non-overlapped stride. + R4ReduceWindowTestData{/*base_bounds=*/{2, 2, 4, 16}, + /*window_bounds=*/{1, 1, 2, 2}, + /*strides=*/{1, 1, 2, 2}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*reducer=*/kAdd}, + + // With minor dims reduction and overlapped stride. + R4ReduceWindowTestData{/*base_bounds=*/{2, 2, 4, 16}, + /*window_bounds=*/{1, 1, 4, 4}, + /*strides=*/{1, 1, 2, 2}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*reducer=*/kAdd}, +}; + +INSTANTIATE_TEST_CASE_P(R4ReduceWindowTestInstantiation, R4ReduceWindowTest, + ::testing::ValuesIn(kR4ReduceWindowTestValues), + R4ReduceWindowTestDataToString); + +class R4ReduceWindowLargeTest : public R4ReduceWindowTest {}; + +XLA_TEST_P(R4ReduceWindowLargeTest, DoIt) { DoIt(); } + +// Test cases that are large/slow/failed. +const R4ReduceWindowTestData kR4ReduceWindowLargeTestValues[] = { + R4ReduceWindowTestData{/*base_bounds=*/{28, 28, 256, 128}, + /*window_bounds=*/{3, 3, 1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{1, 1, 0, 0}, + /*pad_high=*/{1, 1, 0, 0}, + /*reducer=*/kMax}, + + R4ReduceWindowTestData{/*base_bounds=*/{112, 112, 64, 128}, + /*window_bounds=*/{3, 3, 1, 1}, + /*strides=*/{2, 2, 1, 1}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{1, 1, 0, 0}, + /*reducer=*/kAdd}, +}; + +INSTANTIATE_TEST_CASE_P(R4ReduceWindowLargeTestInstantiation, + R4ReduceWindowLargeTest, + ::testing::ValuesIn(kR4ReduceWindowLargeTestValues), + R4ReduceWindowTestDataToString); + +struct R2ReduceWindowTestData { + int64 base_bounds[2]; + int64 window_bounds[2]; + int64 strides[2]; + int64 layout[2]; + Padding padding; + Reducer reducer; +} kR2TestCases[] = { + {/*base_bounds=*/{4, 18}, /*window_bounds=*/{2, 4}, + /*strides=*/{1, 2}, /*layout=*/{0, 1}, + /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{2, 5}, /*window_bounds=*/{2, 4}, + /*strides=*/{1, 1}, /*layout=*/{0, 1}, + /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{1, 3}, /*window_bounds=*/{2, 3}, + /*strides=*/{1, 1}, /*layout=*/{0, 1}, + /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{3, 129}, /*window_bounds=*/{1, 100}, + /*strides=*/{2, 99}, /*layout=*/{0, 1}, + /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{6, 152}, /*window_bounds=*/{2, 25}, + /*strides=*/{5, 4}, /*layout=*/{0, 1}, + /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{6, 4}, /*window_bounds=*/{4, 2}, + /*strides=*/{3, 3}, /*layout=*/{0, 1}, + /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{5, 147}, /*window_bounds=*/{1, 36}, + /*strides=*/{4, 5}, /*layout=*/{1, 0}, + /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{4, 153}, /*window_bounds=*/{2, 93}, + /*strides=*/{1, 1}, /*layout=*/{1, 0}, + /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + // Regression test for a bug that appeared in Inception (b/34784899). + {/*base_bounds=*/{28, 28}, /*window_bounds=*/{3, 3}, + /*strides=*/{1, 1}, /*layout=*/{1, 0}, + /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + // Regression test for a bug that appeared in Inception (b/34784899). + {/*base_bounds=*/{4, 32}, /*window_bounds=*/{2, 2}, + /*strides=*/{2, 2}, /*layout=*/{1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2}, + /*strides=*/{1, 1}, /*layout=*/{1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, +}; + +string R2ReduceWindowTestDataToString( + const ::testing::TestParamInfo& data) { + string str = tensorflow::strings::StrCat( + "base_bounds_", + tensorflow::str_util::Join(data.param.base_bounds, "x"), // + "__window_bounds_", + tensorflow::str_util::Join(data.param.window_bounds, "x"), // + "__strides_", tensorflow::str_util::Join(data.param.strides, "x"), // + "__padding_", data.param.padding == Padding::kSame ? "same" : "valid", // + "__layout_", data.param.layout[0], "_", data.param.layout[1], // + "__reducer_", data.param.reducer == kAdd ? "add" : "max"); + return str; +} + +class R2ReduceWindowTest + : public ClientLibraryTestBase, + public ::testing::WithParamInterface {}; + +TEST_P(R2ReduceWindowTest, Add) { + ComputationBuilder b(client_, TestName()); + const auto& param = GetParam(); + CHECK(param.reducer == kAdd); + + const float kInitValue = 0.0f; + Array2D input(param.base_bounds[0], param.base_bounds[1], 1.0f); + std::unique_ptr input_literal = + Literal::CreateR2FromArray2DWithLayout( + input, LayoutUtil::MakeLayout(param.layout)); + TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr input_arg, + client_->TransferToServer(*input_literal)); + b.ReduceWindow(/*operand=*/ + b.Parameter(0, input_literal->shape(), "p0"), + /*init_value=*/b.ConstantR0(kInitValue), + /*computation=*/CreateScalarAddComputation(F32, &b), + /*window_dimensions=*/param.window_bounds, + /*window_strides=*/param.strides, /*padding=*/param.padding); + + auto expected = ReferenceUtil::ReduceWindow2DAdd( + /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds, + /*stride=*/param.strides, /*padding=*/param.padding); + + ComputeAndCompareR2(&b, *expected, {input_arg.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +INSTANTIATE_TEST_CASE_P(R2ReduceWindowTestInstantiation, R2ReduceWindowTest, + ::testing::ValuesIn(kR2TestCases), + R2ReduceWindowTestDataToString); + +struct R1ReduceWindowTestData { + int64 base_bounds[1]; + int64 window_bounds[1]; + int64 strides[1]; + Padding padding; + Reducer reducer; +} kR1TestCases[] = { + {/*base_bounds=*/{1}, /*window_bounds=*/{1}, + /*strides=*/{1}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{3}, /*window_bounds=*/{3}, + /*strides=*/{1}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{3}, /*window_bounds=*/{2}, + /*strides=*/{1}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{5}, /*window_bounds=*/{1}, + /*strides=*/{1}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax}, + + {/*base_bounds=*/{16}, /*window_bounds=*/{4}, + /*strides=*/{4}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax}, + + {/*base_bounds=*/{16}, /*window_bounds=*/{4}, + /*strides=*/{3}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{128 * 2}, /*window_bounds=*/{30}, + /*strides=*/{27}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{128 * 17}, /*window_bounds=*/{7}, + /*strides=*/{64}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{128 * 2}, /*window_bounds=*/{32}, + /*strides=*/{56}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{3}, /*window_bounds=*/{2}, + /*strides=*/{1}, + /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{5}, /*window_bounds=*/{3}, + /*strides=*/{2}, + /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{16}, /*window_bounds=*/{4}, + /*strides=*/{3}, + /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, +}; + +string R1ReduceWindowTestDataToString( + const ::testing::TestParamInfo& data) { + string str = tensorflow::strings::StrCat( + "base_bounds_", + tensorflow::str_util::Join(data.param.base_bounds, "x"), // + "__window_bounds_", + tensorflow::str_util::Join(data.param.window_bounds, "x"), // + "__strides_", tensorflow::str_util::Join(data.param.strides, "x"), // + "__padding_", data.param.padding == Padding::kSame ? "same" : "valid", // + "__reducer_", data.param.reducer == kAdd ? "add" : "max"); + return str; +} + +class R1ReduceWindowTest + : public ClientLibraryTestBase, + public ::testing::WithParamInterface {}; + +TEST_P(R1ReduceWindowTest, DoIt) { + ComputationBuilder b(client_, TestName()); + const auto& param = GetParam(); + CHECK(param.reducer == kAdd || param.reducer == kMax); + + const float kInitValue = 0.0f; + std::vector input_vector(param.base_bounds[0]); + std::iota(std::begin(input_vector), std::end(input_vector), 0); + std::unique_ptr input_literal = + Literal::CreateR1(tensorflow::gtl::ArraySlice(input_vector)); + TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr input_arg, + client_->TransferToServer(*input_literal)); + + auto computation = param.reducer == kAdd + ? CreateScalarAddComputation(F32, &b) + : CreateScalarMaxComputation(F32, &b); + b.ReduceWindow(/*operand=*/ + b.Parameter(0, input_literal->shape(), "p0"), + /*init_value=*/b.ConstantR0(kInitValue), + /*computation=*/computation, + /*window_dimensions=*/param.window_bounds, + /*window_strides=*/param.strides, /*padding=*/param.padding); + + auto reduce_func = param.reducer == kAdd + ? +[](float a, float b) { return a + b; } + : +[](float a, float b) { return std::max(a, b); }; + auto expected = ReferenceUtil::ReduceWindow1DGeneric( + /*operand=*/tensorflow::gtl::ArraySlice(input_vector), + /*init=*/kInitValue, + /*reduce_func=*/reduce_func, + /*window=*/param.window_bounds, + /*stride=*/param.strides, /*padding=*/param.padding); + + ComputeAndCompareR1(&b, tensorflow::gtl::ArraySlice(*expected), + {input_arg.get()}, ErrorSpec(1e-3, 1e-3)); +} + +INSTANTIATE_TEST_CASE_P(R1ReduceWindowTestInstantiation, R1ReduceWindowTest, + ::testing::ValuesIn(kR1TestCases), + R1ReduceWindowTestDataToString); } // namespace } // namespace xla int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc index 7c6700feef846242cc49e573fee01c0101b05335..cb7f54ea01c2f063db1575bd498634f5107a39c5 100644 --- a/tensorflow/compiler/xla/tests/replay_test.cc +++ b/tensorflow/compiler/xla/tests/replay_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/protobuf_util.h" @@ -61,7 +60,8 @@ TEST_F(ReplayTest, TwoPlusTwoReplay) { // Run it. std::unique_ptr literal = - client_->ExecuteAndTransfer(replayed, /*arguments=*/{}) + client_ + ->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_) .ConsumeValueOrDie(); // Expect 4. @@ -92,15 +92,16 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { // Run it. std::unique_ptr x_data = - client_->TransferToServer(*LiteralUtil::CreateR0(2)) + client_->TransferToServer(*Literal::CreateR0(2)) .ConsumeValueOrDie(); std::unique_ptr y_data = - client_->TransferToServer(*LiteralUtil::CreateR0(3)) + client_->TransferToServer(*Literal::CreateR0(3)) .ConsumeValueOrDie(); std::unique_ptr literal = client_ ->ExecuteAndTransfer(replayed, - /*arguments=*/{x_data.get(), y_data.get()}) + /*arguments=*/{x_data.get(), y_data.get()}, + &execution_options_) .ConsumeValueOrDie(); // Expect 5. @@ -141,7 +142,8 @@ TEST_F(ReplayTest, MapPlusTwoOverR1) { // Run it. std::unique_ptr literal = - client_->ExecuteAndTransfer(replayed, /*arguments=*/{}) + client_ + ->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_) .ConsumeValueOrDie(); // Expect result. @@ -154,7 +156,6 @@ TEST_F(ReplayTest, MapPlusTwoOverR1) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/reshape_motion_test.cc b/tensorflow/compiler/xla/tests/reshape_motion_test.cc index c9817bc23d821d95e660b359ce72ae6f4dec6c85..3051562455f48625def2840913314b16e8de2b72 100644 --- a/tensorflow/compiler/xla/tests/reshape_motion_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_motion_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -63,7 +62,6 @@ TEST_F(ReshapeMotionTest, ElementwiseOfReshapesWithNonSameInputShapes) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index ae7d07727b1e2c20d629f2abc5e58036060f0cef..6748d196c1a6305cc6e3ff87191d2c96a45bf0e7 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -71,7 +70,7 @@ XLA_TEST_F(ReshapeTest, SingleElementArrayToScalar) { XLA_TEST_F(ReshapeTest, ScalarToSingleElementArray) { ComputationBuilder builder(client_, TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(1.0f); + std::unique_ptr param0_literal = Literal::CreateR0(1.0f); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -99,7 +98,7 @@ XLA_TEST_F(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR2FromArray2D(Array2D(0, 3)); + Literal::CreateR2FromArray2D(Array2D(0, 3)); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -403,7 +402,7 @@ XLA_TEST_F(ReshapeTest, FullyConnectedCollapseDesugared) { XLA_TEST_F(ReshapeTest, ToScalar) { for (int rank = 0; rank < 8; ++rank) { ComputationBuilder b(client_, TestName()); - auto input = LiteralUtil::CreateR1({83.0f}); + auto input = Literal::CreateR1({83.0f}); std::vector ones(rank, 1); // this is {1, ..., 1}. std::vector dimensions(rank); std::iota(dimensions.begin(), dimensions.end(), 0); @@ -435,7 +434,7 @@ XLA_TEST_F(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { builder.Reshape(a, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8}); // clang-format off - auto literal = LiteralUtil::CreateR4FromArray4DWithLayout(Array4D{ + auto literal = Literal::CreateR4FromArray4DWithLayout(Array4D{ { { {0, 1}, @@ -467,7 +466,7 @@ XLA_TEST_F(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { }); Computation computation = builder.Build().ConsumeValueOrDie(); - ExecutionOptions execution_options; + ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = ShapeUtil::MakeShapeWithLayout(F32, {2, 8}, {1, 0}); std::unique_ptr actual = @@ -475,12 +474,12 @@ XLA_TEST_F(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { ->ExecuteAndTransfer(computation, {input.get()}, &execution_options) .ConsumeValueOrDie(); std::unique_ptr expected = - LiteralUtil::CreateR2FromArray2D(expected_array); + Literal::CreateR2FromArray2D(expected_array); LiteralTestUtil::ExpectEqual(*expected, *actual); } XLA_TEST_F(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { - std::unique_ptr input = LiteralUtil::CreateR2({ + std::unique_ptr input = Literal::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, @@ -508,7 +507,7 @@ XLA_TEST_F(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { // Tests R2->R4 reshape with the reshape dimensions {1, 0}. XLA_TEST_F(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { - std::unique_ptr input = LiteralUtil::CreateR2({ + std::unique_ptr input = Literal::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, @@ -542,7 +541,7 @@ XLA_TEST_F(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( + Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -565,7 +564,7 @@ XLA_TEST_F(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( + Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -589,7 +588,7 @@ XLA_TEST_F(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( + Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -603,7 +602,7 @@ XLA_TEST_F(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { expected_array(indices[0], indices[2] * 30 + indices[1] * 3 + indices[3]) = *cell; }); - auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); + auto expected = Literal::CreateR2FromArray2D(expected_array); ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}); } @@ -615,7 +614,7 @@ XLA_TEST_F(ReshapeTest, NoopReshape) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( + Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({1, 2, 3, 0})); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -626,7 +625,7 @@ XLA_TEST_F(ReshapeTest, NoopReshape) { /*new_sizes=*/{7, 2, 3, 5}); Computation computation = builder.Build().ConsumeValueOrDie(); - ExecutionOptions execution_options; + ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = ShapeUtil::MakeShapeWithLayout(F32, {7, 2, 3, 5}, {2, 3, 0, 1}); std::unique_ptr output_literal = @@ -642,7 +641,7 @@ XLA_TEST_F(ReshapeTest, NoopReshape) { } XLA_TEST_F(ReshapeTest, R4ToR4Reshape_Trivial) { - auto literal_1x2x3x4 = LiteralUtil::CreateR4( + auto literal_1x2x3x4 = Literal::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); @@ -655,7 +654,7 @@ XLA_TEST_F(ReshapeTest, R4ToR4Reshape_Trivial) { } XLA_TEST_F(ReshapeTest, R4ToR4Reshape) { - auto literal_1x2x3x4 = LiteralUtil::CreateR4( + auto literal_1x2x3x4 = Literal::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); @@ -665,7 +664,7 @@ XLA_TEST_F(ReshapeTest, R4ToR4Reshape) { /*new_sizes=*/{2, 4, 3, 1}); // clang-format off - auto expected_2x4x3x1 = LiteralUtil::CreateR4( + auto expected_2x4x3x1 = Literal::CreateR4( {{{{1}, {5}, {9}}, {{2}, {6}, {10}}, {{3}, {7}, {11}}, @@ -689,7 +688,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeSimple) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( + Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -698,9 +697,9 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeSimple) { auto a = builder.Parameter(0, input_literal->shape(), "a"); builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = LiteralUtil::Relayout( - *LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal), - LayoutUtil::MakeLayout({3, 2, 1, 0})); + std::unique_ptr expected = + LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) + ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. @@ -718,7 +717,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( + Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -727,9 +726,9 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { auto a = builder.Parameter(0, input_literal->shape(), "a"); builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = LiteralUtil::Relayout( - *LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal), - LayoutUtil::MakeLayout({3, 2, 1, 0})); + std::unique_ptr expected = + LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) + ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. @@ -747,7 +746,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( + Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -756,9 +755,9 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { auto a = builder.Parameter(0, input_literal->shape(), "a"); builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = LiteralUtil::Relayout( - *LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal), - LayoutUtil::MakeLayout({3, 2, 1, 0})); + std::unique_ptr expected = + LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) + ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. @@ -777,7 +776,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( + Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -786,9 +785,9 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { auto a = builder.Parameter(0, input_literal->shape(), "a"); builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = LiteralUtil::Relayout( - *LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal), - LayoutUtil::MakeLayout({3, 2, 1, 0})); + std::unique_ptr expected = + LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) + ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. @@ -806,7 +805,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeTrivialR2) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( + Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({0, 1, 2, 3})); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -815,9 +814,9 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeTrivialR2) { auto a = builder.Parameter(0, input_literal->shape(), "a"); builder.Reshape(a, /*dimensions=*/{1, 0, 2, 3}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = LiteralUtil::Relayout( - *LiteralTestUtil::Reshape(new_bounds, {1, 0, 2, 3}, *input_literal), - input_literal->shape().layout()); + std::unique_ptr expected = + LiteralTestUtil::Reshape(new_bounds, {1, 0, 2, 3}, *input_literal) + ->Relayout(input_literal->shape().layout()); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. @@ -831,7 +830,6 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeTrivialR2) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc index 5ca9702380f4e37b6ba90459222faf832472bbf7..2f72fc0729a8634456986f294bd26de2c37a5212 100644 --- a/tensorflow/compiler/xla/tests/reverse_test.cc +++ b/tensorflow/compiler/xla/tests/reverse_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -159,7 +158,6 @@ TEST_F(ReverseTest, Reverse4DFloatArrayOnDim01) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc index 05ce22fc359d5c805840e0f07f645cfb8ffb7786..5b4c05c673339a455c9e58d81c73ede182e0f110 100644 --- a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/packed_literal_reader.h" @@ -66,8 +65,8 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR1F32Length2) { reader.Read(ShapeUtil::MakeShape(F32, {2})).ConsumeValueOrDie(); EXPECT_TRUE(reader.IsExhausted()); - EXPECT_EQ(42.0, LiteralUtil::Get(*actual, {0})); - EXPECT_EQ(24.0, LiteralUtil::Get(*actual, {1})); + EXPECT_EQ(42.0, actual->Get({0})); + EXPECT_EQ(24.0, actual->Get({1})); } TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) { @@ -96,10 +95,10 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) { .ConsumeValueOrDie(); EXPECT_TRUE(reader.IsExhausted()); - EXPECT_EQ(42.0f, LiteralUtil::Get(*actual, {0, 0})); - EXPECT_EQ(24.0f, LiteralUtil::Get(*actual, {0, 1})); - EXPECT_EQ(64.0f, LiteralUtil::Get(*actual, {1, 0})); - EXPECT_EQ(46.0f, LiteralUtil::Get(*actual, {1, 1})); + EXPECT_EQ(42.0f, actual->Get({0, 0})); + EXPECT_EQ(24.0f, actual->Get({0, 1})); + EXPECT_EQ(64.0f, actual->Get({1, 0})); + EXPECT_EQ(46.0f, actual->Get({1, 1})); std::unique_ptr round_tripped = RoundTripToServer(*actual); LiteralTestUtil::ExpectEqual(*round_tripped, *actual); @@ -131,10 +130,10 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) { .ConsumeValueOrDie(); EXPECT_TRUE(reader.IsExhausted()); - EXPECT_EQ(42.0f, LiteralUtil::Get(*actual, {0, 0})); - EXPECT_EQ(24.0f, LiteralUtil::Get(*actual, {1, 0})); - EXPECT_EQ(64.0f, LiteralUtil::Get(*actual, {0, 1})); - EXPECT_EQ(46.0f, LiteralUtil::Get(*actual, {1, 1})); + EXPECT_EQ(42.0f, actual->Get({0, 0})); + EXPECT_EQ(24.0f, actual->Get({1, 0})); + EXPECT_EQ(64.0f, actual->Get({0, 1})); + EXPECT_EQ(46.0f, actual->Get({1, 1})); std::unique_ptr round_tripped = RoundTripToServer(*actual); LiteralTestUtil::ExpectEqual(*round_tripped, *actual); @@ -146,7 +145,6 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc index f0760241cdb4e555f3536d024d278c87376bb4d3..e6a6b7b37a4308f2c00f35ae8d3013a59f6c05e7 100644 --- a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -48,62 +47,61 @@ class RoundTripTransferTest : public ClientLibraryTestBase { }; TEST_F(RoundTripTransferTest, R0S32) { - RoundTripTest(*LiteralUtil::CreateR0(42)); + RoundTripTest(*Literal::CreateR0(42)); } TEST_F(RoundTripTransferTest, R0F32) { - RoundTripTest(*LiteralUtil::CreateR0(42.0)); + RoundTripTest(*Literal::CreateR0(42.0)); } TEST_F(RoundTripTransferTest, R1F32_Len0) { - RoundTripTest(*LiteralUtil::CreateR1({})); + RoundTripTest(*Literal::CreateR1({})); } TEST_F(RoundTripTransferTest, R1F32_Len2) { - RoundTripTest(*LiteralUtil::CreateR1({42.0, 64.0})); + RoundTripTest(*Literal::CreateR1({42.0, 64.0})); } TEST_F(RoundTripTransferTest, R1F32_Len256) { std::vector values(256); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1(values)); + RoundTripTest(*Literal::CreateR1(values)); } TEST_F(RoundTripTransferTest, R1F32_Len1024) { std::vector values(1024); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1(values)); + RoundTripTest(*Literal::CreateR1(values)); } TEST_F(RoundTripTransferTest, R1F32_Len1025) { std::vector values(1025); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1(values)); + RoundTripTest(*Literal::CreateR1(values)); } TEST_F(RoundTripTransferTest, R1F32_Len4096) { std::vector values(4096); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1(values)); + RoundTripTest(*Literal::CreateR1(values)); } TEST_F(RoundTripTransferTest, R2F32_Len10x0) { - RoundTripTest( - *LiteralUtil::CreateR2FromArray2D(Array2D(10, 0))); + RoundTripTest(*Literal::CreateR2FromArray2D(Array2D(10, 0))); } TEST_F(RoundTripTransferTest, R2F32_Len2x2) { - RoundTripTest(*LiteralUtil::CreateR2({{42.0, 64.0}, {77.0, 88.0}})); + RoundTripTest(*Literal::CreateR2({{42.0, 64.0}, {77.0, 88.0}})); } TEST_F(RoundTripTransferTest, R3F32) { RoundTripTest( - *LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, - {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}})); + *Literal::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, + {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}})); } TEST_F(RoundTripTransferTest, R4F32) { - RoundTripTest(*LiteralUtil::CreateR4({{ + RoundTripTest(*Literal::CreateR4({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, @@ -111,36 +109,33 @@ TEST_F(RoundTripTransferTest, R4F32) { } TEST_F(RoundTripTransferTest, EmptyTuple) { - RoundTripTest(*LiteralUtil::MakeTuple({})); + RoundTripTest(*Literal::MakeTuple({})); } TEST_F(RoundTripTransferTest, TupleOfR1F32) { - RoundTripTest( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2}).get(), - LiteralUtil::CreateR1({3, 4}).get()})); + RoundTripTest(*Literal::MakeTuple({Literal::CreateR1({1, 2}).get(), + Literal::CreateR1({3, 4}).get()})); } TEST_F(RoundTripTransferTest, TupleOfR1F32_Len0_Len2) { - RoundTripTest( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({}).get(), - LiteralUtil::CreateR1({3, 4}).get()})); + RoundTripTest(*Literal::MakeTuple({Literal::CreateR1({}).get(), + Literal::CreateR1({3, 4}).get()})); } TEST_F(RoundTripTransferTest, TupleOfR0F32AndR1S32) { - RoundTripTest( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(1.0).get(), - LiteralUtil::CreateR1({2, 3}).get()})); + RoundTripTest(*Literal::MakeTuple({Literal::CreateR0(1.0).get(), + Literal::CreateR1({2, 3}).get()})); } // Below two tests are added to identify the cost of large data transfers. TEST_F(RoundTripTransferTest, R2F32_Large) { - RoundTripTest(*LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512)); + RoundTripTest(*Literal::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512)); } TEST_F(RoundTripTransferTest, R4F32_Large) { Array4D array4d(2, 2, 256, 256); array4d.FillWithMultiples(1.0f); - RoundTripTest(*LiteralUtil::CreateR4FromArray4D(array4d)); + RoundTripTest(*Literal::CreateR4FromArray4D(array4d)); } } // namespace @@ -149,7 +144,6 @@ TEST_F(RoundTripTransferTest, R4F32_Large) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index 47a39ffbbc42dedccc98694a23372cb064da752a..07bd00f015406784b6fc97ddf2a4b6f12ddd5864 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -212,9 +211,9 @@ TEST_F(ScalarComputationsTest, MulThreeScalarsS32) { TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) { ComputationBuilder builder(client_, TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR0(2.1f); - std::unique_ptr b_literal = LiteralUtil::CreateR0(5.5f); - std::unique_ptr c_literal = LiteralUtil::CreateR0(0.5f); + std::unique_ptr a_literal = Literal::CreateR0(2.1f); + std::unique_ptr b_literal = Literal::CreateR0(5.5f); + std::unique_ptr c_literal = Literal::CreateR0(0.5f); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); @@ -361,8 +360,8 @@ TEST_F(ScalarComputationsTest, DivU32s) { for (uint32 divisor : vals) { if (divisor != 0) { for (uint32 dividend : vals) { - auto dividend_literal = LiteralUtil::CreateR0(dividend); - auto divisor_literal = LiteralUtil::CreateR0(divisor); + auto dividend_literal = Literal::CreateR0(dividend); + auto divisor_literal = Literal::CreateR0(divisor); TF_ASSIGN_OR_ASSERT_OK(auto dividend_data, client_->TransferToServer(*dividend_literal)); TF_ASSIGN_OR_ASSERT_OK(auto divisor_data, @@ -373,8 +372,7 @@ TEST_F(ScalarComputationsTest, DivU32s) { {dividend_data.get(), divisor_data.get()}, &execution_options_) .ConsumeValueOrDie(); - auto expected_literal = - LiteralUtil::CreateR0(dividend / divisor); + auto expected_literal = Literal::CreateR0(dividend / divisor); LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); } } @@ -403,8 +401,8 @@ TEST_F(ScalarComputationsTest, RemU32s) { for (uint32 divisor : vals) { if (divisor != 0) { for (uint32 dividend : vals) { - auto dividend_literal = LiteralUtil::CreateR0(dividend); - auto divisor_literal = LiteralUtil::CreateR0(divisor); + auto dividend_literal = Literal::CreateR0(dividend); + auto divisor_literal = Literal::CreateR0(divisor); TF_ASSIGN_OR_ASSERT_OK(auto dividend_data, client_->TransferToServer(*dividend_literal)); TF_ASSIGN_OR_ASSERT_OK(auto divisor_data, @@ -415,8 +413,7 @@ TEST_F(ScalarComputationsTest, RemU32s) { {dividend_data.get(), divisor_data.get()}, &execution_options_) .ConsumeValueOrDie(); - auto expected_literal = - LiteralUtil::CreateR0(dividend % divisor); + auto expected_literal = Literal::CreateR0(dividend % divisor); LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); } } @@ -428,7 +425,7 @@ TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) { auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x"); builder.Rem(x, builder.ConstantR0(80000)); - std::unique_ptr literal = LiteralUtil::CreateR0(87919); + std::unique_ptr literal = Literal::CreateR0(87919); TF_ASSIGN_OR_ASSERT_OK(auto input_data, client_->TransferToServer(*literal)); ComputeAndCompareR0(&builder, 7919, {input_data.get()}); } @@ -764,7 +761,7 @@ TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) { TEST_F(ScalarComputationsTest, SqrtF320) { ComputationBuilder builder(client_, TestName()); - Literal zero_literal = LiteralUtil::Zero(PrimitiveType::F32); + Literal zero_literal = Literal::Zero(PrimitiveType::F32); std::unique_ptr zero_data = client_->TransferToServer(zero_literal).ConsumeValueOrDie(); @@ -782,7 +779,6 @@ TEST_F(ScalarComputationsTest, SqrtF320) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc index 36110da2478083a45d5d378935278de42d55d221..de89588042ec097180906f49fb5b0c4b1fe16edd 100644 --- a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc +++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -381,7 +380,6 @@ XLA_TEST_F(SelectAndScatterTest, R1F32OverlappingWindowMinScatter) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/select_test.cc b/tensorflow/compiler/xla/tests/select_test.cc index 5eb4fee8ed28192a238efe2e6c9e1cad49a5f836..6b48116b6e1317eb23624242f1de656c3e7d48ca 100644 --- a/tensorflow/compiler/xla/tests/select_test.cc +++ b/tensorflow/compiler/xla/tests/select_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -262,7 +261,6 @@ TEST_F(SelectTest, SelectR1F32WithScalarPredicateFalse) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/set_return_value_test.cc b/tensorflow/compiler/xla/tests/set_return_value_test.cc index 25bb915be56560e9a4eb0ebce990f488fe074241..38fc27f200ce823c2385d9456f8754dfccb1525e 100644 --- a/tensorflow/compiler/xla/tests/set_return_value_test.cc +++ b/tensorflow/compiler/xla/tests/set_return_value_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -102,7 +101,6 @@ TEST_F(SetReturnValueTest, SetValueMultipleTimesAndModify) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index 70345c300cc778d9a52ffb857b8a1df2531e8d30..5e7d47566245fe72eb8b01c7abd85b29a305ea02 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -269,7 +268,6 @@ INSTANTIATE_TEST_CASE_P( int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index 6a23df4d3c35a17a56b4ce816f79eaa642831f90..f3a522b05ebae4f1f86d6d7ddbac6e1749d3e286 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -61,7 +61,7 @@ std::unique_ptr CreateR2LiteralWithLayout( auto literal = MakeUnique(); const int64 d0 = values.size(); const int64 d1 = values.begin()->size(); - LiteralUtil::PopulateWithValue(0, {d0, d1}, literal.get()); + literal.get()->PopulateWithValue(0, {d0, d1}); *literal->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); TF_CHECK_OK(ShapeUtil::ValidateShape(literal->shape())); @@ -70,7 +70,7 @@ std::unique_ptr CreateR2LiteralWithLayout( for (auto inner_list : values) { int64 dim1 = 0; for (auto value : inner_list) { - LiteralUtil::Set(literal.get(), {dim0, dim1}, value); + literal.get()->Set({dim0, dim1}, value); ++dim1; } ++dim0; @@ -88,7 +88,7 @@ std::unique_ptr CreateR3LiteralWithLayout( const int64 d0 = values.size(); const int64 d1 = values.begin()->size(); const int64 d2 = values.begin()->begin()->size(); - LiteralUtil::PopulateWithValue(0, {d0, d1, d2}, literal.get()); + literal.get()->PopulateWithValue(0, {d0, d1, d2}); *literal->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); TF_CHECK_OK(ShapeUtil::ValidateShape(literal->shape())); @@ -99,7 +99,7 @@ std::unique_ptr CreateR3LiteralWithLayout( for (auto inner_inner_list : inner_list) { int64 dim2 = 0; for (auto value : inner_inner_list) { - LiteralUtil::Set(literal.get(), {dim0, dim1, dim2}, value); + literal.get()->Set({dim0, dim1, dim2}, value); ++dim2; } ++dim1; diff --git a/tensorflow/compiler/xla/tests/transpose_test.cc b/tensorflow/compiler/xla/tests/transpose_test.cc index e4951c4201060ae01d48f438bc462191de372f0e..07c0f073e86ee204a90b1f138c8c6d90a5c6936a 100644 --- a/tensorflow/compiler/xla/tests/transpose_test.cc +++ b/tensorflow/compiler/xla/tests/transpose_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -189,7 +188,6 @@ TEST_F(TransposeTest, TransposeConstant021_MultipleTilesPerLayer) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 6309e7129735aaaa81a14974ea52bd4cba219dc3..4a1c3fe9629218a0c3c8f5ccacd5500cedf73b61 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -40,6 +39,25 @@ class TupleTest : public ClientLibraryTestBase { ErrorSpec error_spec_{0.0001}; }; +// Tests a tuple-shaped constant. +XLA_TEST_F(TupleTest, TupleConstant) { + ComputationBuilder builder(client_, TestName()); + + const float constant_scalar = 7.3f; + std::initializer_list constant_vector = {1.1f, 2.0f, 3.3f}; + std::initializer_list> constant_matrix = { + {1.1f, 2.2f, 3.5f}, // row 0 + {4.8f, 5.0f, 6.7f}, // row 1 + }; + auto value = + Literal::MakeTuple({Literal::CreateR0(constant_scalar).get(), + Literal::CreateR1(constant_vector).get(), + Literal::CreateR2(constant_matrix).get()}); + + auto result = builder.ConstantLiteral(*value); + ComputeAndCompareTuple(&builder, *value, {}, error_spec_); +} + // Tests the creation of tuple data. XLA_TEST_F(TupleTest, TupleCreate) { ComputationBuilder builder(client_, TestName()); @@ -54,10 +72,10 @@ XLA_TEST_F(TupleTest, TupleCreate) { builder.ConstantR1(constant_vector), builder.ConstantR2(constant_matrix)}); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(constant_scalar).get(), - LiteralUtil::CreateR1(constant_vector).get(), - LiteralUtil::CreateR2(constant_matrix).get()}); + auto expected = + Literal::MakeTuple({Literal::CreateR0(constant_scalar).get(), + Literal::CreateR1(constant_vector).get(), + Literal::CreateR2(constant_matrix).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -68,9 +86,8 @@ XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) { auto result = builder.Tuple( {builder.ConstantR0(7.0), builder.ConstantR1({})}); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(7.0).get(), - LiteralUtil::CreateR1({}).get()}); + auto expected = Literal::MakeTuple({Literal::CreateR0(7.0).get(), + Literal::CreateR1({}).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -78,7 +95,7 @@ XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) { XLA_TEST_F(TupleTest, EmptyTupleCreate) { ComputationBuilder builder(client_, TestName()); auto result = builder.Tuple({}); - auto expected = LiteralUtil::MakeTuple({}); + auto expected = Literal::MakeTuple({}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -147,12 +164,37 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) { builder.ConstantR2(constant_matrix)}); auto new_tuple = builder.Tuple({builder.GetTupleElement(tuple_data, 1), builder.GetTupleElement(tuple_data, 0)}); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2(constant_matrix).get(), - LiteralUtil::CreateR1(constant_vector).get()}); + auto expected = + Literal::MakeTuple({Literal::CreateR2(constant_matrix).get(), + Literal::CreateR1(constant_vector).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } +XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { + ComputationBuilder b(client_, TestName()); + ComputationDataHandle v1, v2; + + for (bool direction : {false, true}) { + std::unique_ptr v1_data = + CreateR0Parameter(0.0f, /*parameter_number=*/0, /*name=*/"v1", + /*builder=*/&b, /*data_handle=*/&v1); + std::unique_ptr v2_data = + CreateR0Parameter(1.0f, /*parameter_number=*/1, /*name=*/"v2", + /*builder=*/&b, /*data_handle=*/&v2); + auto v1_gt = b.Gt(v1, v2); // false + auto v2_gt = b.Gt(v2, v1); // true + auto v1_v2 = b.Tuple({v1_gt, v2_gt}); // {false, true} + auto v2_v1 = b.Tuple({v2_gt, v1_gt}); // {true, false} + auto select = b.Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1); + auto expected = + Literal::MakeTuple({Literal::CreateR0(direction).get(), + Literal::CreateR0(!direction).get()}); + + ComputeAndCompareTuple(&b, *expected, {v1_data.get(), v2_data.get()}, + error_spec_); + } +} + // Builds two new tuples from an existing tuple (by means of GetTupleElement), // then adds up the components of the new tuples. XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) { @@ -213,9 +255,8 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnFalse)) { auto select = builder.Select(builder.ConstantR0(false), tuple12, tuple21); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec2).get(), - LiteralUtil::CreateR1(vec1).get()}); + auto expected = Literal::MakeTuple({Literal::CreateR1(vec2).get(), + Literal::CreateR1(vec1).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -259,9 +300,8 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnTrue)) { auto select = builder.Select(builder.ConstantR0(true), tuple12, tuple21); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec1).get(), - LiteralUtil::CreateR1(vec2).get()}); + auto expected = Literal::MakeTuple({Literal::CreateR1(vec1).get(), + Literal::CreateR1(vec2).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -340,9 +380,8 @@ XLA_TEST_F(TupleTest, auto select = builder.Select(builder.ConstantR0(false), tuple12, tuple21); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec2).get(), - LiteralUtil::CreateR1(vec1).get()}); + auto expected = Literal::MakeTuple({Literal::CreateR1(vec2).get(), + Literal::CreateR1(vec1).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -353,13 +392,13 @@ XLA_TEST_F(TupleTest, NestedTuples) { auto outer_tuple = builder.Tuple({inner_tuple, builder.ConstantR1({22.0, 44.0})}); - auto expected_v1 = LiteralUtil::CreateR1({1.0, 2.0}); - auto expected_s = LiteralUtil::CreateR0(42.0); + auto expected_v1 = Literal::CreateR1({1.0, 2.0}); + auto expected_s = Literal::CreateR0(42.0); auto expected_inner_tuple = - LiteralUtil::MakeTuple({expected_v1.get(), expected_s.get()}); - auto expected_v2 = LiteralUtil::CreateR1({22.0, 44.0}); + Literal::MakeTuple({expected_v1.get(), expected_s.get()}); + auto expected_v2 = Literal::CreateR1({22.0, 44.0}); auto expected = - LiteralUtil::MakeTuple({expected_inner_tuple.get(), expected_v2.get()}); + Literal::MakeTuple({expected_inner_tuple.get(), expected_v2.get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -379,14 +418,14 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { std::unique_ptr data = client_ - ->TransferToServer(*LiteralUtil::MakeTuple({ - LiteralUtil::MakeTuple( + ->TransferToServer(*Literal::MakeTuple({ + Literal::MakeTuple( { - LiteralUtil::CreateR1({1.0, 2.0, 3.0}).get(), - LiteralUtil::CreateR1({4.0, 5.0, 6.0}).get(), + Literal::CreateR1({1.0, 2.0, 3.0}).get(), + Literal::CreateR1({4.0, 5.0, 6.0}).get(), }) .get(), - LiteralUtil::CreateR1({7.0, 8.0, 9.0}).get(), + Literal::CreateR1({7.0, 8.0, 9.0}).get(), })) .ConsumeValueOrDie(); @@ -401,7 +440,6 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc index 61110d5b4cdaea62aa9844a195ee95698bf1632e..d35d9ecdeb6661ff5d5c8940a0e9dcc609aeb9a2 100644 --- a/tensorflow/compiler/xla/tests/unary_op_test.cc +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -165,7 +164,6 @@ TEST_F(UnaryOpTest, SignAbsTestR2) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc index 26a08953b1534044058a001a8c9a66e6ab6461b0..079dbb06117949c870f89e1a3258e31463aa28ec 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -221,7 +220,6 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32AllDims) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc index efde45375fdbe8c0abbba0817f9d3062a118ab3c..b2e0c796bde46bac357635a0ab35dc521da7fde4 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -441,7 +440,6 @@ XLA_TEST_F(VecOpsSimpleTest, VectorPredicateNotEqual) { int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 5f9177977449561acaec6f480937833ea0de3dd1..afa7d871c0e6c0922069e4846d3b7c28c4cb821f 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" @@ -82,6 +81,70 @@ TEST_F(WhileTest, WhileWithScalarResult) { ComputeAndCompareR0(&builder, 5, {}); } +TEST_F(WhileTest, WhileWithScalarResultNonConstInit) { + auto result_shape = ShapeUtil::MakeShape(S32, {}); + auto orig_shape = ShapeUtil::MakeShape(S32, {2}); + + // Create a computation for the condition: repeat for 5 iterations. + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + builder.Gt(builder.ConstantR0(5), prev); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body: add 1 to the result variable. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto input = builder.ConstantR0(1); + auto result = builder.Add(input, prev); + body = builder.Build().ConsumeValueOrDie(); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, TestName()); + auto init = builder.Reduce(builder.ConstantR1(2, 1), + builder.ConstantR0(0), + CreateScalarAddComputation(S32, &builder), {0}); + auto result = builder.While(condition, body, init); + auto shape = builder.GetShape(result).ConsumeValueOrDie(); + + ComputeAndCompareR0(&builder, 5, {}); +} + +TEST_F(WhileTest, WhileWithPredicateResult) { + auto result_shape = ShapeUtil::MakeShape(PRED, {}); + + // Create a computation for the condition: run until condition is true. + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + builder.Ne(builder.ConstantR0(true), prev); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body: or condition with true. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto result = builder.LogicalOr(prev, builder.ConstantR0(true)); + body = builder.Build().ConsumeValueOrDie(); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, TestName()); + auto init = builder.Ne(builder.ConstantR0(false), + builder.ConstantR0(true)); + auto result = builder.While(condition, body, init); + + ComputeAndCompareR0(&builder, true, {}); +} + // Tests a while node when the result type T is a vector. // // All constants are chosen to produce exact results. @@ -240,15 +303,62 @@ TEST_F(WhileTest, WhileWithTupleResult) { VLOG(2) << "while = " << ShapeUtil::HumanString( *builder.GetShape(result).ConsumeValueOrDie()); - auto expected_counter = LiteralUtil::CreateR0(5); - auto expected_data = LiteralUtil::CreateR1( + auto expected_counter = Literal::CreateR0(5); + auto expected_data = Literal::CreateR1( {5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f}); auto expected = - LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); + Literal::MakeTuple({expected_counter.get(), expected_data.get()}); VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } +TEST_F(WhileTest, WhileWithPredicateTupleResult) { + std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), + ShapeUtil::MakeShape(PRED, {})}; + Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); + + // Create a computation for the condition. + // Repeat for 5 iterations. + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + builder.Gt(builder.ConstantR0(5), iteration); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body. + // Add 1 to the iteration variable and or the predicate with true + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + auto pred = builder.GetTupleElement(prev, 1); + auto new_pred = builder.LogicalOr(pred, builder.ConstantR0(true)); + auto result = builder.Tuple( + {builder.Add(iteration, builder.ConstantR0(1)), new_pred}); + body = builder.Build().ConsumeValueOrDie(); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, "while"); + auto init = builder.Tuple({builder.ConstantR0(0), + builder.Ne(builder.ConstantR0(false), + builder.ConstantR0(true))}); + auto result = builder.While(condition, body, init); + VLOG(2) << "while = " + << ShapeUtil::HumanString( + *builder.GetShape(result).ConsumeValueOrDie()); + + auto expected_counter = Literal::CreateR0(5); + auto expected_predicate = Literal::CreateR0(true); + auto expected = + Literal::MakeTuple({expected_counter.get(), expected_predicate.get()}); + ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0)); +} + // Tests two while nodes when the result type T is a Tuple and the second // while node uses the result of the first while node which is used in two // nodes. @@ -525,11 +635,11 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { << ShapeUtil::HumanString( *builder.GetShape(result).ConsumeValueOrDie()); - auto expected_counter = LiteralUtil::CreateR0(5); - auto expected_data = LiteralUtil::CreateR1( + auto expected_counter = Literal::CreateR0(5); + auto expected_data = Literal::CreateR1( {1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f}); auto expected = - LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); + Literal::MakeTuple({expected_counter.get(), expected_data.get()}); VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } @@ -589,7 +699,7 @@ TEST_F(WhileTest, WhileWithPrngScalarResult) { for (int i = 1; i < 4; ++i) { TF_ASSIGN_OR_ASSERT_OK(auto computation, while_loop(i)); - ExecutionOptions execution_options; + ExecutionOptions execution_options = execution_options_; execution_options.set_seed(65); TF_ASSIGN_OR_ASSERT_OK( auto result, @@ -743,7 +853,6 @@ BENCHMARK(BM_WhileLoop); int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc index 7876272467890b56c2cca71f64e66303eb8ac632..4d060895d357493327ec50b38016478c65fef94d 100644 --- a/tensorflow/compiler/xla/text_literal_reader.cc +++ b/tensorflow/compiler/xla/text_literal_reader.cc @@ -104,8 +104,7 @@ StatusOr> TextLiteralReader::ReadAllLines() { auto result = MakeUnique(); const float fill = std::numeric_limits::quiet_NaN(); - LiteralUtil::PopulateWithValue(fill, AsInt64Slice(shape.dimensions()), - result.get()); + result->PopulateWithValue(fill, AsInt64Slice(shape.dimensions())); std::vector pieces; std::vector coordinates; std::vector coordinate_values; @@ -147,7 +146,7 @@ StatusOr> TextLiteralReader::ReadAllLines() { "\"%s\"", shape.dimensions_size(), coordinate_values.size(), line.c_str()); } - LiteralUtil::Set(result.get(), coordinate_values, value); + result->Set(coordinate_values, value); } return std::move(result); } diff --git a/tensorflow/compiler/xla/text_literal_reader_test.cc b/tensorflow/compiler/xla/text_literal_reader_test.cc index a167d80f73b0273739e22d94be8d90ab00839dc9..23070b663870a2b78b38663e09a32fcb28d9c2dc 100644 --- a/tensorflow/compiler/xla/text_literal_reader_test.cc +++ b/tensorflow/compiler/xla/text_literal_reader_test.cc @@ -46,12 +46,12 @@ TEST(TextLiteralReaderTest, ReadsR3File) { TextLiteralReader::ReadPath(fname).ConsumeValueOrDie(); EXPECT_TRUE( ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {1, 2, 3}), literal->shape())); - EXPECT_EQ(42.5, LiteralUtil::Get(*literal, {0, 0, 0})); - EXPECT_EQ(43.5, LiteralUtil::Get(*literal, {0, 0, 1})); - EXPECT_EQ(44.5, LiteralUtil::Get(*literal, {0, 0, 2})); - EXPECT_EQ(45.5, LiteralUtil::Get(*literal, {0, 1, 0})); - EXPECT_EQ(46.5, LiteralUtil::Get(*literal, {0, 1, 1})); - EXPECT_EQ(47.5, LiteralUtil::Get(*literal, {0, 1, 2})); + EXPECT_EQ(42.5, literal->Get({0, 0, 0})); + EXPECT_EQ(43.5, literal->Get({0, 0, 1})); + EXPECT_EQ(44.5, literal->Get({0, 0, 2})); + EXPECT_EQ(45.5, literal->Get({0, 1, 0})); + EXPECT_EQ(46.5, literal->Get({0, 1, 1})); + EXPECT_EQ(47.5, literal->Get({0, 1, 2})); } } // namespace diff --git a/tensorflow/compiler/xla/text_literal_writer.cc b/tensorflow/compiler/xla/text_literal_writer.cc index a5097e41cb3cb3fe1c10e3c21c00c2242087deba..3fee467594d8423c707abf07a0622a738437830a 100644 --- a/tensorflow/compiler/xla/text_literal_writer.cc +++ b/tensorflow/compiler/xla/text_literal_writer.cc @@ -45,9 +45,9 @@ namespace xla { tensorflow::Status status; tensorflow::WritableFile* f_ptr = f.get(); - LiteralUtil::EachCellAsString( - literal, [f_ptr, &status](tensorflow::gtl::ArraySlice indices, - const string& value) { + literal.EachCellAsString( + [f_ptr, &status](tensorflow::gtl::ArraySlice indices, + const string& value) { if (!status.ok()) { return; } diff --git a/tensorflow/compiler/xla/text_literal_writer_test.cc b/tensorflow/compiler/xla/text_literal_writer_test.cc index 177ae4ea036af660b7a2be1d4082b30ca8fb9fac..70cf2fb1b8a1b4f2ecfdaeaef3a00ddc974e2652 100644 --- a/tensorflow/compiler/xla/text_literal_writer_test.cc +++ b/tensorflow/compiler/xla/text_literal_writer_test.cc @@ -30,7 +30,7 @@ namespace xla { namespace { TEST(TextLiteralWriterTest, WritesFloatLiteral) { - auto literal = LiteralUtil::CreateR2({ + auto literal = Literal::CreateR2({ {3.14, 2.17}, {1.23, 4.56}, }); string path = diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 535e5b605b4f68671c9b6a8af4a12732f88e744e..4bbe0ba0ddd93b59557d3a4c6007ed9d2f8b7c11 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -36,7 +36,7 @@ cc_library( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:service_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", @@ -187,7 +187,7 @@ cc_binary( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:hlo_graph_dumper_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo_graph_dumper", "//tensorflow/compiler/xla/service:session_proto", diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc index 10efa9f3e8d856493b2db23195188da6fba65244..7861c3a9b72e85cba8907c82a9d36d0fe39889c2 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc @@ -32,7 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/service_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" @@ -53,8 +53,12 @@ void RealMain(tensorflow::gtl::ArraySlice args) { TF_CHECK_OK( tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); + DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags(); + debug_options.set_xla_generate_hlo_graph(".*"); + debug_options.set_xla_hlo_graph_layout(true); ComputationStats stats = - client->GetComputationStats(computation).ConsumeValueOrDie(); + client->GetComputationStats(computation, debug_options) + .ConsumeValueOrDie(); fprintf(stdout, ">>> %s :: %s\n", arg, stats.DebugString().c_str()); } } @@ -63,12 +67,16 @@ void RealMain(tensorflow::gtl::ArraySlice args) { } // namespace xla int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } tensorflow::port::InitMain(argv[0], &argc, &argv); - xla::legacy_flags::ServiceFlags* flags = xla::legacy_flags::GetServiceFlags(); - flags->xla_generate_hlo_graph = ".*"; - flags->xla_hlo_graph_layout = true; - tensorflow::gtl::ArraySlice args(argv, argc); args.pop_front(); // Pop off the binary name, argv[0] xla::tools::RealMain(args); diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc index 850267d3195785a96bf8d2c80fe64fdb8aae0a91..51f90b07c66f7d839f587350726333b9dbe6a9f0 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc @@ -30,7 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" @@ -52,8 +52,12 @@ void RealMain(tensorflow::gtl::ArraySlice args) { TF_CHECK_OK( tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); + DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags(); + debug_options.set_xla_generate_hlo_graph(".*"); + debug_options.set_xla_hlo_dump_as_graphdef(true); ComputationStats stats = - client->GetComputationStats(computation).ConsumeValueOrDie(); + client->GetComputationStats(computation, debug_options) + .ConsumeValueOrDie(); fprintf(stdout, ">>> %s :: %s\n", arg, stats.DebugString().c_str()); } } @@ -62,14 +66,16 @@ void RealMain(tensorflow::gtl::ArraySlice args) { } // namespace xla int main(int argc, char** argv) { - tensorflow::port::InitMain(argv[0], &argc, &argv); - - xla::legacy_flags::ServiceFlags* flags = xla::legacy_flags::GetServiceFlags(); - flags->xla_generate_hlo_graph = ".*"; + std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } - xla::legacy_flags::HloGraphDumperFlags* dumper_flags = - xla::legacy_flags::GetHloGraphDumperFlags(); - dumper_flags->xla_hlo_dump_as_graphdef = true; + tensorflow::port::InitMain(argv[0], &argc, &argv); tensorflow::gtl::ArraySlice args(argv, argc); args.pop_front(); // Pop off the binary name, argv[0] diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 3a75bf6495415e569aafce1eccc843cc95f9f7fa..6228ca34c0835a7476e45037c9bb6373ee1750dd 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -98,11 +98,11 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool use_fake_data) { std::unique_ptr result = result_status.ConsumeValueOrDie(); fprintf(stdout, "%s: %s :: %s:%s\n", arg, module.entry().name().c_str(), ShapeUtil::HumanString(result->shape()).c_str(), - LiteralUtil::ToString(*result).c_str()); + result->ToString().c_str()); if (module.has_result()) { fprintf(stdout, "was %s:%s\n", ShapeUtil::HumanString(module.result().shape()).c_str(), - LiteralUtil::ToString(Literal(module.result())).c_str()); + Literal(module.result()).ToString().c_str()); } } } diff --git a/tensorflow/compiler/xla/tools/show_literal.cc b/tensorflow/compiler/xla/tools/show_literal.cc index b6538f5de07743ef7320343d6b23119e919d114f..b50cb5e28eac14ed99af566939f8bd64e393ff64 100644 --- a/tensorflow/compiler/xla/tools/show_literal.cc +++ b/tensorflow/compiler/xla/tools/show_literal.cc @@ -42,5 +42,5 @@ int main(int argc, char **argv) { &literal_proto)); xla::Literal literal(literal_proto); LOG(INFO) << "literal: " << literal_proto.ShortDebugString(); - fprintf(stderr, "%s\n", xla::LiteralUtil::ToString(literal).c_str()); + fprintf(stderr, "%s\n", literal.ToString().c_str()); } diff --git a/tensorflow/compiler/xla/tools/show_text_literal.cc b/tensorflow/compiler/xla/tools/show_text_literal.cc index 2d983b407c64ab5547d722abcc2c564a7963f730..bbe9902aa17a585c4bad5b732330305dfdd45302 100644 --- a/tensorflow/compiler/xla/tools/show_text_literal.cc +++ b/tensorflow/compiler/xla/tools/show_text_literal.cc @@ -40,7 +40,7 @@ int main(int argc, char **argv) { xla::TextLiteralReader::ReadPath(argv[1]).ConsumeValueOrDie(); LOG(INFO) << "literal: " << literal->ShortDebugString(); - fprintf(stderr, "%s\n", xla::LiteralUtil::ToString(*literal).c_str()); + fprintf(stderr, "%s\n", literal->ToString().c_str()); if (literal->shape().element_type() == xla::F32) { float min = *std::min_element(literal->f32s().begin(), literal->f32s().end()); diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 42d5c1d15501fb912551a044414e6fa0c83283b8..31f0c3147eba0f369e563af17078effcc2e6b159 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -195,16 +195,24 @@ bool IsPermutation(tensorflow::gtl::ArraySlice permutation, int64 rank); // 2. permutation.size() == input.size(). template - - - diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-metadata-card.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector-metadata-card.ts deleted file mode 100644 index 939300f3878e6c09551c77062a94a92d3cc07000..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-metadata-card.ts +++ /dev/null @@ -1,88 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {PointMetadata} from './data'; -// tslint:disable-next-line:no-unused-variable -import {PolymerElement, PolymerHTMLElement} from './vz-projector-util'; - -// tslint:disable-next-line -export let MetadataCardPolymer = PolymerElement({ - is: 'vz-projector-metadata-card', - properties: { - hasMetadata: {type: Boolean, value: false}, - metadata: {type: Array}, - label: String - } -}); - -export class MetadataCard extends MetadataCardPolymer { - hasMetadata: boolean; - metadata: Array<{key: string, value: string}>; - label: string; - - private labelOption: string; - private pointMetadata: PointMetadata; - - private expandLessButton: HTMLButtonElement; - private expandMoreButton: HTMLButtonElement; - - ready() { - this.expandLessButton = - this.querySelector('#expand-less') as HTMLButtonElement; - this.expandMoreButton = - this.querySelector('#expand-more') as HTMLButtonElement; - } - /** Handles a click on the expand more icon. */ - _expandMore() { - (this.$$('#metadata-container') as any).toggle(); - - this.expandMoreButton.style.display = 'none'; - this.expandLessButton.style.display = ''; - } - - /** Handles a click on the expand less icon. */ - _expandLess() { - (this.$$('#metadata-container') as any).toggle(); - this.expandMoreButton.style.display = ''; - this.expandLessButton.style.display = 'none'; - } - - updateMetadata(pointMetadata?: PointMetadata) { - this.pointMetadata = pointMetadata; - this.hasMetadata = (pointMetadata != null); - - if (pointMetadata) { - let metadata = []; - for (let metadataKey in pointMetadata) { - if (!pointMetadata.hasOwnProperty(metadataKey)) { - continue; - } - metadata.push({key: metadataKey, value: pointMetadata[metadataKey]}); - } - - this.metadata = metadata; - this.label = '' + this.pointMetadata[this.labelOption]; - } - } - - setLabelOption(labelOption: string) { - this.labelOption = labelOption; - if (this.pointMetadata) { - this.label = '' + this.pointMetadata[this.labelOption]; - } - } -} - -document.registerElement(MetadataCard.prototype.is, MetadataCard); diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.html b/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.html deleted file mode 100644 index b82f3f520b5e62bb381f1a9c6ebd10c4a04d13cf..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.html +++ /dev/null @@ -1,316 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts deleted file mode 100644 index 377c6c11ad5d19343682540bdadc3319b5d0ee3c..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts +++ /dev/null @@ -1,589 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import * as data from './data'; -import {DataSet, Projection, ProjectionType, SpriteAndMetadataInfo, State} from './data'; -import * as util from './util'; -import * as vector from './vector'; -import {Vector} from './vector'; -import {Projector} from './vz-projector'; -import {ProjectorInput} from './vz-projector-input'; -// tslint:disable-next-line:no-unused-variable -import {PolymerElement, PolymerHTMLElement} from './vz-projector-util'; - -const NUM_PCA_COMPONENTS = 10; - -// tslint:disable-next-line -export let ProjectionsPanelPolymer = PolymerElement({ - is: 'vz-projector-projections-panel', - properties: { - pcaIs3d: - {type: Boolean, value: true, observer: '_pcaDimensionToggleObserver'}, - tSNEis3d: - {type: Boolean, value: true, observer: '_tsneDimensionToggleObserver'}, - // PCA projection. - pcaComponents: Array, - pcaX: {type: Number, value: 0, observer: 'showPCAIfEnabled'}, - pcaY: {type: Number, value: 1, observer: 'showPCAIfEnabled'}, - pcaZ: {type: Number, value: 2, observer: 'showPCAIfEnabled'}, - // Custom projection. - customSelectedSearchByMetadataOption: { - type: String, - observer: '_customSelectedSearchByMetadataOptionChanged' - }, - } -}); - -type InputControlName = 'xLeft'|'xRight'|'yUp'|'yDown'; - -type CentroidResult = { - centroid?: Vector; numMatches?: number; -}; - -type Centroids = { - [key: string]: Vector; xLeft: Vector; xRight: Vector; yUp: Vector; - yDown: Vector; -}; - -/** - * A polymer component which handles the projection tabs in the projector. - */ -export class ProjectionsPanel extends ProjectionsPanelPolymer { - private projector: Projector; - private pcaComponents: - Array<{id: number, componentNumber: number, percVariance: string}>; - private currentProjection: ProjectionType; - private polymerChangesTriggerReprojection: boolean; - private dataSet: DataSet; - private originalDataSet: DataSet; - private dim: number; - - /** T-SNE perplexity. Roughly how many neighbors each point influences. */ - private perplexity: number; - /** T-SNE learning rate. */ - private learningRate: number; - - private searchByMetadataOptions: string[]; - - /** Centroids for custom projections. */ - private centroidValues: any; - private centroids: Centroids; - /** The centroid across all points. */ - private allCentroid: number[]; - - /** Polymer properties. */ - // TODO(nsthorat): Move these to a separate view controller. - public tSNEis3d: boolean; - public pcaIs3d: boolean; - public pcaX: number; - public pcaY: number; - public pcaZ: number; - public customSelectedSearchByMetadataOption: string; - - /** Polymer elements. */ - private runTsneButton: HTMLButtonElement; - private stopTsneButton: HTMLButtonElement; - private perplexitySlider: HTMLInputElement; - private learningRateInput: HTMLInputElement; - private zDropdown: HTMLElement; - private iterationLabel: HTMLElement; - - private customProjectionXLeftInput: ProjectorInput; - private customProjectionXRightInput: ProjectorInput; - private customProjectionYUpInput: ProjectorInput; - private customProjectionYDownInput: ProjectorInput; - - initialize(projector: Projector) { - this.polymerChangesTriggerReprojection = true; - this.projector = projector; - - // Set up TSNE projections. - this.perplexity = 30; - this.learningRate = 10; - - // Setup Custom projections. - this.centroidValues = {xLeft: null, xRight: null, yUp: null, yDown: null}; - this.clearCentroids(); - - this.setupUIControls(); - } - - ready() { - this.zDropdown = this.querySelector('#z-dropdown') as HTMLElement; - this.runTsneButton = this.querySelector('.run-tsne') as HTMLButtonElement; - this.stopTsneButton = this.querySelector('.stop-tsne') as HTMLButtonElement; - this.perplexitySlider = - this.querySelector('#perplexity-slider') as HTMLInputElement; - this.learningRateInput = - this.querySelector('#learning-rate-slider') as HTMLInputElement; - this.iterationLabel = this.querySelector('.run-tsne-iter') as HTMLElement; - } - - disablePolymerChangesTriggerReprojection() { - this.polymerChangesTriggerReprojection = false; - } - - enablePolymerChangesTriggerReprojection() { - this.polymerChangesTriggerReprojection = true; - } - - private updateTSNEPerplexityFromSliderChange() { - if (this.perplexitySlider) { - this.perplexity = +this.perplexitySlider.value; - } - (this.querySelector('.tsne-perplexity span') as HTMLSpanElement).innerText = - '' + this.perplexity; - } - - private updateTSNELearningRateFromUIChange() { - if (this.learningRateInput) { - this.learningRate = Math.pow(10, +this.learningRateInput.value); - } - (this.querySelector('.tsne-learning-rate span') as HTMLSpanElement) - .innerText = '' + this.learningRate; - } - - private setupUIControls() { - { - const self = this; - const inkTabs = this.querySelectorAll('.ink-tab'); - for (let i = 0; i < inkTabs.length; i++) { - inkTabs[i].addEventListener('click', function() { - let id = this.getAttribute('data-tab'); - self.showTab(id); - }); - } - } - - this.runTsneButton.addEventListener('click', () => this.runTSNE()); - this.stopTsneButton.addEventListener( - 'click', () => this.dataSet.stopTSNE()); - - this.perplexitySlider.value = this.perplexity.toString(); - this.perplexitySlider.addEventListener( - 'change', () => this.updateTSNEPerplexityFromSliderChange()); - this.updateTSNEPerplexityFromSliderChange(); - - this.learningRateInput.addEventListener( - 'change', () => this.updateTSNELearningRateFromUIChange()); - this.updateTSNELearningRateFromUIChange(); - - this.setupCustomProjectionInputFields(); - // TODO: figure out why `--paper-input-container-input` css mixin didn't - // work. - const inputs = - this.querySelectorAll('paper-dropdown-menu paper-input input'); - for (let i = 0; i < inputs.length; i++) { - (inputs[i] as HTMLElement).style.fontSize = '14px'; - } - } - - restoreUIFromBookmark(bookmark: State) { - this.disablePolymerChangesTriggerReprojection(); - - // PCA - this.pcaX = bookmark.pcaComponentDimensions[0]; - this.pcaY = bookmark.pcaComponentDimensions[1]; - if (bookmark.pcaComponentDimensions.length === 3) { - this.pcaZ = bookmark.pcaComponentDimensions[2]; - } - this.pcaIs3d = (bookmark.pcaComponentDimensions.length === 3); - - // t-SNE - if (this.perplexitySlider) { - this.perplexitySlider.value = bookmark.tSNEPerplexity.toString(); - } - if (this.learningRateInput) { - this.learningRateInput.value = bookmark.tSNELearningRate.toString(); - } - this.tSNEis3d = bookmark.tSNEis3d; - - // custom - this.customSelectedSearchByMetadataOption = - bookmark.customSelectedSearchByMetadataOption; - if (this.customProjectionXLeftInput) { - this.customProjectionXLeftInput.set( - bookmark.customXLeftText, bookmark.customXLeftRegex); - } - if (this.customProjectionXRightInput) { - this.customProjectionXRightInput.set( - bookmark.customXRightText, bookmark.customXRightRegex); - } - if (this.customProjectionYUpInput) { - this.customProjectionYUpInput.set( - bookmark.customYUpText, bookmark.customYUpRegex); - } - if (this.customProjectionYDownInput) { - this.customProjectionYDownInput.set( - bookmark.customYDownText, bookmark.customYDownRegex); - } - this.computeAllCentroids(); - - this.setZDropdownEnabled(this.pcaIs3d); - this.updateTSNEPerplexityFromSliderChange(); - this.updateTSNELearningRateFromUIChange(); - if (this.iterationLabel) { - this.iterationLabel.innerText = bookmark.tSNEIteration.toString(); - } - if (bookmark.selectedProjection != null) { - this.showTab(bookmark.selectedProjection); - } - this.enablePolymerChangesTriggerReprojection(); - } - - populateBookmarkFromUI(bookmark: State) { - this.disablePolymerChangesTriggerReprojection(); - - // PCA - bookmark.pcaComponentDimensions = [this.pcaX, this.pcaY]; - if (this.pcaIs3d) { - bookmark.pcaComponentDimensions.push(this.pcaZ); - } - - // t-SNE - if (this.perplexitySlider != null) { - bookmark.tSNEPerplexity = +this.perplexitySlider.value; - } - if (this.learningRateInput != null) { - bookmark.tSNELearningRate = +this.learningRateInput.value; - } - bookmark.tSNEis3d = this.tSNEis3d; - - // custom - bookmark.customSelectedSearchByMetadataOption = - this.customSelectedSearchByMetadataOption; - if (this.customProjectionXLeftInput != null) { - bookmark.customXLeftText = this.customProjectionXLeftInput.getValue(); - bookmark.customXLeftRegex = - this.customProjectionXLeftInput.getInRegexMode(); - } - if (this.customProjectionXRightInput != null) { - bookmark.customXRightText = this.customProjectionXRightInput.getValue(); - bookmark.customXRightRegex = - this.customProjectionXRightInput.getInRegexMode(); - } - if (this.customProjectionYUpInput != null) { - bookmark.customYUpText = this.customProjectionYUpInput.getValue(); - bookmark.customYUpRegex = this.customProjectionYUpInput.getInRegexMode(); - } - if (this.customProjectionYDownInput != null) { - bookmark.customYDownText = this.customProjectionYDownInput.getValue(); - bookmark.customYDownRegex = - this.customProjectionYDownInput.getInRegexMode(); - } - - this.enablePolymerChangesTriggerReprojection(); - } - - // This method is marked as public as it is used as the view method that - // abstracts DOM manipulation so we can stub it in a test. - // TODO(nsthorat): Move this to its own class as the glue between this class - // and the DOM. - setZDropdownEnabled(enabled: boolean) { - if (this.zDropdown) { - if (enabled) { - this.zDropdown.removeAttribute('disabled'); - } else { - this.zDropdown.setAttribute('disabled', 'true'); - } - } - } - - dataSetUpdated(dataSet: DataSet, originalDataSet: DataSet, dim: number) { - this.dataSet = dataSet; - this.originalDataSet = originalDataSet; - this.dim = dim; - const pointCount = (dataSet == null) ? 0 : dataSet.points.length; - const perplexity = Math.max(5, Math.ceil(Math.sqrt(pointCount) / 4)); - this.perplexitySlider.value = perplexity.toString(); - this.updateTSNEPerplexityFromSliderChange(); - this.clearCentroids(); - - (this.querySelector('#tsne-sampling') as HTMLElement).style.display = - pointCount > data.TSNE_SAMPLE_SIZE ? null : 'none'; - const wasSampled = - (dataSet == null) ? false : (dataSet.dim[0] > data.PCA_SAMPLE_DIM || - dataSet.dim[1] > data.PCA_SAMPLE_DIM); - (this.querySelector('#pca-sampling') as HTMLElement).style.display = - wasSampled ? null : 'none'; - this.showTab('pca'); - } - - _pcaDimensionToggleObserver() { - this.setZDropdownEnabled(this.pcaIs3d); - this.beginProjection(this.currentProjection); - } - - _tsneDimensionToggleObserver() { - this.beginProjection(this.currentProjection); - } - - metadataChanged(spriteAndMetadata: SpriteAndMetadataInfo) { - // Project by options for custom projections. - let searchByMetadataIndex = -1; - this.searchByMetadataOptions = spriteAndMetadata.stats.map((stats, i) => { - // Make the default label by the first non-numeric column. - if (!stats.isNumeric && searchByMetadataIndex === -1) { - searchByMetadataIndex = i; - } - return stats.name; - }); - this.customSelectedSearchByMetadataOption = - this.searchByMetadataOptions[Math.max(0, searchByMetadataIndex)]; - } - - public showTab(id: ProjectionType) { - this.currentProjection = id; - - const tab = - this.querySelector('.ink-tab[data-tab="' + id + '"]') as HTMLElement; - const allTabs = this.querySelectorAll('.ink-tab'); - for (let i = 0; i < allTabs.length; i++) { - util.classed(allTabs[i] as HTMLElement, 'active', false); - } - - util.classed(tab, 'active', true); - - const allTabContent = this.querySelectorAll('.ink-panel-content'); - for (let i = 0; i < allTabContent.length; i++) { - util.classed(allTabContent[i] as HTMLElement, 'active', false); - } - - util.classed( - this.querySelector('.ink-panel-content[data-panel="' + id + '"]') as - HTMLElement, - 'active', true); - - // guard for unit tests, where polymer isn't attached and $ doesn't exist. - if (this.$ != null) { - const main = this.$['main']; - // In order for the projections panel to animate its height, we need to - // set it explicitly. - requestAnimationFrame(() => { - this.style.height = main.clientHeight + 'px'; - }); - } - - this.beginProjection(id); - } - - private beginProjection(projection: ProjectionType) { - if (this.polymerChangesTriggerReprojection === false) { - return; - } - if (projection === 'pca') { - if (this.dataSet != null) { - this.dataSet.stopTSNE(); - } - this.showPCA(); - } else if (projection === 'tsne') { - this.showTSNE(); - } else if (projection === 'custom') { - if (this.dataSet != null) { - this.dataSet.stopTSNE(); - } - this.computeAllCentroids(); - this.reprojectCustom(); - } - } - - private showTSNE() { - const dataSet = this.dataSet; - if (dataSet == null) { - return; - } - const accessors = - data.getProjectionComponents('tsne', [0, 1, this.tSNEis3d ? 2 : null]); - const dimensionality = this.tSNEis3d ? 3 : 2; - const projection = - new Projection('tsne', accessors, dimensionality, dataSet); - this.projector.setProjection(projection); - - if (!this.dataSet.hasTSNERun) { - this.runTSNE(); - } else { - this.projector.notifyProjectionPositionsUpdated(); - } - } - - private runTSNE() { - this.runTsneButton.disabled = true; - this.stopTsneButton.disabled = null; - this.dataSet.projectTSNE( - this.perplexity, this.learningRate, this.tSNEis3d ? 3 : 2, - (iteration: number) => { - if (iteration != null) { - this.iterationLabel.innerText = '' + iteration; - this.projector.notifyProjectionPositionsUpdated(); - } else { - this.runTsneButton.disabled = null; - this.stopTsneButton.disabled = true; - } - }); - } - - // tslint:disable-next-line:no-unused-variable - private showPCAIfEnabled() { - if (this.polymerChangesTriggerReprojection) { - this.showPCA(); - } - } - - private updateTotalVarianceMessage() { - let variances = this.dataSet.fracVariancesExplained; - let totalVariance = variances[this.pcaX] + variances[this.pcaY]; - let msg = 'Total variance described: '; - if (this.pcaIs3d) { - totalVariance += variances[this.pcaZ]; - } - msg += (totalVariance * 100).toFixed(1) + '%.'; - (this.querySelector('#total-variance') as HTMLElement).innerHTML = msg; - } - - private showPCA() { - if (this.dataSet == null) { - return; - } - this.dataSet.projectPCA().then(() => { - // Polymer properties are 1-based. - const accessors = data.getProjectionComponents( - 'pca', [this.pcaX, this.pcaY, this.pcaZ]); - - const dimensionality = this.pcaIs3d ? 3 : 2; - const projection = - new Projection('pca', accessors, dimensionality, this.dataSet); - this.projector.setProjection(projection); - let numComponents = Math.min(NUM_PCA_COMPONENTS, this.dataSet.dim[1]); - this.updateTotalVarianceMessage(); - this.pcaComponents = util.range(numComponents).map(i => { - let fracVariance = this.dataSet.fracVariancesExplained[i]; - return { - id: i, - componentNumber: i + 1, - percVariance: (fracVariance * 100).toFixed(1) - }; - }); - }); - } - - private reprojectCustom() { - if (this.centroids == null || this.centroids.xLeft == null || - this.centroids.xRight == null || this.centroids.yUp == null || - this.centroids.yDown == null) { - return; - } - const xDir = vector.sub(this.centroids.xRight, this.centroids.xLeft); - this.dataSet.projectLinear(xDir, 'linear-x'); - - const yDir = vector.sub(this.centroids.yUp, this.centroids.yDown); - this.dataSet.projectLinear(yDir, 'linear-y'); - - const accessors = data.getProjectionComponents('custom', ['x', 'y']); - const projection = new Projection('custom', accessors, 2, this.dataSet); - this.projector.setProjection(projection); - } - - clearCentroids(): void { - this.centroids = {xLeft: null, xRight: null, yUp: null, yDown: null}; - this.allCentroid = null; - } - - _customSelectedSearchByMetadataOptionChanged(newVal: string, oldVal: string) { - if (this.polymerChangesTriggerReprojection === false) { - return; - } - if (this.currentProjection === 'custom') { - this.computeAllCentroids(); - this.reprojectCustom(); - } - } - - private setupCustomProjectionInputFields() { - this.customProjectionXLeftInput = - this.setupCustomProjectionInputField('xLeft'); - this.customProjectionXRightInput = - this.setupCustomProjectionInputField('xRight'); - this.customProjectionYUpInput = this.setupCustomProjectionInputField('yUp'); - this.customProjectionYDownInput = - this.setupCustomProjectionInputField('yDown'); - } - - private computeAllCentroids() { - this.computeCentroid('xLeft'); - this.computeCentroid('xRight'); - this.computeCentroid('yUp'); - this.computeCentroid('yDown'); - } - - private computeCentroid(name: InputControlName) { - const input = this.querySelector('#' + name) as ProjectorInput; - if (input == null) { - return; - } - const value = input.getValue(); - if (value == null) { - return; - } - let inRegexMode = input.getInRegexMode(); - let result = this.getCentroid(value, inRegexMode); - if (result.numMatches === 0) { - input.message = '0 matches. Using a random vector.'; - result.centroid = vector.rn(this.dim); - } else { - input.message = `${result.numMatches} matches.`; - } - this.centroids[name] = result.centroid; - this.centroidValues[name] = value; - } - - private setupCustomProjectionInputField(name: InputControlName): - ProjectorInput { - let input = this.querySelector('#' + name) as ProjectorInput; - input.registerInputChangedListener((input, inRegexMode) => { - if (this.polymerChangesTriggerReprojection) { - this.computeCentroid(name); - this.reprojectCustom(); - } - }); - return input; - } - - private getCentroid(pattern: string, inRegexMode: boolean): CentroidResult { - if (pattern == null || pattern === '') { - return {numMatches: 0}; - } - // Search by the original dataset since we often want to filter and project - // only the nearest neighbors of A onto B-C where B and C are not nearest - // neighbors of A. - let accessor = (i: number) => this.originalDataSet.points[i].vector; - let r = this.originalDataSet.query( - pattern, inRegexMode, this.customSelectedSearchByMetadataOption); - return {centroid: vector.centroid(r, accessor), numMatches: r.length}; - } - - getPcaSampledDimText() { - return data.PCA_SAMPLE_DIM.toLocaleString(); - } - - getPcaSampleSizeText() { - return data.PCA_SAMPLE_SIZE.toLocaleString(); - } - - getTsneSampleSizeText() { - return data.TSNE_SAMPLE_SIZE.toLocaleString(); - } -} - -document.registerElement(ProjectionsPanel.prototype.is, ProjectionsPanel); diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-util.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector-util.ts deleted file mode 100644 index 44062062a364b742e2de6467614e508d4e89d37a..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-util.ts +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -export type Spec = { - is: string; properties?: { - [key: string]: - (Function | - { - type: Function, value?: any; - readonly?: boolean; - notify?: boolean; - observer?: string; - }) - }; - observers?: string[]; -}; - -export function PolymerElement(spec: Spec) { - return Polymer.Class(spec as any) as{new (): PolymerHTMLElement}; -} - -export interface PolymerHTMLElement extends HTMLElement, polymer.Base {} diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector.html b/tensorflow/tensorboard/components/vz_projector/vz-projector.html deleted file mode 100644 index 438ea9f4e978fa608eb0cabde35e9adf6f7e87fe..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector.html +++ /dev/null @@ -1,346 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector.ts deleted file mode 100644 index bf98a4d478599f7b859e893e7a17567f22fd5114..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector.ts +++ /dev/null @@ -1,570 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {AnalyticsLogger} from './analyticsLogger'; -import * as data from './data'; -import {ColorOption, ColumnStats, DataPoint, DataProto, DataSet, DistanceFunction, PointMetadata, Projection, SpriteAndMetadataInfo, State, stateGetAccessorDimensions} from './data'; -import {DataProvider, EmbeddingInfo, ServingMode} from './data-provider'; -import {DemoDataProvider} from './data-provider-demo'; -import {ProtoDataProvider} from './data-provider-proto'; -import {ServerDataProvider} from './data-provider-server'; -import * as knn from './knn'; -import * as logging from './logging'; -import {DistanceMetricChangedListener, HoverListener, ProjectionChangedListener, ProjectorEventContext, SelectionChangedListener} from './projectorEventContext'; -import {ProjectorScatterPlotAdapter} from './projectorScatterPlotAdapter'; -import {MouseMode} from './scatterPlot'; -import * as util from './util'; -import {BookmarkPanel} from './vz-projector-bookmark-panel'; -import {DataPanel} from './vz-projector-data-panel'; -import {InspectorPanel} from './vz-projector-inspector-panel'; -import {MetadataCard} from './vz-projector-metadata-card'; -import {ProjectionsPanel} from './vz-projector-projections-panel'; -// tslint:disable-next-line:no-unused-variable -import {PolymerElement, PolymerHTMLElement} from './vz-projector-util'; - -/** - * The minimum number of dimensions the data should have to automatically - * decide to normalize the data. - */ -const THRESHOLD_DIM_NORMALIZE = 50; -const POINT_COLOR_MISSING = 'black'; - -export let ProjectorPolymer = PolymerElement({ - is: 'vz-projector', - properties: { - routePrefix: String, - dataProto: {type: String, observer: '_dataProtoChanged'}, - servingMode: String, - projectorConfigJsonPath: String, - pageViewLogging: Boolean, - eventLogging: Boolean - } -}); - -const INDEX_METADATA_FIELD = '__index__'; - -export class Projector extends ProjectorPolymer implements - ProjectorEventContext { - // The working subset of the data source's original data set. - dataSet: DataSet; - servingMode: ServingMode; - // The path to the projector config JSON file for demo mode. - projectorConfigJsonPath: string; - - private selectionChangedListeners: SelectionChangedListener[]; - private hoverListeners: HoverListener[]; - private projectionChangedListeners: ProjectionChangedListener[]; - private distanceMetricChangedListeners: DistanceMetricChangedListener[]; - - private originalDataSet: DataSet; - private dataSetBeforeFilter: DataSet; - private projectorScatterPlotAdapter: ProjectorScatterPlotAdapter; - private dim: number; - - private dataSetFilterIndices: number[]; - private selectedPointIndices: number[]; - private neighborsOfFirstPoint: knn.NearestEntry[]; - private hoverPointIndex: number; - - private dataProvider: DataProvider; - private inspectorPanel: InspectorPanel; - - private selectedColorOption: ColorOption; - private selectedLabelOption: string; - private routePrefix: string; - private normalizeData: boolean; - private projection: Projection; - - /** Polymer component panels */ - private dataPanel: DataPanel; - private bookmarkPanel: BookmarkPanel; - private projectionsPanel: ProjectionsPanel; - private metadataCard: MetadataCard; - - private statusBar: HTMLDivElement; - private analyticsLogger: AnalyticsLogger; - private eventLogging: boolean; - private pageViewLogging: boolean; - - ready() { - logging.setDomContainer(this); - - this.analyticsLogger = - new AnalyticsLogger(this.pageViewLogging, this.eventLogging); - this.analyticsLogger.logPageView('embeddings'); - - if (!util.hasWebGLSupport()) { - this.analyticsLogger.logWebGLDisabled(); - logging.setErrorMessage( - 'Your browser or device does not have WebGL enabled. Please enable ' + - 'hardware acceleration, or use a browser that supports WebGL.'); - return; - } - - this.selectionChangedListeners = []; - this.hoverListeners = []; - this.projectionChangedListeners = []; - this.distanceMetricChangedListeners = []; - this.selectedPointIndices = []; - this.neighborsOfFirstPoint = []; - - this.dataPanel = this.$['data-panel'] as DataPanel; - this.inspectorPanel = this.$['inspector-panel'] as InspectorPanel; - this.inspectorPanel.initialize(this, this as ProjectorEventContext); - this.projectionsPanel = this.$['projections-panel'] as ProjectionsPanel; - this.projectionsPanel.initialize(this); - this.bookmarkPanel = this.$['bookmark-panel'] as BookmarkPanel; - this.bookmarkPanel.initialize(this, this as ProjectorEventContext); - this.metadataCard = this.$['metadata-card'] as MetadataCard; - this.statusBar = this.querySelector('#status-bar') as HTMLDivElement; - this.scopeSubtree(this.$$('#notification-dialog'), true); - this.setupUIControls(); - this.initializeDataProvider(); - } - - setSelectedLabelOption(labelOption: string) { - this.selectedLabelOption = labelOption; - this.metadataCard.setLabelOption(this.selectedLabelOption); - this.projectorScatterPlotAdapter.setLabelPointAccessor(labelOption); - this.projectorScatterPlotAdapter.updateScatterPlotAttributes(); - this.projectorScatterPlotAdapter.render(); - } - - setSelectedColorOption(colorOption: ColorOption) { - this.selectedColorOption = colorOption; - this.projectorScatterPlotAdapter.setLegendPointColorer( - this.getLegendPointColorer(colorOption)); - this.projectorScatterPlotAdapter.updateScatterPlotAttributes(); - this.projectorScatterPlotAdapter.render(); - } - - setNormalizeData(normalizeData: boolean) { - this.normalizeData = normalizeData; - this.setCurrentDataSet(this.originalDataSet.getSubset()); - } - - updateDataSet( - ds: DataSet, spriteAndMetadata?: SpriteAndMetadataInfo, - metadataFile?: string) { - this.dataSetFilterIndices = null; - this.originalDataSet = ds; - if (ds != null) { - this.normalizeData = - this.originalDataSet.dim[1] >= THRESHOLD_DIM_NORMALIZE; - spriteAndMetadata = spriteAndMetadata || {}; - if (spriteAndMetadata.pointsInfo == null) { - let [pointsInfo, stats] = this.makeDefaultPointsInfoAndStats(ds.points); - spriteAndMetadata.pointsInfo = pointsInfo; - spriteAndMetadata.stats = stats; - } - let metadataMergeSucceeded = ds.mergeMetadata(spriteAndMetadata); - if (!metadataMergeSucceeded) { - return; - } - } - if (this.projectorScatterPlotAdapter != null) { - if (ds == null) { - this.projectorScatterPlotAdapter.setLabelPointAccessor(null); - this.setProjection(null); - } else { - this.projectorScatterPlotAdapter.updateScatterPlotPositions(); - this.projectorScatterPlotAdapter.updateScatterPlotAttributes(); - this.projectorScatterPlotAdapter.resize(); - this.projectorScatterPlotAdapter.render(); - } - } - if (ds != null) { - this.dataPanel.setNormalizeData(this.normalizeData); - this.setCurrentDataSet(ds.getSubset()); - this.projectorScatterPlotAdapter.setLabelPointAccessor( - this.selectedLabelOption); - this.inspectorPanel.datasetChanged(); - - this.inspectorPanel.metadataChanged(spriteAndMetadata); - this.projectionsPanel.metadataChanged(spriteAndMetadata); - this.dataPanel.metadataChanged(spriteAndMetadata, metadataFile); - // Set the container to a fixed height, otherwise in Colab the - // height can grow indefinitely. - const container = this.querySelector('#container') as HTMLDivElement; - container.style.height = container.clientHeight + 'px'; - } else { - this.setCurrentDataSet(null); - } - } - - setSelectedTensor(run: string, tensorInfo: EmbeddingInfo) { - this.bookmarkPanel.setSelectedTensor(run, tensorInfo, this.dataProvider); - } - - /** - * Registers a listener to be called any time the selected point set changes. - */ - registerSelectionChangedListener(listener: SelectionChangedListener) { - this.selectionChangedListeners.push(listener); - } - - filterDataset(pointIndices: number[]) { - const selectionSize = this.selectedPointIndices.length; - if (this.dataSetBeforeFilter == null) { - this.dataSetBeforeFilter = this.dataSet; - } - this.setCurrentDataSet(this.dataSet.getSubset(pointIndices)); - this.dataSetFilterIndices = pointIndices; - this.projectorScatterPlotAdapter.updateScatterPlotPositions(); - this.projectorScatterPlotAdapter.updateScatterPlotAttributes(); - this.adjustSelectionAndHover(util.range(selectionSize)); - } - - resetFilterDataset() { - const originalPointIndices = this.selectedPointIndices.map( - filteredIndex => this.dataSet.points[filteredIndex].index); - this.setCurrentDataSet(this.dataSetBeforeFilter); - if (this.projection != null) { - this.projection.dataSet = this.dataSetBeforeFilter; - } - this.dataSetBeforeFilter = null; - this.projectorScatterPlotAdapter.updateScatterPlotPositions(); - this.projectorScatterPlotAdapter.updateScatterPlotAttributes(); - this.dataSetFilterIndices = []; - this.adjustSelectionAndHover(originalPointIndices); - } - - /** - * Used by clients to indicate that a selection has occurred. - */ - notifySelectionChanged(newSelectedPointIndices: number[]) { - this.selectedPointIndices = newSelectedPointIndices; - let neighbors: knn.NearestEntry[] = []; - - if (newSelectedPointIndices.length === 1) { - neighbors = this.dataSet.findNeighbors( - newSelectedPointIndices[0], this.inspectorPanel.distFunc, - this.inspectorPanel.numNN); - this.metadataCard.updateMetadata( - this.dataSet.points[newSelectedPointIndices[0]].metadata); - } else { - this.metadataCard.updateMetadata(null); - } - - this.selectionChangedListeners.forEach( - l => l(this.selectedPointIndices, neighbors)); - } - - /** - * Registers a listener to be called any time the mouse hovers over a point. - */ - registerHoverListener(listener: HoverListener) { - this.hoverListeners.push(listener); - } - - /** - * Used by clients to indicate that a hover is occurring. - */ - notifyHoverOverPoint(pointIndex: number) { - this.hoverListeners.forEach(l => l(pointIndex)); - } - - registerProjectionChangedListener(listener: ProjectionChangedListener) { - this.projectionChangedListeners.push(listener); - } - - notifyProjectionChanged(projection: Projection) { - this.projectionChangedListeners.forEach(l => l(projection)); - } - - registerDistanceMetricChangedListener(l: DistanceMetricChangedListener) { - this.distanceMetricChangedListeners.push(l); - } - - notifyDistanceMetricChanged(distMetric: DistanceFunction) { - this.distanceMetricChangedListeners.forEach(l => l(distMetric)); - } - - _dataProtoChanged(dataProtoString: string) { - let dataProto = - dataProtoString ? JSON.parse(dataProtoString) as DataProto : null; - this.initializeDataProvider(dataProto); - } - - private makeDefaultPointsInfoAndStats(points: DataPoint[]): - [PointMetadata[], ColumnStats[]] { - let pointsInfo: PointMetadata[] = []; - points.forEach(p => { - let pointInfo: PointMetadata = {}; - pointInfo[INDEX_METADATA_FIELD] = p.index; - pointsInfo.push(pointInfo); - }); - let stats: ColumnStats[] = [{ - name: INDEX_METADATA_FIELD, - isNumeric: false, - tooManyUniqueValues: true, - min: 0, - max: pointsInfo.length - 1 - }]; - return [pointsInfo, stats]; - } - - private initializeDataProvider(dataProto?: DataProto) { - if (this.servingMode === 'demo') { - let projectorConfigUrl: string; - - // Only in demo mode do we allow the config being passed via URL. - let urlParams = util.getURLParams(window.location.search); - if ('config' in urlParams) { - projectorConfigUrl = urlParams['config']; - } else { - projectorConfigUrl = this.projectorConfigJsonPath; - } - this.dataProvider = new DemoDataProvider(projectorConfigUrl); - } else if (this.servingMode === 'server') { - if (!this.routePrefix) { - throw 'route-prefix is a required parameter'; - } - this.dataProvider = new ServerDataProvider(this.routePrefix); - } else if (this.servingMode === 'proto' && dataProto != null) { - this.dataProvider = new ProtoDataProvider(dataProto); - } - - this.dataPanel.initialize(this, this.dataProvider); - } - - private getLegendPointColorer(colorOption: ColorOption): - (ds: DataSet, index: number) => string { - if ((colorOption == null) || (colorOption.map == null)) { - return null; - } - const colorer = (ds: DataSet, i: number) => { - let value = ds.points[i].metadata[this.selectedColorOption.name]; - if (value == null) { - return POINT_COLOR_MISSING; - } - return colorOption.map(value); - }; - return colorer; - } - - private get3DLabelModeButton(): any { - return this.querySelector('#labels3DMode'); - } - - private get3DLabelMode(): boolean { - const label3DModeButton = this.get3DLabelModeButton(); - return (label3DModeButton as any).active; - } - - adjustSelectionAndHover(selectedPointIndices: number[], hoverIndex?: number) { - this.notifySelectionChanged(selectedPointIndices); - this.notifyHoverOverPoint(hoverIndex); - this.setMouseMode(MouseMode.CAMERA_AND_CLICK_SELECT); - } - - private setMouseMode(mouseMode: MouseMode) { - let selectModeButton = this.querySelector('#selectMode'); - (selectModeButton as any).active = (mouseMode === MouseMode.AREA_SELECT); - this.projectorScatterPlotAdapter.scatterPlot.setMouseMode(mouseMode); - } - - private setCurrentDataSet(ds: DataSet) { - this.adjustSelectionAndHover([]); - if (this.dataSet != null) { - this.dataSet.stopTSNE(); - } - if ((ds != null) && this.normalizeData) { - ds.normalize(); - } - this.dim = (ds == null) ? 0 : ds.dim[1]; - (this.querySelector('span.numDataPoints') as HTMLSpanElement).innerText = - (ds == null) ? '0' : '' + ds.dim[0]; - (this.querySelector('span.dim') as HTMLSpanElement).innerText = - (ds == null) ? '0' : '' + ds.dim[1]; - - this.dataSet = ds; - - this.projectionsPanel.dataSetUpdated( - this.dataSet, this.originalDataSet, this.dim); - - this.projectorScatterPlotAdapter.setDataSet(this.dataSet); - this.projectorScatterPlotAdapter.scatterPlot - .setCameraParametersForNextCameraCreation(null, true); - } - - private setupUIControls() { - // View controls - this.querySelector('#reset-zoom').addEventListener('click', () => { - this.projectorScatterPlotAdapter.scatterPlot.resetZoom(); - this.projectorScatterPlotAdapter.scatterPlot.startOrbitAnimation(); - }); - - let selectModeButton = this.querySelector('#selectMode'); - selectModeButton.addEventListener('click', (event) => { - this.setMouseMode( - (selectModeButton as any).active ? MouseMode.AREA_SELECT : - MouseMode.CAMERA_AND_CLICK_SELECT); - }); - let nightModeButton = this.querySelector('#nightDayMode'); - nightModeButton.addEventListener('click', () => { - this.projectorScatterPlotAdapter.scatterPlot.setDayNightMode( - (nightModeButton as any).active); - }); - - const labels3DModeButton = this.get3DLabelModeButton(); - labels3DModeButton.addEventListener('click', () => { - this.projectorScatterPlotAdapter.set3DLabelMode(this.get3DLabelMode()); - }); - - window.addEventListener('resize', () => { - const container = this.querySelector('#container') as HTMLDivElement; - const parentHeight = (container.parentNode as HTMLElement).clientHeight; - container.style.height = parentHeight + 'px'; - this.projectorScatterPlotAdapter.resize(); - }); - - { - this.projectorScatterPlotAdapter = new ProjectorScatterPlotAdapter( - this.getScatterContainer(), this as ProjectorEventContext); - this.projectorScatterPlotAdapter.setLabelPointAccessor( - this.selectedLabelOption); - } - - this.projectorScatterPlotAdapter.scatterPlot.onCameraMove( - (cameraPosition: THREE.Vector3, cameraTarget: THREE.Vector3) => - this.bookmarkPanel.clearStateSelection()); - - this.registerHoverListener( - (hoverIndex: number) => this.onHover(hoverIndex)); - - this.registerSelectionChangedListener( - (selectedPointIndices: number[], - neighborsOfFirstPoint: knn.NearestEntry[]) => - this.onSelectionChanged( - selectedPointIndices, neighborsOfFirstPoint)); - } - - private onHover(hoverIndex: number) { - this.hoverPointIndex = hoverIndex; - let hoverText = null; - if (hoverIndex != null) { - const point = this.dataSet.points[hoverIndex]; - if (point.metadata[this.selectedLabelOption]) { - hoverText = point.metadata[this.selectedLabelOption].toString(); - } - } - if (this.selectedPointIndices.length === 0) { - this.statusBar.style.display = hoverText ? null : 'none'; - this.statusBar.innerText = hoverText; - } - } - - private getScatterContainer(): HTMLDivElement { - return this.querySelector('#scatter') as HTMLDivElement; - } - - private onSelectionChanged( - selectedPointIndices: number[], - neighborsOfFirstPoint: knn.NearestEntry[]) { - this.selectedPointIndices = selectedPointIndices; - this.neighborsOfFirstPoint = neighborsOfFirstPoint; - let totalNumPoints = - this.selectedPointIndices.length + neighborsOfFirstPoint.length; - this.statusBar.innerText = `Selected ${totalNumPoints} points`; - this.statusBar.style.display = totalNumPoints > 0 ? null : 'none'; - } - - setProjection(projection: Projection) { - this.projection = projection; - if (projection != null) { - this.analyticsLogger.logProjectionChanged(projection.projectionType); - } - this.notifyProjectionChanged(projection); - } - - notifyProjectionPositionsUpdated() { - this.projectorScatterPlotAdapter.notifyProjectionPositionsUpdated(); - } - - /** - * Gets the current view of the embedding and saves it as a State object. - */ - getCurrentState(): State { - const state = new State(); - - // Save the individual datapoint projections. - state.projections = []; - for (let i = 0; i < this.dataSet.points.length; i++) { - const point = this.dataSet.points[i]; - const projections: {[key: string]: number} = {}; - const keys = Object.keys(point.projections); - for (let j = 0; j < keys.length; ++j) { - projections[keys[j]] = point.projections[keys[j]]; - } - state.projections.push(projections); - } - state.selectedProjection = this.projection.projectionType; - state.dataSetDimensions = this.dataSet.dim; - state.tSNEIteration = this.dataSet.tSNEIteration; - state.selectedPoints = this.selectedPointIndices; - state.filteredPoints = this.dataSetFilterIndices; - this.projectorScatterPlotAdapter.populateBookmarkFromUI(state); - state.selectedColorOptionName = this.dataPanel.selectedColorOptionName; - state.forceCategoricalColoring = this.dataPanel.forceCategoricalColoring; - state.selectedLabelOption = this.selectedLabelOption; - this.projectionsPanel.populateBookmarkFromUI(state); - return state; - } - - /** Loads a State object into the world. */ - loadState(state: State) { - this.setProjection(null); - { - this.projectionsPanel.disablePolymerChangesTriggerReprojection(); - if (this.dataSetBeforeFilter != null) { - this.resetFilterDataset(); - } - if (state.filteredPoints != null) { - this.filterDataset(state.filteredPoints); - } - this.projectionsPanel.enablePolymerChangesTriggerReprojection(); - } - for (let i = 0; i < state.projections.length; i++) { - const point = this.dataSet.points[i]; - const projection = state.projections[i]; - const keys = Object.keys(projection); - for (let j = 0; j < keys.length; ++j) { - point.projections[keys[j]] = projection[keys[j]]; - } - } - this.dataSet.hasTSNERun = (state.selectedProjection === 'tsne'); - this.dataSet.tSNEIteration = state.tSNEIteration; - this.projectionsPanel.restoreUIFromBookmark(state); - this.inspectorPanel.restoreUIFromBookmark(state); - this.dataPanel.selectedColorOptionName = state.selectedColorOptionName; - this.dataPanel.setForceCategoricalColoring( - !!state.forceCategoricalColoring); - this.selectedLabelOption = state.selectedLabelOption; - this.projectorScatterPlotAdapter.restoreUIFromBookmark(state); - { - const dimensions = stateGetAccessorDimensions(state); - const components = - data.getProjectionComponents(state.selectedProjection, dimensions); - const projection = new Projection( - state.selectedProjection, components, dimensions.length, - this.dataSet); - this.setProjection(projection); - } - this.notifySelectionChanged(state.selectedPoints); - } -} - -document.registerElement(Projector.prototype.is, Projector); diff --git a/tensorflow/tensorboard/components/vz_sorting/BUILD b/tensorflow/tensorboard/components/vz_sorting/BUILD deleted file mode 100644 index e06b8ae19790490e73d3ceb552ea03d9f304e68d..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/components/vz_sorting/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:defs.bzl", "tensorboard_webcomponent_library") -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "vz_sorting", - srcs = [ - "sorting.ts", - "vz-sorting.html", - ], - path = "/vz-sorting", - visibility = ["//visibility:public"], -) - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [":vz_sorting"], - destdir = "vz-sorting", -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/vz_sorting/sorting.ts b/tensorflow/tensorboard/components/vz_sorting/sorting.ts deleted file mode 100644 index 061184d24bf30623e05834269b32acf745a56299..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/components/vz_sorting/sorting.ts +++ /dev/null @@ -1,107 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -/** - * Compares tag names asciinumerically broken into components. - * - *

This is the comparison function used for sorting most string values in - * TensorBoard. Unlike the standard asciibetical comparator, this function - * knows that 'a10b' > 'a2b'. Fixed point and engineering notation are - * supported. This function also splits the input by slash and underscore to - * perform array comparison. Therefore it knows that 'a/a' < 'a+/a' even - * though '+' < '/' in the ASCII table. - */ -export function compareTagNames(a, b: string): number { - let ai = 0; - let bi = 0; - while (true) { - if (ai === a.length) { - return bi === b.length ? 0 : -1; - } - if (bi === b.length) { - return 1; - } - if (isDigit(a[ai]) && isDigit(b[bi])) { - const ais = ai; - const bis = bi; - ai = consumeNumber(a, ai + 1); - bi = consumeNumber(b, bi + 1); - const an = parseFloat(a.slice(ais, ai)); - const bn = parseFloat(b.slice(bis, bi)); - if (an < bn) { - return -1; - } - if (an > bn) { - return 1; - } - continue; - } - if (isBreak(a[ai])) { - if (!isBreak(b[bi])) { - return -1; - } - } else if (isBreak(b[bi])) { - return 1; - } else if (a[ai] < b[bi]) { - return -1; - } else if (a[ai] > b[bi]) { - return 1; - } - ai++; - bi++; - } -} - -function consumeNumber(s: string, i: number): number { - enum State { NATURAL, REAL, EXPONENT_SIGN, EXPONENT } - let state = State.NATURAL; - for (; i < s.length; i++) { - if (state === State.NATURAL) { - if (s[i] === '.') { - state = State.REAL; - } else if (s[i] === 'e' || s[i] === 'E') { - state = State.EXPONENT_SIGN; - } else if (!isDigit(s[i])) { - break; - } - } else if (state === State.REAL) { - if (s[i] === 'e' || s[i] === 'E') { - state = State.EXPONENT_SIGN; - } else if (!isDigit(s[i])) { - break; - } - } else if (state === State.EXPONENT_SIGN) { - if (isDigit(s[i]) || s[i] === '+' || s[i] === '-') { - state = State.EXPONENT; - } else { - break; - } - } else if (state === State.EXPONENT) { - if (!isDigit(s[i])) { - break; - } - } - } - return i; -} - -function isDigit(c: string): boolean { - return '0' <= c && c <= '9'; -} - -function isBreak(c: string): boolean { - // TODO(jart): Remove underscore when people stop using it like a slash. - return c === '/' || c === '_' || isDigit(c); -} diff --git a/tensorflow/tensorboard/components/vz_sorting/test/BUILD b/tensorflow/tensorboard/components/vz_sorting/test/BUILD deleted file mode 100644 index 929e80d37282387823ea4a93874a112710269cc1..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/components/vz_sorting/test/BUILD +++ /dev/null @@ -1,37 +0,0 @@ -package( - default_testonly = True, - default_visibility = ["//tensorflow/tensorboard:internal"], -) - -load("//tensorflow/tensorboard/defs:vulcanize.bzl", "tensorboard_html_binary") -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "test", - srcs = [ - "sortingTests.ts", - "tests.html", - ], - path = "/vz-sorting/test", - deps = [ - "//tensorflow/tensorboard/components/tf_imports:web_component_tester", - "//tensorflow/tensorboard/components/vz_sorting", - ], -) - -tensorboard_html_binary( - name = "devserver", - compilation_level = "WHITESPACE_ONLY", - input_path = "/vz-sorting/test/tests.html", - output_path = "/vz-sorting/test/tests.html", - deps = [":test"], -) - -filegroup( - name = "all_files", - testonly = 0, - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/vz_sorting/test/sortingTests.ts b/tensorflow/tensorboard/components/vz_sorting/test/sortingTests.ts deleted file mode 100644 index 510685cb4b5e42ca19e56acef6b1f87347811c99..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/components/vz_sorting/test/sortingTests.ts +++ /dev/null @@ -1,77 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {compareTagNames} from '../sorting'; - -describe('compareTagNames', () => { - - const assert = chai.assert; - const sortTagNames = (a) => a.sort(compareTagNames); - - it('is asciibetical', () => { - assert.deepEqual(sortTagNames(['a', 'b']), ['a', 'b']); - assert.deepEqual(sortTagNames(['a', 'B']), ['B', 'a']); - }); - - it('sorts integer portions', () => { - assert.deepEqual(['03', '1'].sort(), ['03', '1']); - assert.deepEqual(sortTagNames(['03', '1']), ['1', '03']); - assert.deepEqual(sortTagNames(['a03', 'a1']), ['a1', 'a03']); - assert.deepEqual(sortTagNames(['a03', 'b1']), ['a03', 'b1']); - assert.deepEqual(sortTagNames(['x0a03', 'x0a1']), ['x0a1', 'x0a03']); - assert.deepEqual(sortTagNames(['a/b/03', 'a/b/1']), ['a/b/1', 'a/b/03']); - }); - - it('sorts fixed point numbers', () => { - assert.deepEqual(sortTagNames(['a0.1', 'a0.01']), ['a0.01', 'a0.1']); - }); - - it('sorts engineering notation', () => { - assert.deepEqual(sortTagNames(['a1e9', 'a9e8']), ['a9e8', 'a1e9']); - assert.deepEqual(sortTagNames(['a1e+9', 'a9e+8']), ['a9e+8', 'a1e+9']); - assert.deepEqual(sortTagNames(['a1e+5', 'a9e-6']), ['a9e-6', 'a1e+5']); - assert.deepEqual(sortTagNames(['a1.0e9', 'a9.0e8']), ['a9.0e8', 'a1.0e9']); - assert.deepEqual( - sortTagNames(['a1.0e+9', 'a9.0e+8']), ['a9.0e+8', 'a1.0e+9']); - }); - - it('is componentized by slash', () => { - assert.deepEqual(['a+/a', 'a/a', 'ab/a'].sort(), ['a+/a', 'a/a', 'ab/a']); - assert.deepEqual( - sortTagNames(['a+/a', 'a/a', 'ab/a']), ['a/a', 'a+/a', 'ab/a']); - }); - - it('is componentized by underscore', () => { - assert.deepEqual( - sortTagNames(['a+_a', 'a_a', 'ab_a']), ['a_a', 'a+_a', 'ab_a']); - assert.deepEqual( - sortTagNames(['a+/a', 'a_a', 'ab_a']), ['a_a', 'a+/a', 'ab_a']); - }); - - it('is componentized by number boundaries', () => { - assert.deepEqual( - sortTagNames(['a+0a', 'a0a', 'ab0a']), ['a0a', 'a+0a', 'ab0a']); - }); - - it('empty comes first', () => { - assert.deepEqual(sortTagNames(['a', '//', '/', '']), ['', '/', '//', 'a']); - }); - - it('decimal parsed correctly', () => { - assert.deepEqual(sortTagNames(['0.2', '0.03']), ['0.03', '0.2']); - assert.deepEqual(sortTagNames(['0..2', '0..03']), ['0..2', '0..03']); - assert.deepEqual(sortTagNames(['.2', '.03']), ['.2', '.03']); - }); -}); diff --git a/tensorflow/tensorboard/components/vz_sorting/test/tests.html b/tensorflow/tensorboard/components/vz_sorting/test/tests.html deleted file mode 100644 index f92c608cdb125ec7e6d6b538d089f2779732ce6a..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/components/vz_sorting/test/tests.html +++ /dev/null @@ -1,23 +0,0 @@ - - - - - - - - diff --git a/tensorflow/tensorboard/components/vz_sorting/vz-sorting.html b/tensorflow/tensorboard/components/vz_sorting/vz-sorting.html deleted file mode 100644 index 5ff6f311589d2ef1c65dbfb052d255390c36991f..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/components/vz_sorting/vz-sorting.html +++ /dev/null @@ -1,18 +0,0 @@ - - - diff --git a/tensorflow/tensorboard/defs/BUILD b/tensorflow/tensorboard/defs/BUILD deleted file mode 100644 index 92a2af34048deaf6da07a7b14aa42e4cd8202958..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/defs/BUILD +++ /dev/null @@ -1,14 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -licenses(["notice"]) # Apache 2.0 - -filegroup( - name = "ts_web_library_default_typings", - srcs = [ - # Ordering probably matters. - "@com_microsoft_typescript//:lib.es6.d.ts", - "@io_angular_clutz//:src/resources/closure.lib.d.ts", - "clutz.d.ts", - ], - visibility = ["//visibility:public"], -) diff --git a/tensorflow/tensorboard/defs/clutz.d.ts b/tensorflow/tensorboard/defs/clutz.d.ts deleted file mode 100644 index 47cf307d2619a4a84f631dceb03b393cd04aa0d6..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/defs/clutz.d.ts +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// tslint:disable -declare namespace ಠ_ಠ.clutz { - interface IteratorIterable extends Iterator, Iterable {} - interface IIterableResult extends IteratorResult {} -} diff --git a/tensorflow/tensorboard/defs/hacks.bzl b/tensorflow/tensorboard/defs/hacks.bzl deleted file mode 100644 index f1d4be790612ac912dc1b1a2298f8bc8dd99dee6..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/defs/hacks.bzl +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# TODO(jart): Merge this file into defs.bzl once that file is sync unified. - -def tensorboard_typescript_bundle( - name, - out, - namespace_srcs, - namespace_symbol_aliases={}, - namespace_symbol_aliases_public={}, - **kwargs): - """Rolls TypeScript ES6 modules into one vanilla source file without imports. - - This is a genrule wrapper that concatenates TypeScripts sources inside - namespace blocks while removing ^import lines. Because the sources themselves - are not parsed, the structure of the modules must be passed to this macro as - a Skylark data structure. - - Args: - name: Name of this build rule target. - out: Path of outputted TypeScript source file. - namespace_srcs: Multimap of namespace strings to build file targets. The - ordering of the dictionary and nested lists does not matter when - generating a typings file, but *does* matter when generating a source - file. - namespace_symbol_aliases: Map of namespace strings where each value is a - map of symbol names to fully qualified symbol names. - namespace_symbol_aliases_public: Same as namespace_symbol_aliases but the - symbol will be visible to other namespaces. - """ - cmd = ["(", "echo // GENERATED BY TENSORBOARD_TYPESCRIPT_BUNDLE"] - inputs = set() - for namespace, srcs in namespace_srcs.items(): - cmd.append("echo") - if out[-5:] == ".d.ts": - cmd.append("echo 'declare namespace %s {'" % namespace) - elif out[-3:] == ".ts": - cmd.append("echo 'module %s {'" % namespace) - else: - fail("'out' must end with .ts or .d.ts: " + out) - for symbol, canon in namespace_symbol_aliases.get(namespace, {}).items(): - cmd.append("echo 'import %s = %s;'" % (symbol, canon)) - for symbol, canon in namespace_symbol_aliases_public.get(namespace, - {}).items(): - cmd.append("echo 'export import %s = %s;'" % (symbol, canon)) - inputs += srcs - for src in srcs: - cmd.append("for f in $(locations %s); do" % src) - cmd.append(" echo") - cmd.append(" echo /////////////////////////////////////////////////////") - cmd.append(" echo // " + namespace) - cmd.append(" echo // $$f") - cmd.append(" echo /////////////////////////////////////////////////////") - cmd.append(" echo") - cmd.append(" sed 's!^import !// import !' $$f \\") - cmd.append(" | sed 's!^export declare !export !' \\") - cmd.append(" | sed '/^export .* from /d' \\") - cmd.append(" | sed '/^export {.*};$$/d'") - cmd.append("done") - cmd.append("echo '}'") - cmd.append(") >$@") - native.genrule( - name = name, - srcs = list(inputs), - outs = [out], - cmd = "\n".join(cmd), - **kwargs - ) diff --git a/tensorflow/tensorboard/defs/protos.bzl b/tensorflow/tensorboard/defs/protos.bzl deleted file mode 100644 index 6d1982e098d9c549a3f6387035c6877d0b798ab7..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/defs/protos.bzl +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the 'License'); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an 'AS IS' BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("@protobuf//:protobuf.bzl", "py_proto_library") - -def tb_proto_library(name, srcs = [], visibility = []): - py_proto_library( - name = name + "_py", - srcs = srcs, - srcs_version = "PY2AND3", - deps = ["@protobuf//:protobuf_python"], - protoc = "@protobuf//:protoc", - visibility = visibility, - default_runtime = "@protobuf//:protobuf_python", - testonly = 0, - ) \ No newline at end of file diff --git a/tensorflow/tensorboard/defs/vulcanize.bzl b/tensorflow/tensorboard/defs/vulcanize.bzl deleted file mode 100644 index 6ff49a35ed73f0a8a5fb7ce5b3544e0807e1c0bc..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/defs/vulcanize.bzl +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/tensorboard/defs:defs.bzl", "legacy_js") -load("@io_bazel_rules_closure//closure/private:defs.bzl", "collect_js", "unfurl", "long_path") -load("//tensorflow/tensorboard/defs:web.bzl", "web_aspect") - -def _tensorboard_html_binary(ctx): - deps = unfurl(ctx.attr.deps, provider="webfiles") - manifests = set(order="topological") - files = set() - webpaths = set() - for dep in deps: - manifests += dep.webfiles.manifests - webpaths += dep.webfiles.webpaths - files += dep.data_runfiles.files - webpaths += [ctx.attr.output_path] - closure_js_library=collect_js( - ctx, unfurl(ctx.attr.deps, provider="closure_js_library")) - - # vulcanize - jslibs = depset(ctx.files._jslibs) + closure_js_library.srcs - ctx.action( - inputs=list(manifests | files | jslibs), - outputs=[ctx.outputs.html], - executable=ctx.executable._Vulcanize, - arguments=([ctx.attr.compilation_level, - "true" if ctx.attr.testonly else "false", - ctx.attr.input_path, - ctx.attr.output_path, - ctx.outputs.html.path] + - [f.path for f in jslibs] + - [f.path for f in manifests]), - progress_message="Vulcanizing %s" % ctx.attr.input_path) - - # webfiles manifest - manifest_srcs = [struct(path=ctx.outputs.html.path, - longpath=long_path(ctx, ctx.outputs.html), - webpath=ctx.attr.output_path)] - manifest = ctx.new_file(ctx.configuration.bin_dir, - "%s.pbtxt" % ctx.label.name) - ctx.file_action( - output=manifest, - content=struct( - label=str(ctx.label), - src=manifest_srcs).to_proto()) - manifests += [manifest] - - # webfiles server - params = struct( - label=str(ctx.label), - bind="[::]:6006", - manifest=[long_path(ctx, man) for man in manifests], - external_asset=[struct(webpath=k, path=v) - for k, v in ctx.attr.external_assets.items()]) - params_file = ctx.new_file(ctx.configuration.bin_dir, - "%s_server_params.pbtxt" % ctx.label.name) - ctx.file_action(output=params_file, content=params.to_proto()) - ctx.file_action( - executable=True, - output=ctx.outputs.executable, - content="#!/bin/sh\nexec %s %s" % ( - ctx.executable._WebfilesServer.short_path, - long_path(ctx, params_file))) - - transitive_runfiles = depset() - transitive_runfiles += ctx.attr._WebfilesServer.data_runfiles.files - for dep in deps: - transitive_runfiles += dep.data_runfiles.files - return struct( - files=depset([ctx.outputs.html]), - webfiles=struct( - manifest=manifest, - manifests=manifests, - webpaths=webpaths, - dummy=ctx.outputs.html), - runfiles=ctx.runfiles( - files=ctx.files.data + [manifest, - params_file, - ctx.outputs.html, - ctx.outputs.executable], - transitive_files=transitive_runfiles)) - -tensorboard_html_binary = rule( - implementation=_tensorboard_html_binary, - executable=True, - attrs={ - "compilation_level": attr.string(default="ADVANCED"), - "input_path": attr.string(mandatory=True), - "output_path": attr.string(mandatory=True), - "data": attr.label_list(cfg="data", allow_files=True), - "deps": attr.label_list( - aspects=[ - web_aspect, - legacy_js, - ], - mandatory=True), - "external_assets": attr.string_dict(default={"/_/runfiles": "."}), - "_jslibs": attr.label( - default=Label("//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:jslibs"), - allow_files=True), - "_Vulcanize": attr.label( - default=Label("//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:Vulcanize"), - executable=True, - cfg="host"), - "_WebfilesServer": attr.label( - default=Label( - "@io_bazel_rules_closure//java/io/bazel/rules/closure/webfiles/server:WebfilesServer"), - executable=True, - cfg="host"), - }, - outputs={ - "html": "%{name}.html", - }) diff --git a/tensorflow/tensorboard/defs/web.bzl b/tensorflow/tensorboard/defs/web.bzl deleted file mode 100644 index 103942b0a25d2706b1af445383689dca02407d91..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/defs/web.bzl +++ /dev/null @@ -1,419 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Same as web_library but supports TypeScript.""" - -load("//tensorflow/tensorboard/defs:defs.bzl", "legacy_js") - -load("//third_party:clutz.bzl", - "CLUTZ_ATTRIBUTES", - "CLUTZ_OUTPUTS", - "clutz_aspect", - "extract_dts_from_closure_libraries") - -load("@io_bazel_rules_closure//closure/private:defs.bzl", - "CLOSURE_LIBRARY_BASE_ATTR", - "CLOSURE_LIBRARY_DEPS_ATTR", - "collect_js", - "collect_runfiles", - "convert_path_to_es6_module_name", - "create_argfile", - "difference", - "long_path", - "unfurl") - -_ASPECT_SLURP_FILE_TYPE = FileType([ - ".html", ".js", ".css", ".gss", ".png", ".jpg", ".gif", ".ico", ".svg"]) - -_CLOSURE_WORKER = attr.label( - default=Label("@io_bazel_rules_closure//java/io/bazel/rules/closure:ClosureWorker"), - executable=True, - cfg="host") - -def _ts_web_library(ctx): - if not ctx.attr.srcs: - if ctx.attr.deps: - fail("deps can not be set when srcs is not") - if not ctx.attr.exports: - fail("exports must be set if srcs is not") - if ctx.attr.path: - if not ctx.attr.path.startswith("/"): - fail("webpath must start with /") - if ctx.attr.path != "/" and ctx.attr.path.endswith("/"): - fail("webpath must not end with / unless it is /") - if "//" in ctx.attr.path: - fail("webpath must not have //") - elif ctx.attr.srcs: - fail("path must be set when srcs is set") - if "*" in ctx.attr.suppress and len(ctx.attr.suppress) != 1: - fail("when \"*\" is suppressed no other items should be present") - - # process what came before - deps = unfurl(ctx.attr.deps, provider="webfiles") - webpaths = depset() - ts_typings = depset(ctx.files._default_typings) - ts_typings_paths = depset( - [long_path(ctx, f) for f in ctx.files._default_typings]) - ts_typings_execroots = depset() - aspect_runfiles = depset() - for dep in deps: - webpaths += dep.webfiles.webpaths - if hasattr(dep.webfiles, "ts_typings"): - ts_typings += dep.webfiles.ts_typings - if hasattr(dep.webfiles, "ts_typings_paths"): - ts_typings_paths += dep.webfiles.ts_typings_paths - if hasattr(dep.webfiles, "ts_typings_execroots"): - ts_typings_execroots += dep.webfiles.ts_typings_execroots - if hasattr(dep.webfiles, "aspect_runfiles"): - aspect_runfiles += dep.webfiles.aspect_runfiles - - # process what comes now - manifest_srcs = [] - new_webpaths = [] - ts_inputs = depset() - ts_outputs = [] - ts_files = list(ts_typings_paths) - new_typings = [] - new_typings_paths = [] - new_typings_execroot = struct(inputs=[]) - execroot = struct( - inputs=[(long_path(ctx, f), f.path) for f in ctx.files._default_typings], - outputs=[], - program=[ctx.executable._tsc.path, "-p"]) - web_srcs = [] - path = ctx.attr.path - strip = _get_strip(ctx) - for src in ctx.files.srcs: - suffix = _get_path_relative_to_package(src) - if strip: - if not suffix.startswith(strip): - fail("Relative src path not start with '%s': %s" % (strip, suffix)) - suffix = suffix[len(strip):] - webpath = "%s/%s" % ("" if path == "/" else path, suffix) - _add_webpath(ctx, src, webpath, webpaths, new_webpaths, manifest_srcs) - if suffix.endswith(".d.ts"): - web_srcs.append(src) - entry = (webpath[1:], src.path) - new_typings.append(src) - new_typings_paths.append(entry[0]) - new_typings_execroot.inputs.append(entry) - ts_inputs += [src] - ts_files.append(entry[0]) - execroot.inputs.append(entry) - elif suffix.endswith(".ts"): - noext = suffix[:-3] - js = ctx.new_file(ctx.bin_dir, "%s.js" % noext) - dts = ctx.new_file(ctx.bin_dir, "%s.d.ts" % noext) - webpath_js = webpath[:-3] + ".js" - webpath_dts = webpath[:-3] + ".d.ts" - _add_webpath(ctx, js, webpath_js, webpaths, new_webpaths, manifest_srcs) - _add_webpath(ctx, dts, webpath_dts, webpaths, new_webpaths, manifest_srcs) - ts_inputs += [src] - ts_outputs.append(js) - ts_outputs.append(dts) - web_srcs.append(dts) - web_srcs.append(js) - ts_files.append(webpath[1:]) - execroot.inputs.append((webpath[1:], src.path)) - execroot.outputs.append((webpath_js[1:], js.path)) - execroot.outputs.append((webpath_dts[1:], dts.path)) - new_typings.append(dts) - new_typings_paths.append(webpath_dts[1:]) - new_typings_execroot.inputs.append((webpath_dts[1:], dts.path)) - else: - web_srcs.append(src) - - # get typings for closure code - clutz_dts = extract_dts_from_closure_libraries(ctx) - if clutz_dts: - entry = (long_path(ctx, clutz_dts), clutz_dts.path) - ts_inputs += [clutz_dts] - ts_files.append(entry[0]) - execroot.inputs.append(entry) - - # compile typescript - workspace = "" - if ctx.label.workspace_root: - workspace = "/" + ctx.label.workspace_root - if execroot.outputs: - ts_config = _new_file(ctx, "-tsc.json") - execroot.inputs.append(("tsconfig.json", ts_config.path)) - ctx.file_action( - output=ts_config, - content=struct( - compilerOptions=struct( - baseUrl=".", - declaration=True, - inlineSourceMap=True, - inlineSources=True, - module="es6", - moduleResolution="node", - noResolve=True, - target="es5", - ), - files=ts_files, - ).to_json()) - er_config = _new_file(ctx, "-tsc-execroot.json") - ctx.file_action(output=er_config, content=execroot.to_json()) - ts_inputs += collect_runfiles([ctx.attr._tsc]) - ts_inputs += ctx.files._tsc - ts_inputs += ts_typings - ts_inputs += ts_typings_execroots - ts_inputs += [ts_config, er_config] - ctx.action( - inputs=list(ts_inputs), - outputs=ts_outputs, - executable=ctx.executable._execrooter, - arguments=[er_config.path] + [f.path for f in ts_typings_execroots], - progress_message="Compiling %d TypeScript files %s" % ( - len(ts_files), ctx.label)) - - # perform strict dependency checking - manifest = _make_manifest(ctx, manifest_srcs) - webpaths += new_webpaths - dummy, manifests = _run_webfiles_validator(ctx, web_srcs, deps, manifest) - web_srcs.append(dummy) - - # define development web server that only applies to this transitive closure - params = struct( - label=str(ctx.label), - bind="[::]:6006", - manifest=[long_path(ctx, man) for man in manifests], - external_asset=[struct(webpath=k, path=v) - for k, v in ctx.attr.external_assets.items()]) - params_file = _new_file(ctx, "-params.pbtxt") - ctx.file_action(output=params_file, content=params.to_proto()) - ctx.file_action( - executable=True, - output=ctx.outputs.executable, - content="#!/bin/sh\nexec %s %s" % ( - ctx.executable._WebfilesServer.short_path, - long_path(ctx, params_file))) - - if new_typings: - er_config = _new_file(ctx, "-typings-execroot.json") - ctx.file_action(output=er_config, content=new_typings_execroot.to_json()) - ts_typings += new_typings - ts_typings_paths += new_typings_paths - ts_typings_execroots += [er_config] - else: - ts_typings = depset() - ts_typings_paths = depset() - ts_typings_execroots = depset() - - # export data to parent rules - return struct( - files=depset(web_srcs + [dummy]), - exports=unfurl(ctx.attr.exports), - webfiles=struct( - manifest=manifest, - manifests=manifests, - webpaths=webpaths, - dummy=dummy, - ts_typings=ts_typings, - ts_typings_paths=ts_typings_paths, - ts_typings_execroots=ts_typings_execroots), - closure_js_library=collect_js( - ctx, unfurl(ctx.attr.deps, provider="closure_js_library")), - runfiles=ctx.runfiles( - files=ctx.files.srcs + ctx.files.data + ts_outputs + [ - manifest, - params_file, - ctx.outputs.executable, - dummy], - transitive_files=(collect_runfiles([ctx.attr._WebfilesServer]) | - collect_runfiles(deps) | - collect_runfiles(ctx.attr.data) | - aspect_runfiles))) - -def _web_aspect_impl(target, ctx): - if hasattr(target, "webfiles"): - return struct() - srcs = [] - deps = [] - if hasattr(ctx.rule.files, "srcs"): - srcs.extend(_ASPECT_SLURP_FILE_TYPE.filter(ctx.rule.files.srcs)) - for attr in ("deps", "sticky_deps", "module_deps"): - value = getattr(ctx.rule.attr, attr, None) - if value: - deps.extend(value) - deps = unfurl(deps, provider="webfiles") - webpaths = depset() - aspect_runfiles = depset(srcs) - for dep in deps: - webpaths += dep.webfiles.webpaths - if hasattr(dep.webfiles, "aspect_runfiles"): - aspect_runfiles += dep.webfiles.aspect_runfiles - manifest_srcs = [] - new_webpaths = [] - for src in srcs: - webpath = "/" + long_path(ctx, src) - _add_webpath(ctx, src, webpath, webpaths, new_webpaths, manifest_srcs) - webpaths += new_webpaths - manifest = _make_manifest(ctx, manifest_srcs) - dummy, manifests = _run_webfiles_validator(ctx, srcs, deps, manifest) - aspect_runfiles += [dummy, manifest] - return struct( - webfiles=struct( - manifest=manifest, - manifests=manifests, - webpaths=webpaths, - dummy=dummy, - aspect_runfiles=aspect_runfiles)) - -def _make_manifest(ctx, src_list): - manifest = _new_file(ctx, "-webfiles.pbtxt") - ctx.file_action( - output=manifest, - content=struct( - label=str(ctx.label), - src=src_list).to_proto()) - return manifest - -def _run_webfiles_validator(ctx, srcs, deps, manifest): - dummy = _new_file(ctx, "-webfiles.ignoreme") - manifests = depset(order="topological") - for dep in deps: - manifests += dep.webfiles.manifests - if srcs: - args = ["WebfilesValidator", - "--dummy", dummy.path, - "--target", manifest.path] - if hasattr(ctx, "attr") and hasattr(ctx.attr, "suppress"): - for category in ctx.attr.suppress: - args.append("--suppress") - args.append(category) - inputs = [manifest] - inputs.extend(srcs) - direct_manifests = depset() - for dep in deps: - inputs.append(dep.webfiles.dummy) - for f in dep.files: - inputs.append(f) - direct_manifests += [dep.webfiles.manifest] - inputs.append(dep.webfiles.manifest) - args.append("--direct_dep") - args.append(dep.webfiles.manifest.path) - for man in difference(manifests, direct_manifests): - inputs.append(man) - args.append("--transitive_dep") - args.append(man.path) - argfile = _new_file(ctx, "-webfiles-checker-args.txt") - ctx.file_action(output=argfile, content="\n".join(args)) - inputs.append(argfile) - ctx.action( - inputs=inputs, - outputs=[dummy], - executable=(getattr(ctx.executable, "_ClosureWorker", None) or - getattr(ctx.executable, "_ClosureWorkerAspect", None)), - arguments=["@@" + argfile.path], - mnemonic="Closure", - execution_requirements={"supports-workers": "1"}, - progress_message="Checking webfiles %s" % ctx.label) - else: - ctx.file_action(output=dummy, content="BOO!") - manifests += [manifest] - return dummy, manifests - -def _new_file(ctx, suffix): - return ctx.new_file(ctx.bin_dir, "%s%s" % (ctx.label.name, suffix)) - -def _add_webpath(ctx, src, webpath, webpaths, new_webpaths, manifest_srcs): - if webpath in new_webpaths: - _fail(ctx, "multiple srcs within %s define the webpath %s " % ( - ctx.label, webpath)) - if webpath in webpaths: - _fail(ctx, "webpath %s was defined by %s when already defined by deps" % ( - webpath, ctx.label)) - new_webpaths.append(webpath) - manifest_srcs.append(struct( - path=src.path, - longpath=long_path(ctx, src), - webpath=webpath)) - -def _fail(ctx, message): - if ctx.attr.suppress == ["*"]: - print(message) - else: - fail(message) - -def _get_path_relative_to_package(artifact): - """Returns file path relative to the package that declared it.""" - path = artifact.path - for prefix in (artifact.root.path, - artifact.owner.workspace_root if artifact.owner else '', - artifact.owner.package if artifact.owner else ''): - if prefix: - prefix = prefix + "/" - if not path.startswith(prefix): - fail("Path %s doesn't start with %s" % (path, prefix)) - path = path[len(prefix):] - return path - -def _get_strip(ctx): - strip = ctx.attr.strip_prefix - if strip: - if strip.startswith("/"): - _fail(ctx, "strip_prefix should not end with /") - strip = strip[1:] - if strip.endswith("/"): - _fail(ctx, "strip_prefix should not end with /") - else: - strip += "/" - return strip - -web_aspect = aspect( - implementation=_web_aspect_impl, - attr_aspects=["deps", "sticky_deps", "module_deps"], - attrs={"_ClosureWorkerAspect": _CLOSURE_WORKER}) - -ts_web_library = rule( - implementation=_ts_web_library, - executable=True, - attrs=CLUTZ_ATTRIBUTES + { - "path": attr.string(), - "srcs": attr.label_list(allow_files=True), - "deps": attr.label_list( - aspects=[ - web_aspect, - clutz_aspect, - legacy_js, - ]), - "exports": attr.label_list(), - "data": attr.label_list(cfg="data", allow_files=True), - "suppress": attr.string_list(), - "strip_prefix": attr.string(), - "external_assets": attr.string_dict(default={"/_/runfiles": "."}), - "clutz_entry_points": attr.string_list(), - "_execrooter": attr.label( - default=Label("//tensorflow/tensorboard/scripts:execrooter"), - executable=True, - cfg="host"), - "_tsc": attr.label( - default=Label("@com_microsoft_typescript//:tsc"), - allow_files=True, - executable=True, - cfg="host"), - "_default_typings": attr.label( - default=Label("//tensorflow/tensorboard:ts_web_library_default_typings"), - allow_files=True), - "_WebfilesServer": attr.label( - default=Label("@io_bazel_rules_closure//java/io/bazel/rules/closure/webfiles/server:WebfilesServer"), - executable=True, - cfg="host"), - "_ClosureWorker": _CLOSURE_WORKER, - "_closure_library_base": CLOSURE_LIBRARY_BASE_ATTR, - "_closure_library_deps": CLOSURE_LIBRARY_DEPS_ATTR, - }, - outputs=CLUTZ_OUTPUTS) diff --git a/tensorflow/tensorboard/defs/zipper.bzl b/tensorflow/tensorboard/defs/zipper.bzl deleted file mode 100644 index e98309ec9a5d5185ac48e235ceb10d0d3f0e153d..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/defs/zipper.bzl +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("@io_bazel_rules_closure//closure/private:defs.bzl", "unfurl", "long_path") - -def _tensorboard_zip_file(ctx): - deps = unfurl(ctx.attr.deps, provider="webfiles") - manifests = set(order="link") - files = set() - webpaths = set() - for dep in deps: - manifests += dep.webfiles.manifests - webpaths += dep.webfiles.webpaths - files += dep.data_runfiles.files - ctx.action( - inputs=list(manifests + files), - outputs=[ctx.outputs.zip], - executable=ctx.executable._Zipper, - arguments=([ctx.outputs.zip.path] + - [m.path for m in manifests]), - progress_message="Zipping %d files" % len(webpaths)) - transitive_runfiles = set() - for dep in deps: - transitive_runfiles += dep.data_runfiles.files - return struct( - files=set([ctx.outputs.zip]), - runfiles=ctx.runfiles( - files=ctx.files.data + [ctx.outputs.zip], - transitive_files=transitive_runfiles)) - -tensorboard_zip_file = rule( - implementation=_tensorboard_zip_file, - attrs={ - "data": attr.label_list(cfg="data", allow_files=True), - "deps": attr.label_list(providers=["webfiles"], mandatory=True), - "_Zipper": attr.label( - default=Label("//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:Zipper"), - executable=True, - cfg="host"), - }, - outputs={ - "zip": "%{name}.zip", - }) diff --git a/tensorflow/tensorboard/demo/BUILD b/tensorflow/tensorboard/demo/BUILD deleted file mode 100644 index b253572ec556314356dee4911eeb755e6da18950..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") - -licenses(["notice"]) # Apache 2.0 - -# THIS PACKAGE HAS MOVED -# See tensorflow/tensorboard/components/tf_tensorboard:demo - -web_library( - name = "demo_data", - srcs = glob(["data/**"]), - path = "/", -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/demo/data/audio_run_run1_tag_au1_2Faudio_2F0.json b/tensorflow/tensorboard/demo/data/audio_run_run1_tag_au1_2Faudio_2F0.json deleted file mode 100644 index 7dfe32c7112c61bcacf896de2d906bc06a9c952f..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/audio_run_run1_tag_au1_2Faudio_2F0.json +++ /dev/null @@ -1 +0,0 @@ -[{"query": "index=0&tag=au1%2Faudio%2F0&run=run1", "step": 0, "wall_time": 1461795049.203407, "content_type": "audio/wav"}] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/audio_run_run2_tag_au2_2Faudio_2F0.json b/tensorflow/tensorboard/demo/data/audio_run_run2_tag_au2_2Faudio_2F0.json deleted file mode 100644 index 13f9c2de4265d08a3b3635360d380c018f7aed7b..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/audio_run_run2_tag_au2_2Faudio_2F0.json +++ /dev/null @@ -1 +0,0 @@ -[{"query": "index=0&tag=au2%2Faudio%2F0&run=run2", "step": 0, "wall_time": 1461795049.212815, "content_type": "audio/wav"}] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/compressedHistograms_run_run1_tag_histo1.json b/tensorflow/tensorboard/demo/data/compressedHistograms_run_run1_tag_histo1.json deleted file mode 100644 index 6ae6fbf880e61bb8f7dfe3ed0a32dcba3e5d40cd..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/compressedHistograms_run_run1_tag_histo1.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, [[0, -2.3150592308536755], [668, -2.0967547155036605], [1587, -1.4326244423655616], [3085, -0.8871306575801902], [5000, -0.09312398815580714], [6915, 0.2584093405812282], [8413, 0.8895470642005087], [9332, 1.3198979614453679], [10000, 1.6793308878855118]]], [100.0, 10, [[0, -1.3417572789138936], [668, -1.183563374619141], [1587, -0.48920418783271574], [3085, 0.29326906896076954], [5000, 0.56953784145381], [6915, 0.8684655583499333], [8413, 1.4133127368907181], [9332, 1.906140650457873], [10000, 2.135771998171255]]], [200.0, 20, [[0, -1.5066917525035333], [668, -1.3910909571770793], [1587, -0.902737218885874], [3085, -0.3807791904765027], [5000, 0.38900200905253046], [6915, 0.8209734209339482], [8413, 1.302385856695965], [9332, 1.9324626053521639], [10000, 2.957505317875451]]], [300.0, 30, [[0, -0.5430457051469562], [668, -0.4626161834245273], [1587, 0.21573949543027715], [3085, 0.37353741100174215], [5000, 0.6891407881591103], [6915, 1.0927156232630852], [8413, 1.2745337159550916], [9332, 1.4321116832891605], [10000, 2.1913774993059034]]], [400.0, 40, [[0, -0.3584790755077172], [668, -0.33301611509753215], [1587, -0.1089466072951948], [3085, 0.5792199847585249], [5000, 1.220854943811942], [6915, 1.759829438421432], [8413, 2.3072559906741614], [9332, 2.753036118353921], [10000, 3.0267252195784047]]]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/compressedHistograms_run_run2_tag_histo1.json b/tensorflow/tensorboard/demo/data/compressedHistograms_run_run2_tag_histo1.json deleted file mode 100644 index 3ad520c5687cdec798b401d3740814de75d39bc8..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/compressedHistograms_run_run2_tag_histo1.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, [[0, -3.6801669545044846], [668, -3.192188140974744], [1587, -2.3414678549368806], [3085, -0.9632173471995873], [5000, -0.3214892636797772], [6915, 0.11870794142185205], [8413, 0.8895470642005087], [9332, 1.183563374619141], [10000, 2.665663810418372]]], [100.0, 10, [[0, -3.564793583751807], [668, -3.376844436865802], [1587, -1.0366615731293798], [3085, -0.27318696312672563], [5000, 0.9718642422053263], [6915, 2.5765662807928194], [8413, 3.1415385101545126], [9332, 4.085981768607621], [10000, 4.623079406808927]]], [200.0, 20, [[0, -2.235172510433281], [668, -2.004569042815611], [1587, -1.2015432383370985], [3085, 0.11835464933202625], [5000, 0.56953784145381], [6915, 1.202844810963146], [8413, 2.689066032283515], [9332, 2.8494015726499944], [10000, 3.481377676013788]]], [300.0, 30, [[0, -3.360113978269659], [668, -2.8293185004961043], [1587, -1.5992540502266783], [3085, 0.14393860259807117], [5000, 1.47723448201245], [6915, 1.9510057389110733], [8413, 2.833176104473626], [9332, 4.142405216576347], [10000, 4.706937777668589]]], [400.0, 40, [[0, -2.599286228987632], [668, -2.240365897443259], [1587, -1.5992540502266783], [3085, -0.9101893288861387], [5000, 0.7580548669750213], [6915, 1.6009864433919474], [8413, 2.3504002974280036], [9332, 2.7907805263353733], [10000, 3.5098048900144323]]]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/compressedHistograms_run_run2_tag_histo2.json b/tensorflow/tensorboard/demo/data/compressedHistograms_run_run2_tag_histo2.json deleted file mode 100644 index a3802ba2365adadb2453809fdf77d07ee5ef9b1f..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/compressedHistograms_run_run2_tag_histo2.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, [[0, -1.9291158122759586], [668, -1.5970765333488954], [1587, -1.0923120348519078], [3085, -0.6688082872192093], [5000, 0.09312398815580714], [6915, 0.44532789251701854], [8413, 0.8238009655877649], [9332, 1.0357232383581656], [10000, 1.2741043689144438]]], [100.0, 10, [[0, -0.7780725642449806], [668, -0.7138496178727424], [1587, -0.5448932415735014], [3085, -0.24370397454796228], [5000, 0.42790220995778355], [6915, 0.6191730643365096], [8413, 0.752059342118037], [9332, 1.0451472255274825], [10000, 2.5559479569222825]]], [200.0, 20, [[0, -1.3876904425996377], [668, -1.1464188862638496], [1587, -0.4049955219067526], [3085, 0.04721394862139682], [5000, 0.56953784145381], [6915, 1.3221859041483333], [8413, 1.6188495656305735], [9332, 1.7613953069723651], [10000, 2.3257482385477384]]], [300.0, 30, [[0, -1.600772629982185], [668, -1.1548516185367033], [1587, -0.260387173785447], [3085, 0.17416570914366614], [5000, 0.47069243095356195], [6915, 1.1559276581637614], [8413, 2.0474031182051404], [9332, 2.18821711651116], [10000, 2.2393193406467518]]], [400.0, 40, [[0, -0.8286852465281818], [668, -0.7815041529866706], [1587, -0.3334896444053469], [3085, 0.21085213041026643], [5000, 0.5177616740489182], [6915, 1.077122434649409], [8413, 1.5898009703967424], [9332, 1.8859097291499742], [10000, 2.0954239138728523]]]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/graph_run_run1.pbtxt b/tensorflow/tensorboard/demo/data/graph_run_run1.pbtxt deleted file mode 100644 index 2a6af3284086b4d797ebf3598bffe286d74baddf..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/graph_run_run1.pbtxt +++ /dev/null @@ -1,9 +0,0 @@ -node { - name: "a" - op: "matmul" -} -node { - name: "b" - op: "matmul" - input: "a:0" -} diff --git a/tensorflow/tensorboard/demo/data/graph_run_run2.pbtxt b/tensorflow/tensorboard/demo/data/graph_run_run2.pbtxt deleted file mode 100644 index a5a4d65d5c61a7cf1c208b48f841a38a03847d60..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/graph_run_run2.pbtxt +++ /dev/null @@ -1,15 +0,0 @@ -node { - name: "a" - op: "matmul" -} -node { - name: "b" - op: "matmul" - input: "a:0" -} -node { - name: "c" - op: "matmul" - input: "a:0" - input: "b:0" -} diff --git a/tensorflow/tensorboard/demo/data/histograms_run_run1_tag_histo1.json b/tensorflow/tensorboard/demo/data/histograms_run_run1_tag_histo1.json deleted file mode 100644 index a5600a356e8277e58be3b2891c3e328d058b5d08..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/histograms_run_run1_tag_histo1.json +++ /dev/null @@ -1 +0,0 @@ -[[400.0, 40, [-0.3584790755077172, 3.0267252195784047, 20.0, 24.012225532303315, 48.29045006426564, [-0.35363819004775493, -0.29226296698161564, -0.19961953895336082, 0.3214892636797772, 0.5177616740489182, 0.56953784145381, 0.6264916255991911, 0.7580548669750213, 0.8338603536725235, 1.220854943811942, 1.3429404381931362, 1.47723448201245, 1.624957930213695, 1.7874537232350647, 1.9661990955585713, 2.379100905625872, 2.6170109961884593, 3.1665833053880363], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 2.0]]]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/histograms_run_run2_tag_histo1.json b/tensorflow/tensorboard/demo/data/histograms_run_run2_tag_histo1.json deleted file mode 100644 index 407c375d2fc710e70408a3238df3a6165e964e84..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/histograms_run_run2_tag_histo1.json +++ /dev/null @@ -1 +0,0 @@ -[[400.0, 40, [-2.599286228987632, 3.5098048900144323, 20.0, 10.792285491200078, 66.66796979177158, [-2.379100905625872, -1.9661990955585713, -1.624957930213695, -1.47723448201245, -1.109868130738129, -1.0089710279437536, -0.42790220995778355, -0.2195814928486969, 0.47069243095356195, 0.7580548669750213, 0.917246389039776, 1.3429404381931362, 1.624957930213695, 1.7874537232350647, 2.1628190051144287, 2.6170109961884593, 2.8787120958073054, 3.8315657995195243], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 2.0, 1.0, 1.0]]]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/histograms_run_run2_tag_histo2.json b/tensorflow/tensorboard/demo/data/histograms_run_run2_tag_histo2.json deleted file mode 100644 index 752b621ab032f24805574708e1659c7139a701a8..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/histograms_run_run2_tag_histo2.json +++ /dev/null @@ -1 +0,0 @@ -[[400.0, 40, [-0.8286852465281818, 2.0954239138728523, 20.0, 13.546880465642861, 24.14836803774091, [-0.7580548669750213, -0.38900200905253046, -0.06996543062044111, 0.07696197368248522, 0.19961953895336082, 0.2656936063469233, 0.29226296698161564, 0.5177616740489182, 0.7580548669750213, 0.917246389039776, 1.109868130738129, 1.220854943811942, 1.624957930213695, 2.1628190051144287], [2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 1.0, 1.0, 2.0, 3.0]]]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/images_run_run1_tag_im1_2Fimage_2F0.json b/tensorflow/tensorboard/demo/data/images_run_run1_tag_im1_2Fimage_2F0.json deleted file mode 100644 index 814b4193c638749620e86ac21b86c48747f18f4c..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/images_run_run1_tag_im1_2Fimage_2F0.json +++ /dev/null @@ -1 +0,0 @@ -[{"wall_time": 1459200389.088045, "width": 4, "height": 4, "step": 0, "query": "tag=im1%2Fimage%2F0&index=0&run=run1"}] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/images_run_run1_tag_im2_2Fimage_2F0.json b/tensorflow/tensorboard/demo/data/images_run_run1_tag_im2_2Fimage_2F0.json deleted file mode 100644 index 0c2bdcfc79cb32433ac987752851ef6dd351b058..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/images_run_run1_tag_im2_2Fimage_2F0.json +++ /dev/null @@ -1 +0,0 @@ -[{"wall_time": 1459200389.093653, "width": 4, "height": 4, "step": 0, "query": "tag=im2%2Fimage%2F0&index=0&run=run1"}] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/images_run_run2_tag_im1_2Fimage_2F0.json b/tensorflow/tensorboard/demo/data/images_run_run2_tag_im1_2Fimage_2F0.json deleted file mode 100644 index 3160aae366d904d5be5be22d60ca1b345a9d5172..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/images_run_run2_tag_im1_2Fimage_2F0.json +++ /dev/null @@ -1 +0,0 @@ -[{"wall_time": 1459200389.117463, "width": 4, "height": 4, "step": 0, "query": "tag=im1%2Fimage%2F0&index=0&run=run2"}] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/individualAudio_index_0_tag_au1_2Faudio_2F0_run_run1.wav b/tensorflow/tensorboard/demo/data/individualAudio_index_0_tag_au1_2Faudio_2F0_run_run1.wav deleted file mode 100644 index f1d24adc0cef5a734e07e8899b9abf8ae26fa228..0000000000000000000000000000000000000000 Binary files a/tensorflow/tensorboard/demo/data/individualAudio_index_0_tag_au1_2Faudio_2F0_run_run1.wav and /dev/null differ diff --git a/tensorflow/tensorboard/demo/data/individualAudio_index_0_tag_au2_2Faudio_2F0_run_run2.wav b/tensorflow/tensorboard/demo/data/individualAudio_index_0_tag_au2_2Faudio_2F0_run_run2.wav deleted file mode 100644 index 006c84338f7313a225830f121bcd95f457de1708..0000000000000000000000000000000000000000 Binary files a/tensorflow/tensorboard/demo/data/individualAudio_index_0_tag_au2_2Faudio_2F0_run_run2.wav and /dev/null differ diff --git a/tensorflow/tensorboard/demo/data/individualImage_tag_im1_2Fimage_2F0_index_0_run_run1.png b/tensorflow/tensorboard/demo/data/individualImage_tag_im1_2Fimage_2F0_index_0_run_run1.png deleted file mode 100644 index 346fd0076be28b9338152c4d49a32fc5ed685e44..0000000000000000000000000000000000000000 Binary files a/tensorflow/tensorboard/demo/data/individualImage_tag_im1_2Fimage_2F0_index_0_run_run1.png and /dev/null differ diff --git a/tensorflow/tensorboard/demo/data/individualImage_tag_im1_2Fimage_2F0_index_0_run_run2.png b/tensorflow/tensorboard/demo/data/individualImage_tag_im1_2Fimage_2F0_index_0_run_run2.png deleted file mode 100644 index 26d2d10acaf8511efeb03169853092d09252215b..0000000000000000000000000000000000000000 Binary files a/tensorflow/tensorboard/demo/data/individualImage_tag_im1_2Fimage_2F0_index_0_run_run2.png and /dev/null differ diff --git a/tensorflow/tensorboard/demo/data/individualImage_tag_im2_2Fimage_2F0_index_0_run_run1.png b/tensorflow/tensorboard/demo/data/individualImage_tag_im2_2Fimage_2F0_index_0_run_run1.png deleted file mode 100644 index 6c4190629429e0929962c4f20bd1a1602620e4bd..0000000000000000000000000000000000000000 Binary files a/tensorflow/tensorboard/demo/data/individualImage_tag_im2_2Fimage_2F0_index_0_run_run1.png and /dev/null differ diff --git a/tensorflow/tensorboard/demo/data/logdir b/tensorflow/tensorboard/demo/data/logdir deleted file mode 100644 index b6362b45d777266d6204b23884222a080f789f71..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/logdir +++ /dev/null @@ -1 +0,0 @@ -{"logdir": "/foo/some/fake/logdir"} \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/runs.json b/tensorflow/tensorboard/demo/data/runs.json deleted file mode 100644 index e09039054299cdc3e3453c620761e1ed6e0c0169..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/runs.json +++ /dev/null @@ -1 +0,0 @@ -{"run1": {"scalars": ["foo/sin", "foo/cos", "foo/square", "bar/square"], "run_metadata": [], "compressedHistograms": ["histo1"], "images": ["im1/image/0", "im2/image/0"], "histograms": ["histo1"], "graph": true, "audio": ["au1/audio/0"]}, "run2": {"scalars": ["foo/cos", "foo/square", "bar/square"], "run_metadata": [], "compressedHistograms": ["histo2", "histo1"], "images": ["im1/image/0"], "histograms": ["histo2", "histo1"], "graph": true, "audio": ["au2/audio/0"]}} \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/scalars.json b/tensorflow/tensorboard/demo/data/scalars.json deleted file mode 100644 index bc269395b68a35f7d4481fca05063e46c79c2859..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/scalars.json +++ /dev/null @@ -1 +0,0 @@ -{"run2": {"foo/cos": [[0.0, 0, 2.0], [10.0, 1, 1.0806045532226562], [20.0, 2, -0.832293689250946], [30.0, 3, -1.979984998703003], [40.0, 4, -1.3072872161865234]], "bar/square": [[0.0, 0, 0.0], [10.0, 1, 1.0], [20.0, 2, 4.0], [30.0, 3, 9.0], [40.0, 4, 16.0]], "foo/square": [[0.0, 0, 0.0], [10.0, 1, 2.0], [20.0, 2, 8.0], [30.0, 3, 18.0], [40.0, 4, 32.0]]}, "run1": {"foo/sin": [[0.0, 0, 0.0], [10.0, 1, 0.8414709568023682], [20.0, 2, 0.9092974066734314], [30.0, 3, 0.14112000167369843], [40.0, 4, -0.756802499294281]], "foo/cos": [[0.0, 0, 1.0], [10.0, 1, 0.5403022766113281], [20.0, 2, -0.416146844625473], [30.0, 3, -0.9899924993515015], [40.0, 4, -0.6536436080932617]], "bar/square": [[0.0, 0, 0.0], [10.0, 1, 1.0], [20.0, 2, 4.0], [30.0, 3, 9.0], [40.0, 4, 16.0]], "foo/square": [[0.0, 0, 0.0], [10.0, 1, 1.0], [20.0, 2, 4.0], [30.0, 3, 9.0], [40.0, 4, 16.0]]}} \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_bar_2Fsquare.json b/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_bar_2Fsquare.json deleted file mode 100644 index 6d584fb4a9e1cd0a6a56d3d87b7183f55ac52ba6..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_bar_2Fsquare.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, 0.0], [10.0, 1, 1.0], [20.0, 2, 4.0], [30.0, 3, 9.0], [40.0, 4, 16.0]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_foo_2Fcos.json b/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_foo_2Fcos.json deleted file mode 100644 index 025eaa16e93110da0c50ad03486786ee6e521700..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_foo_2Fcos.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, 1.0], [10.0, 1, 0.5403022766113281], [20.0, 2, -0.416146844625473], [30.0, 3, -0.9899924993515015], [40.0, 4, -0.6536436080932617]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_foo_2Fsin.json b/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_foo_2Fsin.json deleted file mode 100644 index eae69dd78f3b5aa75acec6b5daa08720fad9adba..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_foo_2Fsin.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, 0.0], [10.0, 1, 0.8414709568023682], [20.0, 2, 0.9092974066734314], [30.0, 3, 0.14112000167369843], [40.0, 4, -0.756802499294281]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_foo_2Fsquare.json b/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_foo_2Fsquare.json deleted file mode 100644 index 6d584fb4a9e1cd0a6a56d3d87b7183f55ac52ba6..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_foo_2Fsquare.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, 0.0], [10.0, 1, 1.0], [20.0, 2, 4.0], [30.0, 3, 9.0], [40.0, 4, 16.0]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/scalars_run_run2_tag_bar_2Fsquare.json b/tensorflow/tensorboard/demo/data/scalars_run_run2_tag_bar_2Fsquare.json deleted file mode 100644 index 6d584fb4a9e1cd0a6a56d3d87b7183f55ac52ba6..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/scalars_run_run2_tag_bar_2Fsquare.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, 0.0], [10.0, 1, 1.0], [20.0, 2, 4.0], [30.0, 3, 9.0], [40.0, 4, 16.0]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/scalars_run_run2_tag_foo_2Fcos.json b/tensorflow/tensorboard/demo/data/scalars_run_run2_tag_foo_2Fcos.json deleted file mode 100644 index dd3593f9d109e81bef5a10c732a9e08e60b3ef4f..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/scalars_run_run2_tag_foo_2Fcos.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, 2.0], [10.0, 1, 1.0806045532226562], [20.0, 2, -0.832293689250946], [30.0, 3, -1.979984998703003], [40.0, 4, -1.3072872161865234]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/scalars_run_run2_tag_foo_2Fsquare.json b/tensorflow/tensorboard/demo/data/scalars_run_run2_tag_foo_2Fsquare.json deleted file mode 100644 index 0ff9ef0551d0a3053ba16b502d0d6148057df660..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/demo/data/scalars_run_run2_tag_foo_2Fsquare.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, 0.0], [10.0, 1, 2.0], [20.0, 2, 8.0], [30.0, 3, 18.0], [40.0, 4, 32.0]] \ No newline at end of file diff --git a/tensorflow/tensorboard/http_api.md b/tensorflow/tensorboard/http_api.md deleted file mode 100644 index c2885daf93c29b5c39b68619d26623c666e28627..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/http_api.md +++ /dev/null @@ -1,402 +0,0 @@ -# Tensorboard client-server HTTP API - -## Runs, Tags, and Tag Types - -TensorBoard data is organized around the concept of a `run`, which represents -all the related data thrown off by a single execution of TensorFlow, a `tag`, -which groups values of data that come from the same source within a TensorFlow -run, and `tag types`, which are our way of distinguishing different types of -data that have fundamentally different representations and should be processed -on different code paths. For example, a "train" run may have a `scalars` -tag that represents the learning rate, another `scalars` tag that -represents the value of the objective function, a `histograms` tag that reveals -information on weights in a particular layer over time, and an `images` tag that -shows input images flowing into the system. The "eval" run might have an -entirely different set of tag names, or some duplicated tag names. - -The currently supported tag types are `scalars`, `images`, `audio`, -`histograms`, `graph` and `run_metadata`. Each tag type corresponds to a route -(documented below) for retrieving tag data of that type. - -All of the data provided comes from TensorFlow events files ('\*.tfevents\*'), -which are written using the SummaryWriter class -(tensorflow/python/training/summary_writer.py), and the data is generated by -summary ops (tensorflow/python/ops/summary_ops.py). The `scalars` come from the -`ScalarSummary` op, the `histograms` from the `HistogramSummary`, the `audio` -from the `AudioSummary`, and the `images` from `ImageSummary`. The tag type -`graph` is special in that it is not a collection of tags of that type, but a -boolean denoting if there is a graph definition associated with the run. The tag -is provided to the summary op (usually as a constant). - -## `data/logdir` - -Returns a JSON object with a key "logdir" that maps to the `logdir` argument -(string) with which Tensorboard started up. Example: -`{logdir: '/foo/logdir/argument'}` - -The `logdir` argument is the path of the directory that contains events files. - -## `data/plugins_listing` - -Returns a dict mapping from plugin name to a boolean indicating whether the -plugin is active. A plugin might be inactive, for instance, if it lacks relevant -data. Every plugin has a key. This route helps the frontend avoid issuing -requests to an inactive plugin - the routes of an inactive plugin do not work. - -## `data/runs` - -Returns an array containing the names of all the runs known to the -TensorBoard backend at this time. Each entry is a string corresponding -to a single run. - -We guarantee that as new runs are created in the log directory, they -will always appear at the end of the list returned by this route. That -is, the order of runs is persistent, and the result of this route is an -“append-only” list. - -Example response: - - ["train_run", "eval"] - -## `/data/plugin/scalars/tags` - -Returns a dictionary mapping from `run_name` (quoted string) to arrays of -`tag_name` (quoted string), where each array contains the names of all -scalar tags present in the corresponding run. Here is an example: - - { - "train_run": ["xent", "loss", "learning_rate"], - "eval": ["precision", "recall"] - } - -Note that runs without any scalar tags are included as keys with value the -empty array. - -## `/data/plugin/scalars/scalars?run=foo&tag=bar` - -Returns an array of event_accumulator.SimpleValueEvents ([wall_time, step, -value]) for the given run and tag. wall_time is seconds since epoch. - -Example: - - [ - [1443856985.705543, 1448, 0.7461960315704346], # wall_time, step, value - [1443857105.704628, 3438, 0.5427092909812927], - [1443857225.705133, 5417, 0.5457325577735901], - ... - ] - -If the format parameter is set to 'csv', the response will instead be in CSV -format: - - Wall time,step,value - 1443856985.705543,1448,0.7461960315704346 - 1443857105.704628,3438,0.5427092909812927 - 1443857225.705133,5417,0.5457325577735901 - -## `/data/plugin/histograms/tags` - -Returns a dictionary mapping from `run_name` (quoted string) to arrays of -`tag_name` (quoted string), where each array contains the names of all -histogram tags present in the corresponding run. Here is an example: - - { - "train_run": ["foo_histogram", "bar_histogram"], - "eval": ["foo_histogram", "bar_histogram"] - } - -Note that runs without any histogram tags are included as keys with -value the empty array. - -## `/data/plugin/histograms/histograms?run=foo&tag=bar` - -Returns an array of event_accumulator.HistogramEvents ([wall_time, step, -HistogramValue]) for the given run and tag. A HistogramValue is [min, max, num, -sum, sum_squares, bucket_limit, bucket]. wall_time is seconds since epoch. - -Annotated Example: (note - real data is higher precision) - - [ - [ - 1443871386.185149, # wall_time - 235166, # step - [ - -0.66, # minimum value - 0.44, # maximum value - 8.0, # number of items in the histogram - -0.80, # sum of items in the histogram - 0.73, # sum of squares of items in the histogram - [-0.68, -0.62, -0.292, -0.26, -0.11, -0.10, -0.08, -0.07, -0.05, - -0.0525, -0.0434, -0.039, -0.029, -0.026, 0.42, 0.47, 1.8e+308], - # the right edge of each bucket - [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, - 1.0, 0.0] # the number of elements within each bucket - ] - ] - ] - -## `/data/plugin/distributions/tags` - -Returns a dictionary mapping from `run_name` (quoted string) to arrays of -`tag_name` (quoted string), where each array contains the names of all -distribution tags present in the corresponding run. Here is an example: - - { - "train_run": ["foo_histogram", "bar_histogram"], - "eval": ["foo_histogram", "bar_histogram"] - } - -Note that runs without any distribution tags are included as keys with -value the empty array. - -## `/data/plugin/distributions/distributions?run=foo&tag=bar` - -Returns an array of event_accumulator.CompressedHistogramEvents ([wall_time, -step, CompressedHistogramValues]) for the given run and tag. - -CompressedHistogramValues is a list of namedtuples with each tuple specifying -a basis point (bps) as well as an interpolated value of the histogram value -at that basis point. A basis point is 1/100 of a percent. - -The current compression strategy is to choose basis points that correspond to -the median and bands of 1SD, 2SD, and 3SDs around the median. Note that the -current compression strategy does not work well for representing multimodal -data -- this is something that will be improved in a later iteration. - -Annotated Example: (note - real data is higher precision) - - [ - [ - 1441154832.580509, # wall_time - 5, # step - [ [0, -3.67], # CompressedHistogramValue for 0th percentile - [2500, -4.19], # CompressedHistogramValue for 25th percentile - [5000, 6.29], - [7500, 1.64], - [10000, 3.67] - ] - ], - ... - ] - -## `/data/plugin/images/images?run=foo&tag=bar` - -Gets a sample of ImageMetadatas for the given run and tag. - -Returns an array of objects containing information about available images, -crucially including the query parameter that may be used to retrieve that image. -(See /data/plugin/images/individualImage for details.) - -For example: - - { - "width": 28, # width in pixels - "height": 28, # height in pixels - "wall_time": 1440210599.246, # time in seconds since epoch - "step": 63702821, # number of steps that have passed - "query": "index=0&tagname=input%2Fimage%2F2&run=train" - # param for /individualImage - } - -## `/data/plugin/images/individualImage?{{query}}` - -Retrieves an individual image. The image query should not be generated by the -frontend, but instead acquired from calling the /images route (the image -metadata objects contain the query to use). The response is the image itself -with mime-type 'image/png'. - -Note that the query is not guaranteed to always refer to the same image even -within a single run, as images may be removed from the sampling reservoir and -replaced with other images. (See Notes for details on the reservoir sampling.) - -An example call to this route would look like this: -/data/plugin/images/individualImage?index=0&tagname=input%2Fimage%2F2&run=train - -## `/data/plugin/images/tags` - -Returns a dictionary mapping from `run_name` (quoted string) to arrays of -`tag_name` (quoted string), where each array contains the names of all image -tags present in the corresponding run. Here is an example: - - { - "train": ["foo_image", "bar_image"], - "eval": ["foo_image", "bar_image"] - } - -Note that runs without any image tags are included as keys with value the empty -array. - -## `/data/plugin/audio/audio?run=foo&tag=bar` - -Gets a sample of AudioMetadatas for the given run and tag. - -Returns an array of objects containing information about available audio, -crucially including the query parameter that may be used to retrieve that audio. -(See /data/plugin/audio/individualAudio for details.) - -For example: - - { - "wall_time": 1440210599.246, # time in seconds since epoch - "step": 63702821, # number of steps that have passed - "content_type": "audio/wav" # the MIME-type of the audio - "query": "index=0&tagname=input%2Faudio%2F2&run=train" - # param for /individualAudio - } - -## `/data/plugin/audio/individualAudio?{{query}}` - -Retrieves an individual audio clip. The audio query should not be generated by -the frontend, but instead acquired from calling the /audio route (the audio -metadata objects contain the query to use). The response is the audio itself -with an appropriate Content-Type header set. - -Note that the query is not guaranteed to always refer to the same clip even -within a single run, as audio may be removed from the sampling reservoir and -replaced with other clips. (See Notes for details on the reservoir sampling.) - -An example call to this route would look like this: -/individualAudio?index=0&tagname=input%2Faudio%2F2&run=train - -## `/data/plugin/audio/tags` - -Returns a dictionary mapping from `run_name` (quoted string) to arrays of -`tag_name` (quoted string), where each array contains the names of all audio -tags present in the corresponding run. Here is an example: - - { - "train": ["foo_audio", "bar_audio"], - "eval": ["foo_audio", "bar_audio"], - } - -Note that runs without any audio tags are included as keys with value the empty -array. - -## `/data/plugin/graphs/runs` - -Returns a list of runs that have associated graphs. - -For example: - - ["train"] - -## `/data/plugin/graphs/graph?run=foo&limit_attr_size=1024&large_attrs_key=key` - -Returns the graph definition for the given run in pbtxt format. The -graph is composed of a list of nodes, where each node is a specific -TensorFlow operation which takes as inputs other nodes (operations). - -The query parameters `limit_attr_size` and `large_attrs_key` are optional. - -`limit_attr_size` specifies the maximum allowed size in bytes, before the -attribute is considered large and filtered out of the graph. If specified, -it must be an int and > 0. If not specified, no filtering is applied. - -`large_attrs_key` is the attribute key that will be used for storing -attributes that are too large. The value of this key (list of strings) -should be used by the client in order to determine which attributes -have been filtered. Must be specified if `limit_attr_size` is specified. - -For the query - - /data/plugin/graphs/graph?run=foo&limit_attr_size=1024&large_attrs_key=_too_large, - -here is an example pbtxt response of a graph with 3 nodes, where the second -node had two large attributes "a" and "b" that were filtered out (size > 1024): - - node { - op: "Input" - name: "A" - } - node { - op: "Input" - name: "B" - attr { - key: "small_attr" - value: { - s: "some string" - } - } - attr { - key: "_too_large" - value { - list { - s: "a" - s: "b" - } - } - } - } - node { - op: "MatMul" - name: "C" - input: "A" - input: "B" - } - -Prior to filtering, the original node "B" had the following content: - - node { - op: "Input" - name: "B" - attr { - key: "small_attr" - value: { - s: "some string" - } - } - attr { - key: "a" - value { Very large object... } - } - attr { - key: "b" - value { Very large object... } - } - } - -## `/data/run_metadata?run=foo&tag=bar` - -Given a run and tag, returns the metadata of a particular -`session.run()` as a gzipped, pbtxt serialized [`RunMetadata`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/config.proto) -proto. For example: - - step_stats { - dev_stats { - device: "/job:localhost/replica:0/task:0/cpu:0" - node_stats { - node_name: "_SOURCE" - all_start_micros: 1458337695775395 - op_start_rel_micros: 11 - op_end_rel_micros: 12 - all_end_rel_micros: 38 - memory { - allocator_name: "cpu" - } - timeline_label: "_SOURCE = NoOp()" - scheduled_micros: 1458337695775363 - } - } - } - -## Notes - -All returned values, histograms, audio, and images are returned in the order -they were written by TensorFlow (which should correspond to increasing -`wall_time` order, but may not necessarily correspond to increasing step count -if the process had to restart from a previous checkpoint). - -The returned values may be downsampled using reservoir sampling, which is -configurable by the TensorBoard server. When downsampling occurs, the server -guarantees that different tags will all sample at the same sequence of indices, -so that if there are two tags `A` and `B` which are related so that `A[i] ~ -B[i]` for all `i`, then `D(A)[i] ~ D(B)[i]` for all `i`, where `D` represents -the downsampling operation. - -The reservoir sampling puts an upper bound on the number of items that will be -returned for a given run-tag combination, and guarantees that all items are -equally likely to be in the final sample (ie it is a uniform distribution over -the values), with the proviso that the most recent individual item is always -included in the sample. - -The reservoir sizes are configurable on a per-tag type basis. diff --git a/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/BUILD b/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/BUILD deleted file mode 100644 index f1f7746ff846e549f3473412470bbff3970a7741..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/BUILD +++ /dev/null @@ -1,56 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -licenses(["notice"]) # Apache 2.0 - -java_binary( - name = "Vulcanize", - srcs = ["Vulcanize.java"], - jvm_flags = [ - "-Xss20m", # JSCompiler needs big stacks for recursive parsing - "-XX:+UseParallelGC", # Best GC when app isn't latency sensitive - "-Djava.util.logging.SimpleFormatter.format='%1$$tY-%1$$tm-%1$$td %1$$tH:%1$$tM:%1$$tS.%1$$tL %4$$-6s %5$$s%6$$s%n'", # Less log spam - ], - visibility = ["//visibility:public"], - deps = [ - "@com_google_guava", - "@com_google_protobuf_java", - "@io_bazel_rules_closure//closure/compiler", - "@io_bazel_rules_closure//java/io/bazel/rules/closure:webpath", - "@io_bazel_rules_closure//java/io/bazel/rules/closure/webfiles:build_info_java_proto", - "@io_bazel_rules_closure//java/org/jsoup/nodes", - "@org_jsoup", - ], -) - -java_binary( - name = "Zipper", - srcs = ["Zipper.java"], - visibility = ["//visibility:public"], - deps = [ - "@com_google_guava", - "@com_google_protobuf_java", - "@io_bazel_rules_closure//java/io/bazel/rules/closure/webfiles", - "@io_bazel_rules_closure//java/io/bazel/rules/closure/webfiles:build_info_java_proto", - ], -) - -# These JS files are always taken into consideration by the Closure Compiler -# when vulcanizing, per vulcanize.bzl. -filegroup( - name = "jslibs", - srcs = [ - # Ordering probably matters - "@com_google_javascript_closure_compiler_externs", - "@com_google_javascript_closure_compiler_externs_polymer", - "externs.js", - "@com_google_javascript_closure_library//:closure/goog/base.js", - "@com_google_javascript_closure_library//:closure/goog/deps.js", - ], - visibility = ["//visibility:public"], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/Vulcanize.java b/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/Vulcanize.java deleted file mode 100644 index 533907dd64dd84107d46dd7411235c4ff8aaa755..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/Vulcanize.java +++ /dev/null @@ -1,546 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.tensorflow.tensorboard.vulcanize; - -import static com.google.common.base.Preconditions.checkNotNull; -import static com.google.common.base.Verify.verify; -import static com.google.common.base.Verify.verifyNotNull; -import static java.nio.charset.StandardCharsets.UTF_8; - -import com.google.common.base.CharMatcher; -import com.google.common.base.Joiner; -import com.google.common.base.Optional; -import com.google.common.base.Splitter; -import com.google.common.collect.HashMultimap; -import com.google.common.collect.ImmutableMultimap; -import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Iterables; -import com.google.common.collect.Lists; -import com.google.common.collect.Multimap; -import com.google.javascript.jscomp.CheckLevel; -import com.google.javascript.jscomp.CompilationLevel; -import com.google.javascript.jscomp.Compiler; -import com.google.javascript.jscomp.CompilerOptions; -import com.google.javascript.jscomp.DiagnosticGroup; -import com.google.javascript.jscomp.DiagnosticGroups; -import com.google.javascript.jscomp.DiagnosticType; -import com.google.javascript.jscomp.JSError; -import com.google.javascript.jscomp.ModuleIdentifier; -import com.google.javascript.jscomp.PropertyRenamingPolicy; -import com.google.javascript.jscomp.Result; -import com.google.javascript.jscomp.SourceFile; -import com.google.javascript.jscomp.WarningsGuard; -import com.google.protobuf.TextFormat; -import io.bazel.rules.closure.Webpath; -import io.bazel.rules.closure.webfiles.BuildInfo.Webfiles; -import io.bazel.rules.closure.webfiles.BuildInfo.WebfilesSource; -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.nio.file.StandardOpenOption; -import java.util.ArrayDeque; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Deque; -import java.util.HashMap; -import java.util.HashSet; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.regex.Matcher; -import java.util.regex.Pattern; -import java.util.stream.Collectors; -import org.jsoup.Jsoup; -import org.jsoup.nodes.Attribute; -import org.jsoup.nodes.Comment; -import org.jsoup.nodes.DataNode; -import org.jsoup.nodes.Document; -import org.jsoup.nodes.Element; -import org.jsoup.nodes.Html5Printer; -import org.jsoup.nodes.Node; -import org.jsoup.nodes.TextNode; -import org.jsoup.parser.Parser; -import org.jsoup.parser.Tag; - -/** Simple one-off solution for TensorBoard vulcanization. */ -public final class Vulcanize { - - private static final Pattern IGNORE_PATHS_PATTERN = - Pattern.compile("/(?:polymer|marked-element)/.*"); - - private static final ImmutableSet EXTRA_JSDOC_TAGS = - ImmutableSet.of("attribute", "hero", "group", "required"); - - private static final Pattern WEBPATH_PATTERN = Pattern.compile("//~~WEBPATH~~([^\n]+)"); - - private static final Parser parser = Parser.htmlParser(); - private static final Map webfiles = new HashMap<>(); - private static final Set alreadyInlined = new HashSet<>(); - private static final Set legalese = new HashSet<>(); - private static final List licenses = new ArrayList<>(); - private static final List stack = new ArrayList<>(); - private static final List externs = new ArrayList<>(); - private static final List sourcesFromJsLibraries = new ArrayList<>(); - private static final Map sourcesFromScriptTags = new LinkedHashMap<>(); - private static final Map sourceTags = new LinkedHashMap<>(); - private static final Multimap suppressions = HashMultimap.create(); - private static CompilationLevel compilationLevel; - private static Webpath outputPath; - private static Node firstCompiledScript; - private static Node licenseComment; - private static int insideDemoSnippet; - private static boolean testOnly; - - public static void main(String[] args) throws IOException { - compilationLevel = CompilationLevel.fromString(args[0]); - testOnly = args[1].equals("true"); - Webpath inputPath = Webpath.get(args[2]); - outputPath = Webpath.get(args[3]); - Path output = Paths.get(args[4]); - for (int i = 5; i < args.length; i++) { - if (args[i].endsWith(".js")) { - String code = new String(Files.readAllBytes(Paths.get(args[i])), UTF_8); - SourceFile sourceFile = SourceFile.fromCode(args[i], code); - if (code.contains("@externs")) { - externs.add(sourceFile); - } else { - sourcesFromJsLibraries.add(sourceFile); - } - continue; - } - if (!args[i].endsWith(".pbtxt")) { - continue; - } - Webfiles manifest = loadWebfilesPbtxt(Paths.get(args[i])); - for (WebfilesSource src : manifest.getSrcList()) { - webfiles.put(Webpath.get(src.getWebpath()), Paths.get(src.getPath())); - } - } - stack.add(inputPath); - Document document = parse(Files.readAllBytes(webfiles.get(inputPath))); - transform(document); - compile(); - if (licenseComment != null) { - licenseComment.attr("comment", String.format("\n%s\n", Joiner.on("\n\n").join(licenses))); - } - Files.write( - output, - Html5Printer.stringify(document).getBytes(UTF_8), - StandardOpenOption.WRITE, - StandardOpenOption.CREATE, - StandardOpenOption.TRUNCATE_EXISTING); - } - - private static void transform(Node root) throws IOException { - Node node = checkNotNull(root); - Node newNode; - while (true) { - newNode = enterNode(node); - if (node.equals(root)) { - root = newNode; - } - node = newNode; - if (node.childNodeSize() > 0) { - node = node.childNode(0); - } else { - while (true) { - newNode = leaveNode(node); - if (node.equals(root)) { - root = newNode; - } - node = newNode; - if (node.equals(root)) { - return; - } - Node next = node.nextSibling(); - if (next == null) { - if (node.parentNode() == null) { - return; - } - node = verifyNotNull(node.parentNode(), "unexpected root: %s", node); - } else { - node = next; - break; - } - } - } - } - } - - private static Node enterNode(Node node) throws IOException { - if (node.nodeName().equals("demo-snippet")) { - insideDemoSnippet++; - } - if (insideDemoSnippet > 0) { - return node; - } - if (node instanceof Element) { - if (!getAttrTransitive(node, "vulcanize-noinline").isPresent()) { - if (node.nodeName().equals("link") && node.attr("rel").equals("import")) { - // Inline HTML. - node = visitHtmlImport(node); - } else if (node.nodeName().equals("script") - && !shouldIgnoreUri(node.attr("src")) - && !node.hasAttr("jscomp-ignore")) { - node = visitScript(node); - } else if (node.nodeName().equals("link") - && node.attr("rel").equals("stylesheet") - && !node.attr("href").isEmpty() - && !shouldIgnoreUri(node.attr("href"))) { - node = visitStylesheet(node); - } - } - rootifyAttribute(node, "href"); - rootifyAttribute(node, "src"); - rootifyAttribute(node, "action"); - rootifyAttribute(node, "assetpath"); - } else if (node instanceof Comment) { - String text = ((Comment) node).getData(); - if (text.contains("@license")) { - handleLicense(text); - if (licenseComment == null) { - licenseComment = node; - } else { - node = replaceNode(node, new TextNode("", node.baseUri())); - } - } else { - node = replaceNode(node, new TextNode("", node.baseUri())); - } - } - return node; - } - - private static Node leaveNode(Node node) { - if (node instanceof Document) { - stack.remove(stack.size() - 1); - } else if (node.nodeName().equals("demo-snippet")) { - insideDemoSnippet--; - } - return node; - } - - private static Node visitHtmlImport(Node node) throws IOException { - Webpath href = me().lookup(Webpath.get(node.attr("href"))); - if (alreadyInlined.add(href)) { - stack.add(href); - Document subdocument = parse(Files.readAllBytes(getWebfile(href))); - for (Attribute attr : node.attributes()) { - subdocument.attr(attr.getKey(), attr.getValue()); - } - return replaceNode(node, subdocument); - } else { - return replaceNode(node, new TextNode("", node.baseUri())); - } - } - - private static Node visitScript(Node node) throws IOException { - Webpath path; - String script; - if (node.attr("src").isEmpty()) { - path = makeSyntheticName(".js"); - script = getInlineScriptFromNode(node); - } else { - path = me().lookup(Webpath.get(node.attr("src"))); - script = new String(Files.readAllBytes(getWebfile(path)), UTF_8); - } - if (node.attr("src").endsWith(".min.js") - || getAttrTransitive(node, "jscomp-nocompile").isPresent()) { - Node newScript = - new Element(Tag.valueOf("script"), node.baseUri(), node.attributes()) - .appendChild(new DataNode(script, node.baseUri())) - .removeAttr("src") - .removeAttr("jscomp-nocompile"); - if (firstCompiledScript != null) { - firstCompiledScript.before(newScript); - return replaceNode(node, new TextNode("", node.baseUri())); - } else { - return replaceNode(node, newScript); - } - } else { - if (firstCompiledScript == null) { - firstCompiledScript = node; - } - sourcesFromScriptTags.put(path, script); - sourceTags.put(path, node); - Optional suppress = getAttrTransitive(node, "jscomp-suppress"); - if (suppress.isPresent()) { - if (suppress.get().isEmpty()) { - suppressions.put(path, "*"); - } else { - suppressions.putAll(path, Splitter.on(' ').split(suppress.get())); - } - } - return node; - } - } - - private static Node visitStylesheet(Node node) throws IOException { - Webpath href = me().lookup(Webpath.get(node.attr("href"))); - return replaceNode( - node, - new Element(Tag.valueOf("style"), node.baseUri(), node.attributes()) - .appendChild( - new DataNode( - new String(Files.readAllBytes(getWebfile(href)), UTF_8), node.baseUri())) - .removeAttr("rel") - .removeAttr("href")); - } - - private static Optional getAttrTransitive(Node node, String attr) { - while (node != null) { - if (node.hasAttr(attr)) { - return Optional.of(node.attr(attr)); - } - node = node.parent(); - } - return Optional.absent(); - } - - private static Node replaceNode(Node oldNode, Node newNode) { - oldNode.replaceWith(newNode); - return newNode; - } - - private static Path getWebfile(Webpath path) { - return verifyNotNull(webfiles.get(path), "Bad ref: %s -> %s", me(), path); - } - - private static void compile() { - if (sourcesFromScriptTags.isEmpty()) { - return; - } - - CompilerOptions options = new CompilerOptions(); - compilationLevel.setOptionsForCompilationLevel(options); - - // Nice options. - options.setColorizeErrorOutput(true); - options.setContinueAfterErrors(true); - options.setLanguageIn(CompilerOptions.LanguageMode.ECMASCRIPT_2016); - options.setLanguageOut(CompilerOptions.LanguageMode.ECMASCRIPT5); - options.setGenerateExports(true); - options.setStrictModeInput(false); - options.setExtraAnnotationNames(EXTRA_JSDOC_TAGS); - - // So we can chop JS binary back up into the original script tags. - options.setPrintInputDelimiter(true); - options.setInputDelimiter("//~~WEBPATH~~%name%"); - - // Optimizations that are too advanced for us right now. - options.setPropertyRenaming(PropertyRenamingPolicy.OFF); - options.setCheckGlobalThisLevel(CheckLevel.OFF); - options.setRemoveUnusedPrototypeProperties(false); - options.setRemoveUnusedPrototypePropertiesInExterns(false); - options.setRemoveUnusedClassProperties(false); - - // Dependency management. - options.setClosurePass(true); - options.setManageClosureDependencies(true); - options.getDependencyOptions().setDependencyPruning(true); - options.getDependencyOptions().setDependencySorting(true); - options.getDependencyOptions().setMoocherDropping(false); - options.getDependencyOptions() - .setEntryPoints( - sourceTags - .keySet() - .stream() - .map(Webpath::toString) - .map(ModuleIdentifier::forFile) - .collect(Collectors.toList())); - - // Polymer pass. - options.setPolymerVersion(1); - - // Debug flags. - if (testOnly) { - options.setPrettyPrint(true); - options.setGeneratePseudoNames(true); - options.setExportTestFunctions(true); - } - - // Don't print warnings from " - sanitized = "<script>alert('xss')</script>" - self.assertEqual(text_plugin.markdown_and_sanitize(dangerous), sanitized) - - dangerous = textwrap.dedent("""\ - hello *you*""") - sanitized = '

hello you

' - self.assertEqual(text_plugin.markdown_and_sanitize(dangerous), sanitized) - - def testTableGeneration(self): - array2d = np.array([['one', 'two'], ['three', 'four']]) - expected_table = textwrap.dedent("""\ - - - - - - - - - - - -
onetwo
threefour
""") - self.assertEqual(text_plugin.make_table(array2d), expected_table) - - expected_table_with_headers = textwrap.dedent("""\ - - - - - - - - - - - - - - - - - -
c1c2
onetwo
threefour
""") - - actual_with_headers = text_plugin.make_table(array2d, headers=['c1', 'c2']) - self.assertEqual(actual_with_headers, expected_table_with_headers) - - array_1d = np.array(['one', 'two', 'three', 'four', 'five']) - expected_1d = textwrap.dedent("""\ - - - - - - - - - - - - - - - - - - -
one
two
three
four
five
""") - self.assertEqual(text_plugin.make_table(array_1d), expected_1d) - - expected_1d_with_headers = textwrap.dedent("""\ - - - - - - - - - - - - - - - - - - - - - - - -
X
one
two
three
four
five
""") - actual_1d_with_headers = text_plugin.make_table(array_1d, headers=['X']) - self.assertEqual(actual_1d_with_headers, expected_1d_with_headers) - - def testMakeTableExceptions(self): - # Verify that contents is being type-checked and shape-checked. - with self.assertRaises(ValueError): - text_plugin.make_table([]) - - with self.assertRaises(ValueError): - text_plugin.make_table('foo') - - with self.assertRaises(ValueError): - invalid_shape = np.full((3, 3, 3), 'nope', dtype=np.dtype('S3')) - text_plugin.make_table(invalid_shape) - - # Test headers exceptions in 2d array case. - test_array = np.full((3, 3), 'foo', dtype=np.dtype('S3')) - with self.assertRaises(ValueError): - # Headers is wrong type. - text_plugin.make_table(test_array, headers='foo') - with self.assertRaises(ValueError): - # Too many headers. - text_plugin.make_table(test_array, headers=['foo', 'bar', 'zod', 'zoink']) - with self.assertRaises(ValueError): - # headers is 2d - text_plugin.make_table(test_array, headers=test_array) - - # Also make sure the column counting logic works in the 1d array case. - test_array = np.array(['foo', 'bar', 'zod']) - with self.assertRaises(ValueError): - # Too many headers. - text_plugin.make_table(test_array, headers=test_array) - - def test_reduce_to_2d(self): - - def make_range_array(dim): - """Produce an incrementally increasing multidimensional array. - - Args: - dim: the number of dimensions for the array - - Returns: - An array of increasing integer elements, with dim dimensions and size - two in each dimension. - - Example: rangeArray(2) results in [[0,1],[2,3]]. - """ - return np.array(range(2**dim)).reshape([2] * dim) - - for i in range(2, 5): - actual = text_plugin.reduce_to_2d(make_range_array(i)) - expected = make_range_array(2) - np.testing.assert_array_equal(actual, expected) - - def test_text_array_to_html(self): - - convert = text_plugin.text_array_to_html - scalar = np.array('foo') - scalar_expected = '

foo

' - self.assertEqual(convert(scalar), scalar_expected) - - vector = np.array(['foo', 'bar']) - vector_expected = textwrap.dedent("""\ - - - - - - - - - -

foo

bar

""") - self.assertEqual(convert(vector), vector_expected) - - d2 = np.array([['foo', 'bar'], ['zoink', 'zod']]) - d2_expected = textwrap.dedent("""\ - - - - - - - - - - - -

foo

bar

zoink

zod

""") - self.assertEqual(convert(d2), d2_expected) - - d3 = np.array([[['foo', 'bar'], ['zoink', 'zod']], [['FOO', 'BAR'], - ['ZOINK', 'ZOD']]]) - - warning = text_plugin.markdown_and_sanitize(text_plugin.WARNING_TEMPLATE % - 3) - d3_expected = warning + textwrap.dedent("""\ - - - - - - - - - - - -

foo

bar

zoink

zod

""") - self.assertEqual(convert(d3), d3_expected) - - def testPluginIsActive(self): - plugin = text_plugin.TextPlugin() - multiplexer = event_multiplexer.EventMultiplexer() - plugin.get_plugin_apps(event_multiplexer.EventMultiplexer(), None) - - # The plugin is inactive because text summaries are not available. - self.assertFalse(plugin.is_active()) - - multiplexer.AddRunsFromDirectory(self.logdir) - multiplexer.Reload() - - # The plugin is active because text summaries are available. - self.assertTrue(self.plugin.is_active()) - - def testUnicode(self): - self.assertConverted(u'

Iñtërnâtiônàlizætiøn⚡💩

', - 'Iñtërnâtiônàlizætiøn⚡💩') - - -if __name__ == '__main__': - tf.test.main() diff --git a/tensorflow/tensorboard/scripts/BUILD b/tensorflow/tensorboard/scripts/BUILD deleted file mode 100644 index 05425ee61d05e3a0e540106a8c313205562b347c..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/scripts/BUILD +++ /dev/null @@ -1,32 +0,0 @@ -# Description: -# Some useful scripts that are bundled with TensorBoard. - -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -py_binary( - name = "generate_testdata", - srcs = ["generate_testdata.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow:tensorflow_py", - "//third_party/py/numpy", - "@six_archive//:six", - ], -) - -py_binary( - name = "execrooter", - srcs = ["execrooter.py"], - srcs_version = "PY2AND3", - visibility = ["//visibility:public"], -) - -filegroup( - name = "all_files", - srcs = glob(["*"]), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/tensorboard/scripts/__init__.py b/tensorflow/tensorboard/scripts/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/tensorflow/tensorboard/scripts/execrooter.py b/tensorflow/tensorboard/scripts/execrooter.py deleted file mode 100644 index 65569b9151258dc692ec45223a4f9118ea803126..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/scripts/execrooter.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Utility for running programs in a symlinked execroot.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import json -import os -import shutil -import subprocess -import sys -import tempfile - - -def run(inputs, program, outputs): - """Creates temp symlink tree, runs program, and copies back outputs. - - Args: - inputs: List of fake paths to real paths, which are used for symlink tree. - program: List containing real path of program and its arguments. The - execroot directory will be appended as the last argument. - outputs: List of fake outputted paths to copy back to real paths. - Returns: - 0 if succeeded or nonzero if failed. - """ - root = tempfile.mkdtemp() - try: - cwd = os.getcwd() - for fake, real in inputs: - parent = os.path.join(root, os.path.dirname(fake)) - if not os.path.exists(parent): - os.makedirs(parent) - os.symlink(os.path.join(cwd, real), os.path.join(root, fake)) - if subprocess.call(program + [root]) != 0: - return 1 - for fake, real in outputs: - shutil.copyfile(os.path.join(root, fake), real) - return 0 - finally: - shutil.rmtree(root) - - -def main(args): - """Invokes run function using a JSON file config. - - Args: - args: CLI args, which can be a JSON file containing an object whose - attributes are the parameters to the run function. If multiple JSON - files are passed, their contents are concatenated. - Returns: - 0 if succeeded or nonzero if failed. - Raises: - Exception: If input data is missing. - """ - if not args: - raise Exception('Please specify at least one JSON config path') - inputs = [] - program = [] - outputs = [] - for arg in args: - with open(arg) as fd: - config = json.load(fd) - inputs.extend(config.get('inputs', [])) - program.extend(config.get('program', [])) - outputs.extend(config.get('outputs', [])) - if not program: - raise Exception('Please specify a program') - return run(inputs, program, outputs) - - -if __name__ == '__main__': - sys.exit(main(sys.argv[1:])) diff --git a/tensorflow/tensorboard/scripts/generate_testdata.py b/tensorflow/tensorboard/scripts/generate_testdata.py deleted file mode 100644 index f191d16a82dc9f771ea4f1d42a510625c157d119..0000000000000000000000000000000000000000 --- a/tensorflow/tensorboard/scripts/generate_testdata.py +++ /dev/null @@ -1,225 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Generate some standard test data for debugging TensorBoard. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import bisect -import math -import os -import os.path -import random -import shutil - -import numpy as np -from six.moves import xrange # pylint: disable=redefined-builtin -import tensorflow as tf - - -tf.flags.DEFINE_string("target", None, """The directory where serialized data -will be written""") - -tf.flags.DEFINE_boolean("overwrite", False, """Whether to remove and overwrite -TARGET if it already exists.""") - -FLAGS = tf.flags.FLAGS - -# Hardcode a start time and reseed so script always generates the same data. -_start_time = 0 -random.seed(0) - - -def _MakeHistogramBuckets(): - v = 1E-12 - buckets = [] - neg_buckets = [] - while v < 1E20: - buckets.append(v) - neg_buckets.append(-v) - v *= 1.1 - # Should include DBL_MAX, but won't bother for test data. - return neg_buckets[::-1] + [0] + buckets - - -def _MakeHistogram(values): - """Convert values into a histogram proto using logic from histogram.cc.""" - limits = _MakeHistogramBuckets() - counts = [0] * len(limits) - for v in values: - idx = bisect.bisect_left(limits, v) - counts[idx] += 1 - - limit_counts = [(limits[i], counts[i]) for i in xrange(len(limits)) - if counts[i]] - bucket_limit = [lc[0] for lc in limit_counts] - bucket = [lc[1] for lc in limit_counts] - sum_sq = sum(v * v for v in values) - return tf.HistogramProto( - min=min(values), - max=max(values), - num=len(values), - sum=sum(values), - sum_squares=sum_sq, - bucket_limit=bucket_limit, - bucket=bucket) - - -def WriteScalarSeries(writer, tag, f, n=5): - """Write a series of scalar events to writer, using f to create values.""" - step = 0 - wall_time = _start_time - for i in xrange(n): - v = f(i) - value = tf.Summary.Value(tag=tag, simple_value=v) - summary = tf.Summary(value=[value]) - event = tf.Event(wall_time=wall_time, step=step, summary=summary) - writer.add_event(event) - step += 1 - wall_time += 10 - - -def WriteHistogramSeries(writer, tag, mu_sigma_tuples, n=20): - """Write a sequence of normally distributed histograms to writer.""" - step = 0 - wall_time = _start_time - for [mean, stddev] in mu_sigma_tuples: - data = [random.normalvariate(mean, stddev) for _ in xrange(n)] - histo = _MakeHistogram(data) - summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=histo)]) - event = tf.Event(wall_time=wall_time, step=step, summary=summary) - writer.add_event(event) - step += 10 - wall_time += 100 - - -def WriteImageSeries(writer, tag, n_images=1): - """Write a few dummy images to writer.""" - step = 0 - session = tf.Session() - p = tf.placeholder("uint8", (1, 4, 4, 3)) - s = tf.summary.image(tag, p) - for _ in xrange(n_images): - im = np.random.random_integers(0, 255, (1, 4, 4, 3)) - summ = session.run(s, feed_dict={p: im}) - writer.add_summary(summ, step) - step += 20 - session.close() - - -def WriteAudioSeries(writer, tag, n_audio=1): - """Write a few dummy audio clips to writer.""" - step = 0 - session = tf.Session() - - min_frequency_hz = 440 - max_frequency_hz = 880 - sample_rate = 4000 - duration_frames = sample_rate // 2 # 0.5 seconds. - frequencies_per_run = 1 - num_channels = 2 - - p = tf.placeholder("float32", (frequencies_per_run, duration_frames, - num_channels)) - s = tf.summary.audio(tag, p, sample_rate) - - for _ in xrange(n_audio): - # Generate a different frequency for each channel to show stereo works. - frequencies = np.random.random_integers( - min_frequency_hz, - max_frequency_hz, - size=(frequencies_per_run, num_channels)) - tiled_frequencies = np.tile(frequencies, (1, duration_frames)) - tiled_increments = np.tile( - np.arange(0, duration_frames), - (num_channels, 1)).T.reshape(1, duration_frames * num_channels) - tones = np.sin(2.0 * np.pi * tiled_frequencies * tiled_increments / - sample_rate) - tones = tones.reshape(frequencies_per_run, duration_frames, num_channels) - - summ = session.run(s, feed_dict={p: tones}) - writer.add_summary(summ, step) - step += 20 - session.close() - - -def GenerateTestData(path): - """Generates the test data directory.""" - run1_path = os.path.join(path, "run1") - os.makedirs(run1_path) - writer1 = tf.summary.FileWriter(run1_path) - WriteScalarSeries(writer1, "foo/square", lambda x: x * x) - WriteScalarSeries(writer1, "bar/square", lambda x: x * x) - WriteScalarSeries(writer1, "foo/sin", math.sin) - WriteScalarSeries(writer1, "foo/cos", math.cos) - WriteHistogramSeries(writer1, "histo1", [[0, 1], [0.3, 1], [0.5, 1], [0.7, 1], - [1, 1]]) - WriteImageSeries(writer1, "im1") - WriteImageSeries(writer1, "im2") - WriteAudioSeries(writer1, "au1") - - run2_path = os.path.join(path, "run2") - os.makedirs(run2_path) - writer2 = tf.summary.FileWriter(run2_path) - WriteScalarSeries(writer2, "foo/square", lambda x: x * x * 2) - WriteScalarSeries(writer2, "bar/square", lambda x: x * x * 3) - WriteScalarSeries(writer2, "foo/cos", lambda x: math.cos(x) * 2) - WriteHistogramSeries(writer2, "histo1", [[0, 2], [0.3, 2], [0.5, 2], [0.7, 2], - [1, 2]]) - WriteHistogramSeries(writer2, "histo2", [[0, 1], [0.3, 1], [0.5, 1], [0.7, 1], - [1, 1]]) - WriteImageSeries(writer2, "im1") - WriteAudioSeries(writer2, "au2") - - graph_def = tf.GraphDef() - node1 = graph_def.node.add() - node1.name = "a" - node1.op = "matmul" - node2 = graph_def.node.add() - node2.name = "b" - node2.op = "matmul" - node2.input.extend(["a:0"]) - - writer1.add_graph(graph_def) - node3 = graph_def.node.add() - node3.name = "c" - node3.op = "matmul" - node3.input.extend(["a:0", "b:0"]) - writer2.add_graph(graph_def) - writer1.close() - writer2.close() - - -def main(unused_argv=None): - target = FLAGS.target - if not target: - print("The --target flag is required.") - return -1 - if os.path.exists(target): - if FLAGS.overwrite: - if os.path.isdir(target): - shutil.rmtree(target) - else: - os.remove(target) - else: - print("Refusing to overwrite target %s without --overwrite" % target) - return -2 - GenerateTestData(target) - - -if __name__ == "__main__": - tf.app.run() diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index d01342827dc26a80ac0d7f829c4e093afcf76abb..83a9313e50a62bc962ad56be2ec47837bd5ff115 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -138,16 +138,20 @@ WIN_COPTS = [ "/DTF_COMPILE_LIBRARY", "/DEIGEN_HAS_C99_MATH", "/DTENSORFLOW_USE_EIGEN_THREADPOOL", + "/DEIGEN_AVOID_STL_ARRAY", + "/Iexternal/gemmlowp", + "/wd4018", # -Wno-sign-compare + "/U_HAS_EXCEPTIONS", "/D_HAS_EXCEPTIONS=1", "/EHsc", # -fno-exceptions ] # LINT.IfChange def tf_copts(): - return ([ + return (if_not_windows([ "-DEIGEN_AVOID_STL_ARRAY", "-Iexternal/gemmlowp", "-Wno-sign-compare", "-fno-exceptions", - ] + if_cuda(["-DGOOGLE_CUDA=1"]) + if_mkl(["-DINTEL_MKL=1", "-fopenmp",]) + if_android_arm( + ]) + if_cuda(["-DGOOGLE_CUDA=1"]) + if_mkl(["-DINTEL_MKL=1", "-fopenmp",]) + if_android_arm( ["-mfpu=neon"]) + if_x86(["-msse3"]) + select({ clean_dep("//tensorflow:android"): [ "-std=c++11", @@ -167,7 +171,7 @@ def tf_opts_nortti_if_android(): "-fno-rtti", "-DGOOGLE_PROTOBUF_NO_RTTI", "-DGOOGLE_PROTOBUF_NO_STATIC_INITIALIZER", - ]) + if_android_x86(["-msse4.1"]) + ]) # LINT.ThenChange(//tensorflow/contrib/android/cmake/CMakeLists.txt) @@ -1021,9 +1025,9 @@ def tf_py_wrap_cc(name, native.cc_binary( name=cc_library_name, srcs=[module_name + ".cc"], - copts=(copts + [ + copts=(copts + if_not_windows([ "-Wno-self-assign", "-Wno-sign-compare", "-Wno-write-strings" - ] + tf_extension_copts()), + ]) + tf_extension_copts()), linkopts=tf_extension_linkopts() + extra_linkopts, linkstatic=1, linkshared=1, diff --git a/tensorflow/tools/api/golden/tensorflow.-f-i-f-o-queue.pbtxt b/tensorflow/tools/api/golden/tensorflow.-f-i-f-o-queue.pbtxt index 72cc53244768ad515c0ce33b937a2eae3a9fd98a..a095616c00cfe8fb64413e2078ae1589a423d2f4 100644 --- a/tensorflow/tools/api/golden/tensorflow.-f-i-f-o-queue.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-f-i-f-o-queue.pbtxt @@ -55,6 +55,10 @@ tf_class { name: "from_list" argspec: "args=[\'index\', \'queues\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "is_closed" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "size" argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.-fixed-length-record-reader.pbtxt b/tensorflow/tools/api/golden/tensorflow.-fixed-length-record-reader.pbtxt index 5c77b3dd5cca6c7741764e6b4bcea82ef30a47fd..260c796fd65b90020eb2b8191645ffdb2402a4a4 100644 --- a/tensorflow/tools/api/golden/tensorflow.-fixed-length-record-reader.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-fixed-length-record-reader.pbtxt @@ -13,7 +13,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'record_bytes\', \'header_bytes\', \'footer_bytes\', \'hop_bytes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'record_bytes\', \'header_bytes\', \'footer_bytes\', \'hop_bytes\', \'name\', \'encoding\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "num_records_produced" diff --git a/tensorflow/tools/api/golden/tensorflow.-interactive-session.pbtxt b/tensorflow/tools/api/golden/tensorflow.-interactive-session.pbtxt index f5b0bae58d0d11d1fb0b83e3996a038f6254ccdc..0a3b81bf829f48e88e9c48ce26cdbb4207101a16 100644 --- a/tensorflow/tools/api/golden/tensorflow.-interactive-session.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-interactive-session.pbtxt @@ -34,7 +34,7 @@ tf_class { } member_method { name: "make_callable" - argspec: "args=[\'self\', \'fetches\', \'feed_list\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'fetches\', \'feed_list\', \'accept_options\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " } member_method { name: "partial_run" diff --git a/tensorflow/tools/api/golden/tensorflow.-l-m-d-b-reader.pbtxt b/tensorflow/tools/api/golden/tensorflow.-l-m-d-b-reader.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..f9b7e9bbca82858ca99e67d70cf93583ca75972f --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-l-m-d-b-reader.pbtxt @@ -0,0 +1,46 @@ +path: "tensorflow.LMDBReader" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "reader_ref" + mtype: "" + } + member { + name: "supports_serialize" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'name\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "num_records_produced" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "num_work_units_completed" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read" + argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read_up_to" + argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reset" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "restore_state" + argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "serialize_state" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-padding-f-i-f-o-queue.pbtxt b/tensorflow/tools/api/golden/tensorflow.-padding-f-i-f-o-queue.pbtxt index 1bfe723ce754830efeebd7644871ff29f9809423..8fed133561544b91abfc64577e63a7088b43a007 100644 --- a/tensorflow/tools/api/golden/tensorflow.-padding-f-i-f-o-queue.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-padding-f-i-f-o-queue.pbtxt @@ -55,6 +55,10 @@ tf_class { name: "from_list" argspec: "args=[\'index\', \'queues\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "is_closed" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "size" argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.-priority-queue.pbtxt b/tensorflow/tools/api/golden/tensorflow.-priority-queue.pbtxt index dbe25f3a5b9ecc1596c77862396c684b6ddb9c5f..ebb017e81bc29e062d804fbe9f50c62f7b615dab 100644 --- a/tensorflow/tools/api/golden/tensorflow.-priority-queue.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-priority-queue.pbtxt @@ -55,6 +55,10 @@ tf_class { name: "from_list" argspec: "args=[\'index\', \'queues\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "is_closed" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "size" argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.-queue-base.pbtxt b/tensorflow/tools/api/golden/tensorflow.-queue-base.pbtxt index 9263d73a51161e9df083992528400b57302832d2..761f90989f316611d42580ee911e24bb3d0d2fec 100644 --- a/tensorflow/tools/api/golden/tensorflow.-queue-base.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-queue-base.pbtxt @@ -54,6 +54,10 @@ tf_class { name: "from_list" argspec: "args=[\'index\', \'queues\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "is_closed" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "size" argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.-random-shuffle-queue.pbtxt b/tensorflow/tools/api/golden/tensorflow.-random-shuffle-queue.pbtxt index ec783ffe5a01d66965d6370ec1bc6c83178b5a8c..f3ca84139311bc05478e3dce876b53f7b9dec883 100644 --- a/tensorflow/tools/api/golden/tensorflow.-random-shuffle-queue.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-random-shuffle-queue.pbtxt @@ -55,6 +55,10 @@ tf_class { name: "from_list" argspec: "args=[\'index\', \'queues\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "is_closed" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "size" argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.-session.pbtxt b/tensorflow/tools/api/golden/tensorflow.-session.pbtxt index 173cd1963e5e8c088556e8530b65ac1bdee99dc3..1d6b037f9c3540653a8fb18b6508f74b01da66ab 100644 --- a/tensorflow/tools/api/golden/tensorflow.-session.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-session.pbtxt @@ -34,7 +34,7 @@ tf_class { } member_method { name: "make_callable" - argspec: "args=[\'self\', \'fetches\', \'feed_list\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'fetches\', \'feed_list\', \'accept_options\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " } member_method { name: "partial_run" diff --git a/tensorflow/tools/api/golden/tensorflow.-tensor-shape.pbtxt b/tensorflow/tools/api/golden/tensorflow.-tensor-shape.pbtxt index d5b9cb8f5ed3cf088f5bd27809ff98f00801217d..8e3598fb2470b327e6e3601969f055d4907f614a 100644 --- a/tensorflow/tools/api/golden/tensorflow.-tensor-shape.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-tensor-shape.pbtxt @@ -54,6 +54,10 @@ tf_class { name: "merge_with" argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "most_specific_compatible_shape" + argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "num_elements" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/tensorflow.bitwise.pbtxt b/tensorflow/tools/api/golden/tensorflow.bitwise.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..1e4d333cc0bb0bb33fb4cc8d76badd30c8babaa4 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.bitwise.pbtxt @@ -0,0 +1,19 @@ +path: "tensorflow.bitwise" +tf_module { + member_method { + name: "bitwise_and" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "bitwise_or" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "bitwise_xor" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "invert" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..3a6f770153013dc925dc1b65a38ec59202c4b0b2 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt @@ -0,0 +1,38 @@ +path: "tensorflow.estimator.DNNClassifier" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "config" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "params" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'hidden_units\', \'feature_columns\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'activation_fn\', \'dropout\', \'input_layer_partitioner\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Adagrad\', \'\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "export_savedmodel" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "train" + argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..83e53d3960477b8170664c03ee30f588f87454b9 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt @@ -0,0 +1,38 @@ +path: "tensorflow.estimator.DNNLinearCombinedClassifier" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "config" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "params" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'model_dir\', \'linear_feature_columns\', \'linear_optimizer\', \'dnn_feature_columns\', \'dnn_optimizer\', \'dnn_hidden_units\', \'dnn_activation_fn\', \'dnn_dropout\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'input_layer_partitioner\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'Ftrl\', \'None\', \'Adagrad\', \'None\', \'\', \'None\', \'2\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "export_savedmodel" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "train" + argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..17f30a04fbfe7ffe464e7d107f8a9d9a27140188 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt @@ -0,0 +1,38 @@ +path: "tensorflow.estimator.DNNLinearCombinedRegressor" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "config" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "params" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'model_dir\', \'linear_feature_columns\', \'linear_optimizer\', \'dnn_feature_columns\', \'dnn_optimizer\', \'dnn_hidden_units\', \'dnn_activation_fn\', \'dnn_dropout\', \'label_dimension\', \'weight_column\', \'input_layer_partitioner\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'Ftrl\', \'None\', \'Adagrad\', \'None\', \'\', \'None\', \'1\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "export_savedmodel" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "train" + argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..edd68f0bb9ac8654dbc53e090d812de37a168515 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt @@ -0,0 +1,38 @@ +path: "tensorflow.estimator.DNNRegressor" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "config" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "params" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'hidden_units\', \'feature_columns\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'activation_fn\', \'dropout\', \'input_layer_partitioner\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Adagrad\', \'\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "export_savedmodel" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "train" + argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..3715dd5ec76284004efb24b0b6316d1eec87a589 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt @@ -0,0 +1,38 @@ +path: "tensorflow.estimator.LinearClassifier" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "config" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "params" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'feature_columns\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'config\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Ftrl\', \'None\', \'None\'], " + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "export_savedmodel" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "train" + argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..ccb4abf675f3c05a14990a5ae0da3068fc0d8a47 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt @@ -0,0 +1,38 @@ +path: "tensorflow.estimator.LinearRegressor" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "config" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "params" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'feature_columns\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'config\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Ftrl\', \'None\', \'None\'], " + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "export_savedmodel" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "train" + argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt index d69c475a313075a5b165dba9a80e30cf8212657d..801260c4507803345c4c84852fd83832b752ac12 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt @@ -22,6 +22,10 @@ tf_class { name: "keep_checkpoint_max" mtype: "" } + member { + name: "log_step_count_steps" + mtype: "" + } member { name: "master" mtype: "" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.pbtxt index 0d5dc73271dbc972c9177a6274f1632862f93ef0..07b04810b5c6d2eda3c3dce5ad4c35592158b085 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.pbtxt @@ -1,5 +1,21 @@ path: "tensorflow.estimator" tf_module { + member { + name: "DNNClassifier" + mtype: "" + } + member { + name: "DNNLinearCombinedClassifier" + mtype: "" + } + member { + name: "DNNLinearCombinedRegressor" + mtype: "" + } + member { + name: "DNNRegressor" + mtype: "" + } member { name: "Estimator" mtype: "" @@ -8,6 +24,14 @@ tf_module { name: "EstimatorSpec" mtype: "" } + member { + name: "LinearClassifier" + mtype: "" + } + member { + name: "LinearRegressor" + mtype: "" + } member { name: "ModeKeys" mtype: "" @@ -24,4 +48,12 @@ tf_module { name: "inputs" mtype: "" } + member_method { + name: "classifier_parse_example_spec" + argspec: "args=[\'feature_columns\', \'label_key\', \'label_dtype\', \'label_default\', \'weight_column\'], varargs=None, keywords=None, defaults=[\"\", \'None\', \'None\'], " + } + member_method { + name: "regressor_parse_example_spec" + argspec: "args=[\'feature_columns\', \'label_key\', \'label_dtype\', \'label_default\', \'label_dimension\', \'weight_column\'], varargs=None, keywords=None, defaults=[\"\", \'None\', \'1\', \'None\'], " + } } diff --git a/tensorflow/tools/api/golden/tensorflow.feature_column.pbtxt b/tensorflow/tools/api/golden/tensorflow.feature_column.pbtxt index 4c633a850f8e069135f122292bac019e2646aa61..2a57a845cdcb92d2c3e5d87e06d4e03886696be1 100644 --- a/tensorflow/tools/api/golden/tensorflow.feature_column.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.feature_column.pbtxt @@ -18,7 +18,7 @@ tf_module { } member_method { name: "categorical_column_with_vocabulary_list" - argspec: "args=[\'key\', \'vocabulary_list\', \'dtype\', \'default_value\'], varargs=None, keywords=None, defaults=[\'None\', \'-1\'], " + argspec: "args=[\'key\', \'vocabulary_list\', \'dtype\', \'default_value\', \'num_oov_buckets\'], varargs=None, keywords=None, defaults=[\'None\', \'-1\', \'0\'], " } member_method { name: "crossed_column" diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt index 342ee95f74d14e65fba9ece1ce1b1bd1924ab79c..781c0a6b4aa110adcc876b8e320f09f8f82fb9dd 100644 --- a/tensorflow/tools/api/golden/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.pbtxt @@ -108,6 +108,10 @@ tf_module { name: "InteractiveSession" mtype: "" } + member { + name: "LMDBReader" + mtype: "" + } member { name: "LogMessage" mtype: "" @@ -252,6 +256,10 @@ tf_module { name: "bfloat16" mtype: "" } + member { + name: "bitwise" + mtype: "" + } member { name: "bool" mtype: "" @@ -372,6 +380,10 @@ tf_module { name: "orthogonal_initializer" mtype: "" } + member { + name: "profiler" + mtype: "" + } member { name: "python_io" mtype: "" @@ -468,6 +480,10 @@ tf_module { name: "user_ops" mtype: "" } + member { + name: "variance_scaling_initializer" + mtype: "" + } member { name: "zeros_initializer" mtype: "" @@ -500,6 +516,10 @@ tf_module { name: "acos" argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "acosh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "add" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -522,19 +542,19 @@ tf_module { } member_method { name: "arg_max" - argspec: "args=[\'input\', \'dimension\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'input\', \'dimension\', \'output_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "arg_min" - argspec: "args=[\'input\', \'dimension\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'input\', \'dimension\', \'output_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "argmax" - argspec: "args=[\'input\', \'axis\', \'name\', \'dimension\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'input\', \'axis\', \'name\', \'dimension\', \'output_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \"\"], " } member_method { name: "argmin" - argspec: "args=[\'input\', \'axis\', \'name\', \'dimension\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'input\', \'axis\', \'name\', \'dimension\', \'output_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \"\"], " } member_method { name: "as_dtype" @@ -548,6 +568,10 @@ tf_module { name: "asin" argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "asinh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "assert_equal" argspec: "args=[\'x\', \'y\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " @@ -640,6 +664,10 @@ tf_module { name: "atan2" argspec: "args=[\'y\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "atanh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "batch_to_space" argspec: "args=[\'input\', \'crops\', \'block_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -1024,6 +1052,14 @@ tf_module { name: "global_variables_initializer" argspec: "args=[], varargs=None, keywords=None, defaults=None" } + member_method { + name: "glorot_normal_initializer" + argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"\"], " + } + member_method { + name: "glorot_uniform_initializer" + argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"\"], " + } member_method { name: "gradients" argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'False\', \'None\'], " @@ -1696,6 +1732,14 @@ tf_module { name: "sparse_placeholder" argspec: "args=[\'dtype\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } + member_method { + name: "sparse_reduce_max" + argspec: "args=[\'sp_input\', \'axis\', \'keep_dims\', \'reduction_axes\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + } + member_method { + name: "sparse_reduce_max_sparse" + argspec: "args=[\'sp_input\', \'axis\', \'keep_dims\', \'reduction_axes\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + } member_method { name: "sparse_reduce_sum" argspec: "args=[\'sp_input\', \'axis\', \'keep_dims\', \'reduction_axes\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " @@ -1732,6 +1776,10 @@ tf_module { name: "sparse_segment_sum" argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "sparse_slice" + argspec: "args=[\'sp_input\', \'start\', \'size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "sparse_softmax" argspec: "args=[\'sp_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-profiler.pbtxt b/tensorflow/tools/api/golden/tensorflow.profiler.-profiler.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..0fb363aca48031e13487d716a0375973f93b3dc8 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.profiler.-profiler.pbtxt @@ -0,0 +1,33 @@ +path: "tensorflow.profiler.Profiler" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'graph\', \'op_log\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_step" + argspec: "args=[\'self\', \'step\', \'run_meta\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "advise" + argspec: "args=[\'self\', \'options\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "profile_graph" + argspec: "args=[\'self\', \'options\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "profile_name_scope" + argspec: "args=[\'self\', \'options\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "profile_operations" + argspec: "args=[\'self\', \'options\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "profile_python" + argspec: "args=[\'self\', \'options\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.pbtxt b/tensorflow/tools/api/golden/tensorflow.profiler.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..b1a97f2e1057fe0b78be4b254527bd3f3d037d71 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.profiler.pbtxt @@ -0,0 +1,19 @@ +path: "tensorflow.profiler" +tf_module { + member { + name: "Profiler" + mtype: "" + } + member_method { + name: "advise" + argspec: "args=[\'graph\', \'run_meta\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'0\'], " + } + member_method { + name: "profile" + argspec: "args=[\'graph\', \'run_meta\', \'op_log\', \'cmd\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'scope\', \'0\'], " + } + member_method { + name: "write_op_log" + argspec: "args=[\'graph\', \'log_dir\', \'op_log\', \'run_meta\', \'add_trace\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.python_io.-t-f-record-writer.pbtxt b/tensorflow/tools/api/golden/tensorflow.python_io.-t-f-record-writer.pbtxt index af0c11ca14d4f38547a49ac511ee13e15847eb33..31775de2d12bcd2f214f5a04be7a92f49c594fde 100644 --- a/tensorflow/tools/api/golden/tensorflow.python_io.-t-f-record-writer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.python_io.-t-f-record-writer.pbtxt @@ -10,6 +10,10 @@ tf_class { name: "close" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "flush" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "write" argspec: "args=[\'self\', \'record\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/tensorflow.test.pbtxt b/tensorflow/tools/api/golden/tensorflow.test.pbtxt index 2a88f26ed02c7e2690c37180f76b965d7ffa87e0..6237207821ab18c8eb3e6148875e29e2e2fad773 100644 --- a/tensorflow/tools/api/golden/tensorflow.test.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.test.pbtxt @@ -30,7 +30,7 @@ tf_module { } member_method { name: "create_local_cluster" - argspec: "args=[\'num_workers\', \'num_ps\', \'protocol\'], varargs=None, keywords=None, defaults=[\'grpc\'], " + argspec: "args=[\'num_workers\', \'num_ps\', \'protocol\', \'worker_config\', \'ps_config\'], varargs=None, keywords=None, defaults=[\'grpc\', \'None\', \'None\'], " } member_method { name: "get_temp_dir" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-ftrl-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-ftrl-optimizer.pbtxt index 2dc11df57b60b15a797b1866743b27ea1068624e..5cff6087ef533f6674d6d7f1e0a8be425c16f2ad 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-ftrl-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-ftrl-optimizer.pbtxt @@ -17,7 +17,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'use_locking\', \'name\', \'accum_name\', \'linear_name\'], varargs=None, keywords=None, defaults=[\'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'False\', \'Ftrl\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'use_locking\', \'name\', \'accum_name\', \'linear_name\', \'l2_shrinkage_regularization_strength\'], varargs=None, keywords=None, defaults=[\'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'False\', \'Ftrl\', \'None\', \'None\', \'0.0\'], " } member_method { name: "apply_gradients" diff --git a/tensorflow/tools/api/golden/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.pbtxt index 58fd5760c11d29f063c0f7f66ea0a11d39a08a1e..89c299ae994bcd4f6ceb6daa632f985247d3db7f 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.pbtxt @@ -230,7 +230,7 @@ tf_module { } member_method { name: "MonitoredTrainingSession" - argspec: "args=[\'master\', \'is_chief\', \'checkpoint_dir\', \'scaffold\', \'hooks\', \'chief_only_hooks\', \'save_checkpoint_secs\', \'save_summaries_steps\', \'save_summaries_secs\', \'config\', \'stop_grace_period_secs\', \'log_step_count_steps\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'None\', \'None\', \'None\', \'None\', \'600\', \'100\', \'None\', \'None\', \'120\', \'100\'], " + argspec: "args=[\'master\', \'is_chief\', \'checkpoint_dir\', \'scaffold\', \'hooks\', \'chief_only_hooks\', \'save_checkpoint_secs\', \'save_summaries_steps\', \'save_summaries_secs\', \'config\', \'stop_grace_period_secs\', \'log_step_count_steps\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'None\', \'None\', \'None\', \'None\', \'600\', \'\', \'\', \'None\', \'120\', \'100\'], " } member_method { name: "NewCheckpointReader" @@ -304,6 +304,10 @@ tf_module { name: "import_meta_graph" argspec: "args=[\'meta_graph_or_file\', \'clear_devices\', \'import_scope\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'None\'], " } + member_method { + name: "init_from_checkpoint" + argspec: "args=[\'ckpt_dir_or_file\', \'assignment_map\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "input_producer" argspec: "args=[\'input_tensor\', \'element_shape\', \'num_epochs\', \'shuffle\', \'seed\', \'capacity\', \'shared_name\', \'summary_name\', \'name\', \'cancel_op\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'None\', \'32\', \'None\', \'None\', \'None\', \'None\'], " @@ -320,6 +324,18 @@ tf_module { name: "limit_epochs" argspec: "args=[\'tensor\', \'num_epochs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } + member_method { + name: "list_variables" + argspec: "args=[\'ckpt_dir_or_file\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "load_checkpoint" + argspec: "args=[\'ckpt_dir_or_file\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "load_variable" + argspec: "args=[\'ckpt_dir_or_file\', \'name\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "match_filenames_once" argspec: "args=[\'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.variance_scaling_initializer.pbtxt b/tensorflow/tools/api/golden/tensorflow.variance_scaling_initializer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a58398d645e8397dc8e61a6e0241710c3e34218f --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.variance_scaling_initializer.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.variance_scaling_initializer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'scale\', \'mode\', \'distribution\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'fan_in\', \'normal\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/lib/BUILD b/tensorflow/tools/api/lib/BUILD index cdfa0e7be524e3bb4ec039ac19bea72747afb58c..2d3b838957d60ffb5e827c6b43100d217cc5739e 100644 --- a/tensorflow/tools/api/lib/BUILD +++ b/tensorflow/tools/api/lib/BUILD @@ -22,7 +22,8 @@ py_library( srcs_version = "PY2AND3", deps = [ ":api_objects_proto_py", - "//tensorflow/tools/common:traverse", + "//tensorflow/python:platform", + "//tensorflow/python:util", ], ) diff --git a/tensorflow/tools/api/tests/BUILD b/tensorflow/tools/api/tests/BUILD index 8421d8fce28611f6049847f6fbca5538475b59af..e9aeeb385586e3abd129d9a475d89545efaca45b 100644 --- a/tensorflow/tools/api/tests/BUILD +++ b/tensorflow/tools/api/tests/BUILD @@ -23,11 +23,12 @@ py_test( srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:lib", "//tensorflow/python:platform", "//tensorflow/tools/api/lib:python_object_to_proto_visitor", "//tensorflow/tools/common:public_api", "//tensorflow/tools/common:traverse", - "@protobuf//:protobuf_python", ], ) diff --git a/tensorflow/tools/benchmark/benchmark_model.cc b/tensorflow/tools/benchmark/benchmark_model.cc index dfad11adf0b971748cbc64f9b86fd6cb2c7cdd37..81892aef96996c341ff750abf7a347059f6e9446 100644 --- a/tensorflow/tools/benchmark/benchmark_model.cc +++ b/tensorflow/tools/benchmark/benchmark_model.cc @@ -27,6 +27,7 @@ limitations under the License. #include #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/step_stats.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/graph/algorithm.h" diff --git a/tensorflow/tools/ci_build/Dockerfile.tensorboard b/tensorflow/tools/ci_build/Dockerfile.tensorboard deleted file mode 100644 index 9795872e2c4907908c288f8901d0a007f8d1dcaa..0000000000000000000000000000000000000000 --- a/tensorflow/tools/ci_build/Dockerfile.tensorboard +++ /dev/null @@ -1,11 +0,0 @@ -FROM ubuntu:14.04 - -MAINTAINER Jan Prach - -# Copy and run the install scripts. -COPY install/*.sh /install/ -RUN /install/install_bootstrap_deb_packages.sh -RUN add-apt-repository -y ppa:openjdk-r/ppa && \ - add-apt-repository -y ppa:george-edison55/cmake-3.x -RUN /install/install_deb_packages.sh -RUN /install/install_tensorboard_packages.sh diff --git a/tensorflow/tools/ci_build/builds/pip.sh b/tensorflow/tools/ci_build/builds/pip.sh index 85c712d3c6db353574fda40363d58dc328259430..7cb93c1774d9e587dabebf4325eb03799d815b3c 100755 --- a/tensorflow/tools/ci_build/builds/pip.sh +++ b/tensorflow/tools/ci_build/builds/pip.sh @@ -23,7 +23,7 @@ # # When executing the Python unit tests, the script obeys the shell # variables: TF_BUILD_BAZEL_CLEAN, TF_BUILD_INSTALL_EXTRA_PIP_PACKAGES, -# NO_TEST_ON_INSTALL +# NO_TEST_ON_INSTALL, PIP_TEST_ROOT # # TF_BUILD_BAZEL_CLEAN, if set to any non-empty and non-0 value, directs the # script to perform bazel clean prior to main build and test steps. @@ -41,6 +41,9 @@ # If NO_TEST_TFDBG_BINARIES has any non-empty and non-0 value, the testing of # TensorFlow Debugger (tfdbg) binaries and examples will be skipped. # +# If PIP_TEST_ROOT has a non-empty and a non-0 value, the whl files will be +# placed in that directory. +# # Any flags not listed in the usage above will be passed directly to Bazel. # # If the --test_tutorials flag is set, it will cause the script to run the @@ -162,7 +165,10 @@ echo "Python binary path to be used in PIP install: ${PYTHON_BIN_PATH} "\ "(Major.Minor version: ${PY_MAJOR_MINOR_VER})" # Build PIP Wheel file -PIP_TEST_ROOT="pip_test" +# Set default pip file folder unless specified by env variable +if [ -z "$PIP_TEST_ROOT" ]; then + PIP_TEST_ROOT="pip_test" +fi PIP_WHL_DIR="${PIP_TEST_ROOT}/whl" PIP_WHL_DIR=$(realpath ${PIP_WHL_DIR}) # Get absolute path rm -rf ${PIP_WHL_DIR} && mkdir -p ${PIP_WHL_DIR} diff --git a/tensorflow/tools/ci_build/ci_parameterized_build.sh b/tensorflow/tools/ci_build/ci_parameterized_build.sh index 1cf87d7c7c09613d2a7f265e5cc1b54a3e2ae47e..4ca08052bad707120ee4da243c384dfd5f65654f 100755 --- a/tensorflow/tools/ci_build/ci_parameterized_build.sh +++ b/tensorflow/tools/ci_build/ci_parameterized_build.sh @@ -358,7 +358,7 @@ if [[ "${TF_BUILD_APPEND_ARGUMENTS}" == *"--test_tag_filters="* ]]; then fi done else - EXTRA_ARGS="${TF_BUILD_APPEND_ARGUMENTS} --test_tag_filters=-benchmark-test" + EXTRA_ARGS="${TF_BUILD_APPEND_ARGUMENTS} --test_tag_filters=-no_oss,-benchmark-test" if [[ ${IS_MAC} == "1" ]]; then EXTRA_ARGS="${EXTRA_ARGS},-nomac" fi diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh index c9867796f3a3de9ab439859c6822e6fe02fab6dc..44fc21df9458c0880d1972603c93e1590e2b0643 100755 --- a/tensorflow/tools/ci_build/install/install_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh @@ -86,5 +86,6 @@ pip2 install mock pip2 install portpicker pip3 install portpicker -pip2 install backports.weakref==1.0rc1 -pip3 install backports.weakref==1.0rc1 +# TensorFlow Serving integration tests require the following: +pip2 install grpcio +pip3 install grpcio diff --git a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh index 33b3bc104bd04f61c7e1247a97cc7a850c2984fd..084ac49496cf576cabd8abeb2284692f47cb6649 100755 --- a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh @@ -89,6 +89,3 @@ pip3.5 install wheel==0.29.0 pip3.5 install portpicker pip3.5 install werkzeug - -pip3.5 install backports.weakref==1.0rc1 - diff --git a/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh b/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh index 467e4ab7e53ebd1c6985bcc908c9efdda10cef17..118e85fee0b55c735914f97394308ef8b92ca4c1 100755 --- a/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh +++ b/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh @@ -33,7 +33,7 @@ export PYTHON_BIN_PATH=`which python` yes "" | ./configure # Run bazel test command. Double test timeouts to avoid flakes. -bazel test --test_tag_filters=-gpu,-benchmark-test --test_lang_filters=cc -k \ +bazel test --test_tag_filters=-no_oss,-gpu,-benchmark-test --test_lang_filters=cc -k \ --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --build_tests_only \ --test_output=errors -- \ //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh b/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh index e2bbc0e8c0be0d1069eb85364ba8a137b950cb3a..4a30c73417551671b241d2cffc6a51fb7af6c50b 100755 --- a/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh +++ b/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh @@ -32,7 +32,7 @@ export PYTHON_BIN_PATH=`which python2` yes "" | ./configure # Run bazel test command. Double test timeouts to avoid flakes. -bazel test --test_tag_filters=-gpu,-benchmark-test --test_lang_filters=py -k \ +bazel test --test_tag_filters=-no_oss,-gpu,-benchmark-test --test_lang_filters=py -k \ --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --build_tests_only \ --test_output=errors -- \ //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh b/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh index a03cab0cca5c375e668a2adeae64c48ac2b217a0..d224878643e6e42675bd6463240d822007d51768 100755 --- a/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh +++ b/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh @@ -32,7 +32,7 @@ export PYTHON_BIN_PATH=`which python3` yes "" | ./configure # Run bazel test command. Double test timeouts to avoid flakes. -bazel test --test_tag_filters=-gpu,-benchmark-test -k \ +bazel test --test_tag_filters=-no_oss,-gpu,-benchmark-test -k \ --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \ --test_output=errors -- \ //tensorflow/contrib/... diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh b/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh index 32de5cea200d4a43e5885364a9aeeafd2fa51af6..39c241fb5393d2d475433856cf42acb790813cfd 100755 --- a/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh +++ b/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh @@ -32,7 +32,7 @@ export PYTHON_BIN_PATH=`which python3` yes "" | ./configure # Run bazel test command. Double test timeouts to avoid flakes. -bazel test --test_tag_filters=-gpu,-benchmark-test --test_lang_filters=py -k \ +bazel test --test_tag_filters=-no_oss,-gpu,-benchmark-test --test_lang_filters=py -k \ --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --build_tests_only \ --test_output=errors -- \ //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... diff --git a/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh b/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh index 6acc26213835c0d2924f9ee0a31a80790bf5d75e..4d7b4741d66c2131ced2ac57c29f6667a469e598 100755 --- a/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh +++ b/tensorflow/tools/ci_build/linux/gpu/run_cc_core.sh @@ -35,7 +35,7 @@ export TF_CUDA_COMPUTE_CAPABILITIES=3.7 yes "" | ./configure # Run bazel test command. Double test timeouts to avoid flakes. -bazel test --config=cuda --test_tag_filters=-no_gpu,-benchmark-test -k \ +bazel test --config=cuda --test_tag_filters=-no_oss,-no_gpu,-benchmark-test -k \ --test_lang_filters=cc --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \ --build_tests_only --test_output=errors --local_test_jobs=8 \ --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute -- \ diff --git a/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh b/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh index e73fe046c967b0bb3db6eb5b109516c0d207a1e4..c0bcec4a17db54445bd1ebafe9714f985bb28b2f 100755 --- a/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh +++ b/tensorflow/tools/ci_build/linux/gpu/run_py3_core.sh @@ -35,7 +35,7 @@ export TF_CUDA_COMPUTE_CAPABILITIES=3.7 yes "" | ./configure # Run bazel test command. Double test timeouts to avoid flakes. -bazel test --config=cuda --test_tag_filters=-no_gpu,-benchmark-test -k \ +bazel test --config=cuda --test_tag_filters=-no_oss,-no_gpu,-benchmark-test -k \ --test_lang_filters=py --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \ --build_tests_only --test_output=errors --local_test_jobs=8 \ --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute -- \ diff --git a/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh b/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh index e5f4a22f7ade7eb5c260a7a486cd5d3fa75d5859..0b8c73993f8e27cf58b3f7a8756e01c8304d837f 100755 --- a/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh +++ b/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh @@ -32,9 +32,8 @@ export TF_NEED_CUDA=0 export PYTHON_BIN_PATH=$(which python2) yes "" | ./configure which bazel -bazel test --test_tag_filters=-gpu,-benchmark-test,-nomac \ +bazel test --test_tag_filters=-no_oss,-gpu,-benchmark-test,-nomac \ --test_timeout 300,450,1200,3600 \ --test_size_filters=small,medium \ --jobs=${N_JOBS} --build_tests_only --test_output=errors -k -- \ - //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... \ - -//tensorflow/tensorboard/... + //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... diff --git a/tensorflow/tools/ci_build/update_version.sh b/tensorflow/tools/ci_build/update_version.sh index 682f5329f58fffa5f2030c7e33db14bd3e343165..b707ee338a2786ce3946d9e3d34da311b9f512f5 100755 --- a/tensorflow/tools/ci_build/update_version.sh +++ b/tensorflow/tools/ci_build/update_version.sh @@ -130,12 +130,6 @@ if [[ ${OLD_MAJOR} != ${MAJOR} ]] || [[ ${OLD_MINOR} != ${MINOR} ]]; then echo "Detected Major.Minor change. "\ "Updating pattern ${OLD_R_MAJOR_MINOR} to ${R_MAJOR_MINOR} in additional files" - # Update tensorflow/tensorboard/README.md - TENSORBOARD_README_MD="${TF_SRC_DIR}/tensorboard/README.md" - check_existence file "${TENSORBOARD_README_MD}" - sed -i -r -e "s/${OLD_R_MAJOR_MINOR}/${R_MAJOR_MINOR}/g" \ - "${TENSORBOARD_README_MD}" - # Update dockerfiles DEVEL_DOCKERFILE="${TF_SRC_DIR}/tools/docker/Dockerfile.devel" check_existence file "${DEVEL_DOCKERFILE}" diff --git a/tensorflow/tools/common/BUILD b/tensorflow/tools/common/BUILD index f92edd0dd8863fa7a3a6ad764a895370d48a5958..8a8667957ae4acb97356d4a141edd422509b48c7 100644 --- a/tensorflow/tools/common/BUILD +++ b/tensorflow/tools/common/BUILD @@ -15,6 +15,7 @@ py_library( name = "public_api", srcs = ["public_api.py"], srcs_version = "PY2AND3", + deps = ["//tensorflow/python:util"], ) py_test( @@ -32,6 +33,7 @@ py_library( name = "traverse", srcs = ["traverse.py"], srcs_version = "PY2AND3", + deps = ["//tensorflow/python:util"], ) py_test( diff --git a/tensorflow/tools/compatibility/BUILD b/tensorflow/tools/compatibility/BUILD index fb40cf0833f08fc142aec18fe8940ce836453906..19959ea6d260d5aded5a3f37850025f6722d82ee 100644 --- a/tensorflow/tools/compatibility/BUILD +++ b/tensorflow/tools/compatibility/BUILD @@ -24,7 +24,9 @@ py_test( srcs_version = "PY2AND3", deps = [ "tf_upgrade", - "//tensorflow:tensorflow_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "@six_archive//:six", ], ) diff --git a/tensorflow/tools/dist_test/python/mnist_replica.py b/tensorflow/tools/dist_test/python/mnist_replica.py index f7dbfea7fb0f3463cd708cde8762eb28b69b05a1..d21385e384a45f78fe85ccdd9c60a1836de991be 100644 --- a/tensorflow/tools/dist_test/python/mnist_replica.py +++ b/tensorflow/tools/dist_test/python/mnist_replica.py @@ -123,8 +123,6 @@ def main(unused_argv): is_chief = (FLAGS.task_index == 0) if FLAGS.num_gpus > 0: - if FLAGS.num_gpus < num_workers: - raise ValueError("number of gpus is less than number of workers") # Avoid gpu allocation conflict: now allocate task_num -> #gpu # for each worker in the corresponding machine gpu = (FLAGS.task_index % FLAGS.num_gpus) diff --git a/tensorflow/tools/dist_test/scripts/BUILD b/tensorflow/tools/dist_test/scripts/BUILD index c329f0bbe8779fe300e601a1f41d6c123688815a..ce2fa5c743ece40eae10b30f4b2626a9cfada147 100644 --- a/tensorflow/tools/dist_test/scripts/BUILD +++ b/tensorflow/tools/dist_test/scripts/BUILD @@ -17,6 +17,6 @@ py_test( srcs_version = "PY2AND3", deps = [ ":k8s_tensorflow_lib", - "//tensorflow/python:client_testlib", + "//tensorflow/python:platform_test", ], ) diff --git a/tensorflow/tools/docker/Dockerfile b/tensorflow/tools/docker/Dockerfile index 5b3f1f936a48bb448b712152c57c095226efea8e..07a972400df46f59c2d24b7b8e99bd690659b83a 100644 --- a/tensorflow/tools/docker/Dockerfile +++ b/tensorflow/tools/docker/Dockerfile @@ -24,14 +24,15 @@ RUN curl -O https://bootstrap.pypa.io/get-pip.py && \ rm get-pip.py RUN pip --no-cache-dir install \ + Pillow \ + h5py \ ipykernel \ jupyter \ matplotlib \ numpy \ + pandas \ scipy \ sklearn \ - pandas \ - Pillow \ && \ python -m ipykernel.kernelspec diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu index d0a038a9db61c97643678d9fbca8974df0f84c8f..e10565a064e2ec232d131366f40347c8a84137ad 100644 --- a/tensorflow/tools/docker/Dockerfile.devel-gpu +++ b/tensorflow/tools/docker/Dockerfile.devel-gpu @@ -19,6 +19,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ zlib1g-dev \ openjdk-8-jdk \ openjdk-8-jre-headless \ + wget \ && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* diff --git a/tensorflow/tools/docker/Dockerfile.gpu b/tensorflow/tools/docker/Dockerfile.gpu index 3ba1e963f92a0fd7294a36288785545962f40146..da83a300580b660bd2cea890eff8acc8a96103b2 100644 --- a/tensorflow/tools/docker/Dockerfile.gpu +++ b/tensorflow/tools/docker/Dockerfile.gpu @@ -24,14 +24,15 @@ RUN curl -O https://bootstrap.pypa.io/get-pip.py && \ rm get-pip.py RUN pip --no-cache-dir install \ + Pillow \ + h5py \ ipykernel \ jupyter \ matplotlib \ numpy \ + pandas \ scipy \ sklearn \ - pandas \ - Pillow \ && \ python -m ipykernel.kernelspec diff --git a/tensorflow/tools/docker/README.md b/tensorflow/tools/docker/README.md index 6d5a9bdc4ce1a197c3a3fcf5a0c4a48408c2bd92..3780bde2beeac389437627b012d95be7aa9dbbd2 100644 --- a/tensorflow/tools/docker/README.md +++ b/tensorflow/tools/docker/README.md @@ -55,7 +55,7 @@ for additional containers, such as release candidates or nightly builds. ## Rebuilding the containers Building TensorFlow Docker containers should be done through the -[parameterized_docker_build.sh](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/docker/README.md) +[parameterized_docker_build.sh](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/docker/parameterized_docker_build.sh) script. The raw Dockerfiles should not be used directly as they contain strings to be replaced by the script during the build. diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD index 8e27b133c2fa33a8f6366b0f94a596cf1ca7c1a2..45722ec9ebda36c380651108cb0727f4c2b958e5 100644 --- a/tensorflow/tools/docs/BUILD +++ b/tensorflow/tools/docs/BUILD @@ -37,6 +37,7 @@ py_library( srcs = ["parser.py"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], + deps = ["@com_github_andreif_codegen"], ) py_test( @@ -44,7 +45,6 @@ py_test( size = "small", srcs = ["parser_test.py"], srcs_version = "PY2AND3", - tags = ["manual"], deps = [ ":parser", "//tensorflow/python:platform_test", @@ -78,13 +78,10 @@ py_test( size = "small", srcs = ["generate_lib_test.py"], srcs_version = "PY2AND3", - tags = ["manual"], deps = [ ":generate_lib", ":parser", - "//tensorflow:tensorflow_py", "//tensorflow/python:platform_test", - "//tensorflow/python/debug:debug_py", ], ) @@ -105,7 +102,6 @@ py_test( srcs = ["build_docs_test.py"], data = ["//tensorflow:docs_src"], srcs_version = "PY2AND3", - tags = ["manual"], deps = [ ":generate_lib", "//tensorflow:tensorflow_py", diff --git a/tensorflow/tools/docs/build_docs_test.py b/tensorflow/tools/docs/build_docs_test.py index d28dd93b9a8d5eb19af414622c1d1b22516f9c1c..ae293f6576456ecdbb8a4b1ee4e8e4f40482ad94 100644 --- a/tensorflow/tools/docs/build_docs_test.py +++ b/tensorflow/tools/docs/build_docs_test.py @@ -19,6 +19,8 @@ from __future__ import division from __future__ import print_function import os +import sys +import textwrap import tensorflow as tf from tensorflow.python import debug as tf_debug @@ -29,19 +31,40 @@ from tensorflow.tools.docs import generate_lib class Flags(object): resource_root = resource_loader.get_root_dir_with_all_resources() - src_dir = os.path.join(resource_root, 'third_party/tensorflow/docs_src') - base_dir = os.path.join(resource_root, 'third_party/tensorflow/') + src_dir = os.path.join(resource_root, 'tensorflow/docs_src') + base_dir = os.path.join(resource_root, 'tensorflow/') output_dir = googletest.GetTempDir() class BuildDocsTest(googletest.TestCase): def testBuildDocs(self): + if sys.version_info >= (3, 0): + print('Warning: Doc generation is not supported from python3.') + return + doc_generator = generate_lib.DocGenerator() doc_generator.set_py_modules([('tf', tf), ('tfdbg', tf_debug)]) - status = doc_generator.build(Flags()) + try: + status = doc_generator.build(Flags()) + except RuntimeError as e: + if not e.args[0].startswith('Modules nested too deep'): + raise + + msg = textwrap.dedent("""\ + %s + + **************************************************************** + If this test fails here, you have most likely introduced an + unsealed module. Make sure to use `remove_undocumented` or similar + utilities to avoid leaking symbols. See above for more information + on the exact point of failure. + **************************************************************** + """ % e.args[0]) + + raise RuntimeError(msg) if status: self.fail('Found %s Errors!' % status) diff --git a/tensorflow/tools/docs/generate_lib.py b/tensorflow/tools/docs/generate_lib.py index 99872e1d8446ab84bcf77caeb86003d86db85e52..bbeb3921d7b75a9d06d99e0131e1886af3849f2a 100644 --- a/tensorflow/tools/docs/generate_lib.py +++ b/tensorflow/tools/docs/generate_lib.py @@ -20,6 +20,7 @@ from __future__ import print_function import argparse import os +import sys import six @@ -90,6 +91,7 @@ def write_docs(output_dir, parser_config, yaml_toc): # Parse and write Markdown pages, resolving cross-links (@{symbol}). for full_name, py_object in six.iteritems(parser_config.index): + parser_config.reference_resolver.current_doc_full_name = full_name if full_name in parser_config.duplicate_of: continue @@ -181,7 +183,7 @@ def add_dict_to_dict(add_from, add_to): # Exclude some libaries in contrib from the documentation altogether. def _get_default_private_map(): - return {} + return {'tf.test': ['mock']} # Exclude members of some libaries. @@ -390,6 +392,9 @@ def _other_docs(src_dir, output_dir, reference_resolver): print('Skipping excluded file %s...' % base_name) continue full_in_path = os.path.join(dirpath, base_name) + + reference_resolver.current_doc_full_name = full_in_path + suffix = os.path.relpath(path=full_in_path, start=src_dir) full_out_path = os.path.join(output_dir, suffix) if not base_name.endswith('.md'): @@ -415,6 +420,8 @@ class DocGenerator(object): """Main entry point for generating docs.""" def __init__(self): + if sys.version_info >= (3, 0): + sys.exit('Doc generation is not supported from python3.') self.argument_parser = argparse.ArgumentParser() self._py_modules = None self._private_map = _get_default_private_map() @@ -442,7 +449,7 @@ class DocGenerator(object): '--base_dir', type=str, default=default_base_dir, - help='Base directory to to strip from file names referenced in docs.') + help='Base directory to strip from file names referenced in docs.') def parse_known_args(self): flags, _ = self.argument_parser.parse_known_args() @@ -505,7 +512,6 @@ class DocGenerator(object): write_docs(output_dir, parser_config, yaml_toc=self.yaml_toc) _other_docs(flags.src_dir, flags.output_dir, reference_resolver) - if parser.all_errors: - print('Errors during processing:\n ' + '\n '.join(parser.all_errors)) - return 1 - return 0 + parser_config.reference_resolver.log_errors() + + return parser_config.reference_resolver.num_errors() diff --git a/tensorflow/tools/docs/generate_lib_test.py b/tensorflow/tools/docs/generate_lib_test.py index 6e5deb6a36ed7d7d8b51f28e7ed3d9a680fce13b..1ceaf31f1c3b83e2c2cb3c0d2022ce98781aed4b 100644 --- a/tensorflow/tools/docs/generate_lib_test.py +++ b/tensorflow/tools/docs/generate_lib_test.py @@ -21,9 +21,6 @@ from __future__ import print_function import os import sys -import tensorflow as tf - -from tensorflow.python import debug as tf_debug from tensorflow.python.platform import googletest from tensorflow.tools.docs import generate_lib from tensorflow.tools.docs import parser @@ -54,23 +51,10 @@ class DummyVisitor(object): class GenerateTest(googletest.TestCase): - def test_extraction(self): - py_modules = [('tf', tf), ('tfdbg', tf_debug)] - - try: - generate_lib.extract(py_modules, - generate_lib._get_default_private_map(), - generate_lib._get_default_do_not_descend_map()) - except RuntimeError: - print('*****************************************************************') - print('If this test fails, you have most likely introduced an unsealed') - print('module. Make sure to use remove_undocumented or similar utilities') - print('to avoid leaking symbols. See below for more information on the') - print('failure.') - print('*****************************************************************') - raise - def test_write(self): + if sys.version_info >= (3, 0): + self.skipTest('Warning: Doc generation is not supported from python3.') + module = sys.modules[__name__] index = { diff --git a/tensorflow/tools/docs/parser.py b/tensorflow/tools/docs/parser.py index 7ae1d2abd9af813d29e527f447b6ce21c8e72b82..18c3c98dc2b7c43b207c9035d861de3824aade9a 100644 --- a/tensorflow/tools/docs/parser.py +++ b/tensorflow/tools/docs/parser.py @@ -24,6 +24,7 @@ import functools import json import os import re +import sys import codegen import six @@ -35,13 +36,36 @@ from tensorflow.python.util import tf_inspect # A regular expression capturing a python indentifier. IDENTIFIER_RE = '[a-zA-Z_][a-zA-Z0-9_]*' -# Log of all reported errors -all_errors = [] +class _Errors(object): + """A collection of errors.""" -def log_error(s): - all_errors.append(s) - print('ERROR:', s) + def __init__(self): + self._errors = [] + + def log_all(self): + """Log all the collected errors to the standard error.""" + template = 'ERROR:\n output file name: %s\n %s\n\n' + + for full_name, message in self._errors: + print(template % (full_name, message), file=sys.stderr) + + def append(self, full_name, message): + """Add an error to the collection. + + Args: + full_name: The path to the file in which the error occurred. + message: The message to display with the error. + """ + self._errors.append((full_name, message)) + + def __len__(self): + return len(self._errors) + + def __eq__(self, other): + if not isinstance(other, _Errors): + return False + return self._errors == other._errors # pylint: disable=protected-access def documentation_path(full_name): @@ -107,6 +131,18 @@ class ReferenceResolver(object): self._all_names = set(is_class.keys()) self._py_module_names = py_module_names + self.current_doc_full_name = None + self._errors = _Errors() + + def add_error(self, message): + self._errors.append(self.current_doc_full_name, message) + + def log_errors(self): + self._errors.log_all() + + def num_errors(self): + return len(self._errors) + @classmethod def from_visitor(cls, visitor, doc_index, **kwargs): """A factory function for building a ReferenceResolver from a visitor. @@ -153,7 +189,8 @@ class ReferenceResolver(object): for key, value in self.__dict__.items(): # Drop these two fields. `_doc_index` is not serializable. `_all_names` is # generated by the constructor. - if key in ('_doc_index', '_all_names'): + if key in ('_doc_index', '_all_names', + '_errors', 'current_doc_full_name'): continue # Strip off any leading underscores on field names as these are not @@ -186,10 +223,10 @@ class ReferenceResolver(object): Returns: `string`, with "@{symbol}" references replaced by Markdown links. """ - return re.sub(SYMBOL_REFERENCE_RE, - lambda match: self._one_ref(match.group(1), # pylint: disable=g-long-lambda - relative_path_to_root), - string) + def one_ref(match): + return self._one_ref(match, relative_path_to_root) + + return re.sub(SYMBOL_REFERENCE_RE, one_ref, string) def python_link(self, link_text, ref_full_name, relative_path_to_root, code_ref=True): @@ -250,9 +287,8 @@ class ReferenceResolver(object): # Check whether this link exists if master_name not in self._all_names: - # TODO(josh11b): Make error reporting more uniform. - print('ERROR: Cannot make link to %s (original: %s): Not in index.' % - (master_name, ref_full_name)) + message = 'Cannot make link to "%s": Not in index.' % master_name + self.add_error(message) return 'BROKEN_LINK' # If this is a member of a class, link to the class page with an anchor. @@ -270,8 +306,10 @@ class ReferenceResolver(object): return os.path.join(relative_path_to_root, ref_path) - def _one_ref(self, string, relative_path_to_root): + def _one_ref(self, match, relative_path_to_root): """Return a link for a single "@{symbol}" reference.""" + string = match.group(1) + # Look for link text after $. dollar = string.rfind('$') if dollar > 0: # Ignore $ in first character @@ -303,8 +341,8 @@ class ReferenceResolver(object): code_ref=not manual_link_text) # Error! - log_error('Did not understand "@{%s}"' % string) - return 'ERROR:%s' % string + self.add_error('Did not understand "%s"' % match.group(0)) + return 'BROKEN_LINK' def _doc_link(self, string, link_text, manual_link_text, relative_path_to_root): @@ -330,7 +368,7 @@ class ReferenceResolver(object): def _doc_missing(self, string, unused_hash_tag, link_text, unused_manual_link_text, unused_relative_path_to_root): """Generate an error for unrecognized @{$...} references.""" - log_error('Handle doc reference "@{$%s}"' % string) + self.add_error('Unknown Document "%s"' % string) return link_text def _cc_link(self, string, link_text, unused_manual_link_text, @@ -348,7 +386,7 @@ class ReferenceResolver(object): elif string == 'tensorflow::ops::Const': ret = 'namespace/tensorflow/ops.md#const' else: - log_error('Handle C++ reference "@{%s}"' % string) + self.add_error('C++ reference not understood: "%s"' % string) return 'TODO_C++:%s' % string # relative_path_to_root gets you to api_docs/python, we go from there # to api_docs/cc, and then add ret. diff --git a/tensorflow/tools/docs/parser_test.py b/tensorflow/tools/docs/parser_test.py index 3e02160130f1959484472ecc77e8b2e883294a1e..862f0acfa90fbc8ea7f5054b745c684783f1ff5a 100644 --- a/tensorflow/tools/docs/parser_test.py +++ b/tensorflow/tools/docs/parser_test.py @@ -491,13 +491,13 @@ Returns: class TestParseFunctionDetails(googletest.TestCase): - def testParseFunctionDetails(self): + def test_parse_function_details(self): docstring, function_details = parser._parse_function_details(RELU_DOC) self.assertEqual(len(function_details), 2) args = function_details[0] self.assertEqual(args.keyword, 'Args') - self.assertEmpty(args.header) + self.assertEqual(len(args.header), 0) self.assertEqual(len(args.items), 2) self.assertEqual(args.items[0][0], 'features') self.assertEqual(args.items[1][0], 'name') @@ -515,5 +515,60 @@ class TestParseFunctionDetails(googletest.TestCase): docstring + ''.join(str(detail) for detail in function_details)) +class TestGenerateSignature(googletest.TestCase): + + def test_known_object(self): + if sys.version_info >= (3, 0): + print('Warning: Doc generation is not supported from python3.') + return + + known_object = object() + reverse_index = {id(known_object): 'location.of.object.in.api'} + + def example_fun(arg=known_object): # pylint: disable=unused-argument + pass + + sig = parser._generate_signature(example_fun, reverse_index) + self.assertEqual(sig, ['arg=location.of.object.in.api']) + + def test_literals(self): + if sys.version_info >= (3, 0): + print('Warning: Doc generation is not supported from python3.') + return + + def example_fun(a=5, b=5.0, c=None, d=True, e='hello', f=(1, (2, 3))): # pylint: disable=g-bad-name, unused-argument + pass + + sig = parser._generate_signature(example_fun, reverse_index={}) + self.assertEqual( + sig, ['a=5', 'b=5.0', 'c=None', 'd=True', "e='hello'", 'f=(1, (2, 3))']) + + def test_dotted_name(self): + if sys.version_info >= (3, 0): + print('Warning: Doc generation is not supported from python3.') + return + + # pylint: disable=g-bad-name + class a(object): + + class b(object): + + class c(object): + + class d(object): + + def __init__(self, *args): + pass + # pylint: enable=g-bad-name + + e = {'f': 1} + + def example_fun(arg1=a.b.c.d, arg2=a.b.c.d(1, 2), arg3=e['f']): # pylint: disable=unused-argument + pass + + sig = parser._generate_signature(example_fun, reverse_index={}) + self.assertEqual(sig, ['arg1=a.b.c.d', 'arg2=a.b.c.d(1, 2)', "arg3=e['f']"]) + + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD index fa2cf15cb16ec3396089b5f52ce8718fd05f94a0..d9ec8e8e9b45c0f503908a6608494e6356c012a7 100644 --- a/tensorflow/tools/graph_transforms/BUILD +++ b/tensorflow/tools/graph_transforms/BUILD @@ -240,7 +240,6 @@ cc_binary( deps = [ ":transform_utils", "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", ], @@ -250,7 +249,12 @@ py_library( name = "transform_graph_py", srcs = ["__init__.py"], srcs_version = "PY2AND3", - deps = ["//tensorflow/python:pywrap_tensorflow"], + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:errors", + "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:util", + ], ) tf_py_test( diff --git a/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc b/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc index 066727614c8a24329d9d2f45d9dfe946a51b322b..0978c336b49ce8cc72d9fc35af551a7f15ee697f 100644 --- a/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc +++ b/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc @@ -54,7 +54,7 @@ Status FoldOldBatchNorms(const GraphDef& input_graph_def, GraphDef replaced_graph_def; TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( current_graph_def, // clang-format off - {"BatchNormWithGlobalNormalization", // batch_norm_node + {"BatchNormWithGlobalNormalization|FusedBatchNorm", // batch_norm_node { {"Conv2D", // conv_node { @@ -74,19 +74,33 @@ Status FoldOldBatchNorms(const GraphDef& input_graph_def, std::vector* new_nodes) { // Find all the nodes we expect in the subgraph. const NodeDef& batch_norm_node = match.node; - CHECK_EQ("BatchNormWithGlobalNormalization", batch_norm_node.op()); + // BatchNormWithGlobalNormalization and FusedBatchNorm ops only differ + // by input order and attribute names. + CHECK(batch_norm_node.op() == "BatchNormWithGlobalNormalization" || + batch_norm_node.op() == "FusedBatchNorm"); + const bool is_fused = batch_norm_node.op() == "FusedBatchNorm"; + const int mean_idx = is_fused ? 3 : 1; + const int var_idx = is_fused ? 4 : 2; + const int beta_idx = is_fused ? 2 : 3; + const int gamma_idx = is_fused ? 1 : 4; + const string epsilon_attr = is_fused ? "epsilon" : "variance_epsilon"; + // FusedBatchNorm always scales after normalization. + const bool scale_after_normalization = + is_fused || + batch_norm_node.attr().at("scale_after_normalization").b(); + const NodeDef& conv_node = match.inputs[0].node; CHECK_EQ("Conv2D", conv_node.op()); const NodeDef& input_node = match.inputs[0].inputs[0].node; const NodeDef& weights_node = match.inputs[0].inputs[1].node; CHECK_EQ("Const", weights_node.op()); - const NodeDef& mean_node = match.inputs[1].node; + const NodeDef& mean_node = match.inputs[mean_idx].node; CHECK_EQ("Const", mean_node.op()); - const NodeDef& variance_node = match.inputs[2].node; + const NodeDef& variance_node = match.inputs[var_idx].node; CHECK_EQ("Const", variance_node.op()); - const NodeDef& beta_node = match.inputs[3].node; + const NodeDef& beta_node = match.inputs[beta_idx].node; CHECK_EQ("Const", beta_node.op()); - const NodeDef& gamma_node = match.inputs[4].node; + const NodeDef& gamma_node = match.inputs[gamma_idx].node; CHECK_EQ("Const", gamma_node.op()); // We have a set of vectors that we want to combine into a vector of @@ -98,9 +112,7 @@ Status FoldOldBatchNorms(const GraphDef& input_graph_def, Tensor beta = GetNodeTensorAttr(beta_node, "value"); Tensor gamma = GetNodeTensorAttr(gamma_node, "value"); const float variance_epsilon = - batch_norm_node.attr().at("variance_epsilon").f(); - const bool scale_after_normalization = - batch_norm_node.attr().at("scale_after_normalization").b(); + batch_norm_node.attr().at(epsilon_attr).f(); // Make sure all the inputs really are vectors, with as many entries // as there are columns in the weights. @@ -119,16 +131,17 @@ Status FoldOldBatchNorms(const GraphDef& input_graph_def, scale_values[i] = (1.0f / sqrtf(variance.flat()(i) + variance_epsilon)) * gamma.flat()(i); - offset_values[i] = 0.0f; } } else { for (int i = 0; i < weights_cols; ++i) { scale_values[i] = (1.0f / sqrtf(variance.flat()(i) + variance_epsilon)); - offset_values[i] = (-mean.flat()(i) * scale_values[i]) + - beta.flat()(i); } } + for (int i = 0; i < weights_cols; ++i) { + offset_values[i] = (-mean.flat()(i) * scale_values[i]) + + beta.flat()(i); + } // Multiply the original weights by the scale vector. auto weights_matrix = weights.flat_inner_dims(); diff --git a/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc b/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc index 1c4958d83c935e4b298b54461e820b15608d7b8e..3be9110b475f97087be18118d2ba0c52d6388c03 100644 --- a/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc +++ b/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" @@ -118,11 +119,92 @@ class FoldOldBatchNormsTest : public ::testing::Test { EXPECT_NE("BatchNormWithGlobalNormalization", node.op()); } } + + void TestFoldFusedBatchNorms() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2})); + test::FillValues( + &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f, + -5.0f, -3.0f, -6.0f}); + Output input_op = + Const(root.WithOpName("input_op"), Input::Initializer(input_data)); + + Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2})); + test::FillValues(&weights_data, + {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f}); + Output weights_op = + Const(root.WithOpName("weights_op"), Input::Initializer(weights_data)); + + Output conv_op = Conv2D(root.WithOpName("conv_op"), input_op, weights_op, + {1, 1, 1, 1}, "VALID"); + + Tensor mean_data(DT_FLOAT, TensorShape({2})); + test::FillValues(&mean_data, {10.0f, 20.0f}); + Output mean_op = + Const(root.WithOpName("mean_op"), Input::Initializer(mean_data)); + + Tensor variance_data(DT_FLOAT, TensorShape({2})); + test::FillValues(&variance_data, {0.25f, 0.5f}); + Output variance_op = Const(root.WithOpName("variance_op"), + Input::Initializer(variance_data)); + + Tensor beta_data(DT_FLOAT, TensorShape({2})); + test::FillValues(&beta_data, {0.1f, 0.6f}); + Output beta_op = + Const(root.WithOpName("beta_op"), Input::Initializer(beta_data)); + + Tensor gamma_data(DT_FLOAT, TensorShape({2})); + test::FillValues(&gamma_data, {1.0f, 2.0f}); + Output gamma_op = + Const(root.WithOpName("gamma_op"), Input::Initializer(gamma_data)); + + GraphDef original_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&original_graph_def)); + + NodeDef batch_norm_node; + batch_norm_node.set_op("FusedBatchNorm"); + batch_norm_node.set_name("output"); + AddNodeInput("conv_op", &batch_norm_node); + AddNodeInput("gamma_op", &batch_norm_node); + AddNodeInput("beta_op", &batch_norm_node); + AddNodeInput("mean_op", &batch_norm_node); + AddNodeInput("variance_op", &batch_norm_node); + SetNodeAttr("T", DT_FLOAT, &batch_norm_node); + SetNodeAttr("epsilon", 0.00001f, &batch_norm_node); + SetNodeAttr("is_training", false, &batch_norm_node); + *(original_graph_def.mutable_node()->Add()) = batch_norm_node; + + std::unique_ptr original_session(NewSession(SessionOptions())); + TF_ASSERT_OK(original_session->Create(original_graph_def)); + std::vector original_outputs; + TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs)); + + GraphDef fused_graph_def; + TF_ASSERT_OK(FoldOldBatchNorms(original_graph_def, {{}, {"output"}}, + &fused_graph_def)); + + std::unique_ptr fused_session(NewSession(SessionOptions())); + TF_ASSERT_OK(fused_session->Create(fused_graph_def)); + std::vector fused_outputs; + TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs)); + + test::ExpectTensorNear(original_outputs[0], fused_outputs[0], 1e-5); + + for (const NodeDef& node : fused_graph_def.node()) { + EXPECT_NE("FusedBatchNorm", node.op()); + } + } }; TEST_F(FoldOldBatchNormsTest, TestFoldOldBatchNorms) { TestFoldOldBatchNorms(); } +TEST_F(FoldOldBatchNormsTest, TestFoldFusedBatchNorms) { + TestFoldFusedBatchNorms(); +} + } // namespace graph_transforms } // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/quantize_nodes.cc b/tensorflow/tools/graph_transforms/quantize_nodes.cc index da064377ac3f2919e0d0421099d9407a35518e22..2b85e7e83c6f3e2c8d0840f0b9eb0b4992a8b113 100644 --- a/tensorflow/tools/graph_transforms/quantize_nodes.cc +++ b/tensorflow/tools/graph_transforms/quantize_nodes.cc @@ -119,6 +119,13 @@ const std::vector& GetQuantizedOpList() { DT_QUINT8, {}, QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, + {"ResizeBilinear", + {"align_corners"}, + {{"T", DT_QUINT8}}, + DT_QUINT8, + DT_QUINT8, + {1}, + QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, {"Relu6", {}, {{"Tinput", DT_QUINT8}}, diff --git a/tensorflow/tools/graph_transforms/quantize_nodes_test.cc b/tensorflow/tools/graph_transforms/quantize_nodes_test.cc index d02655f3f9cb5093a9c542e90aef2f8069e6e1dd..eca263a1ae0dbfad51565b1d3d0d26b066704fc8 100644 --- a/tensorflow/tools/graph_transforms/quantize_nodes_test.cc +++ b/tensorflow/tools/graph_transforms/quantize_nodes_test.cc @@ -106,8 +106,8 @@ class QuantizeNodesTest : public ::testing::Test { // Reshape is not included here because it can be added as part of the // quantization process. const std::set quantizable_ops = { - "Add", "BiasAdd", "Concat", "Conv2D", "MatMul", - "Relu", "Relu6", "AvgPool", "MaxPool", "Mul"}; + "Add", "BiasAdd", "Concat", "Conv2D", "MatMul", "Relu", + "Relu6", "ResizeBilinear", "AvgPool", "MaxPool", "Mul"}; for (const NodeDef& node : quantized_graph_def.node()) { EXPECT_EQ(0, quantizable_ops.count(node.op())) << "Found quantizable node " << node.op() << " for node named " @@ -652,6 +652,33 @@ class QuantizeNodesTest : public ::testing::Test { EXPECT_EQ("requantize_op", node_map.at("final_dequantize")->input(0)); } + void TestQuantizeResizeBilinear() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor size_tensor(DT_INT32, TensorShape({2})); + test::FillValues(&size_tensor, {256, 256}); + + Output constant_op = Const(root.WithOpName("size_tensor_op"), + Input::Initializer(size_tensor)); + + Output placeholder_op = + Placeholder(root.WithOpName("placeholder_op"), DT_FLOAT); + + Output resize_bilinear_op = ResizeBilinear( + root.WithOpName("resize_bilinear_op"), placeholder_op, constant_op); + + GraphDef float_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&float_graph_def)); + + Tensor input_tensor(DT_FLOAT, {1, 128, 128, 3}); + test::FillFn(&input_tensor, [](int) { return 100.0f; }); + + TestQuantizedVersusFloatGraph(float_graph_def, + {{"placeholder_op", input_tensor}}, + {"resize_bilinear_op"}); + } + void TestRemoveRedundantQuantizationWithMultipleOutputs() { auto root = tensorflow::Scope::NewRootScope(); using namespace ::tensorflow::ops; // NOLINT(build/namespaces) @@ -1446,6 +1473,10 @@ TEST_F(QuantizeNodesTest, TestQuantizeAvgPool) { TestQuantizeAvgPool(); } TEST_F(QuantizeNodesTest, TestQuantizeReshape) { TestQuantizeReshape(); } +TEST_F(QuantizeNodesTest, TestQuantizeResizeBilinear) { + TestQuantizeResizeBilinear(); +} + TEST_F(QuantizeNodesTest, TestRemoveRedundantQuantization) { TestRemoveRedundantQuantization(); } diff --git a/tensorflow/tools/graph_transforms/rename_attribute_test.cc b/tensorflow/tools/graph_transforms/rename_attribute_test.cc index a0a33e9fc090acea176333ec840e2e6f438ca998..31619d82ad998a48dde7a3c73fba12a16a0360c2 100644 --- a/tensorflow/tools/graph_transforms/rename_attribute_test.cc +++ b/tensorflow/tools/graph_transforms/rename_attribute_test.cc @@ -43,17 +43,17 @@ class RenameAttributeTest : public ::testing::Test { mul_node1->set_op("Mul"); mul_node1->add_input("add_node2"); mul_node1->add_input("add_node3"); - AddNodeAttr("foo", 23, mul_node1); - AddNodeAttr("bar", "something", mul_node1); + AddNodeAttr("foo", 23, mul_node1); + AddNodeAttr("bar", "something", mul_node1); NodeDef* add_node2 = graph_def.add_node(); add_node2->set_name("add_node2"); add_node2->set_op("Add"); add_node2->add_input("const_node1"); add_node2->add_input("const_node2"); - AddNodeAttr("foo", 46, add_node2); - AddNodeAttr("bob", 23, add_node2); - AddNodeAttr("bar", "something else", add_node2); + AddNodeAttr("foo", 46, add_node2); + AddNodeAttr("bob", 23, add_node2); + AddNodeAttr("bar", "something else", add_node2); NodeDef* add_node3 = graph_def.add_node(); add_node3->set_name("add_node3"); diff --git a/tensorflow/tools/graph_transforms/strip_unused_nodes_test.cc b/tensorflow/tools/graph_transforms/strip_unused_nodes_test.cc index 4eb074998f71e8c1ff51ea64463ff35660bcedca..c0107014e2cf115aeafe78ca879c0cb169cb335b 100644 --- a/tensorflow/tools/graph_transforms/strip_unused_nodes_test.cc +++ b/tensorflow/tools/graph_transforms/strip_unused_nodes_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/cc/ops/nn_ops.h" #include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/tools/graph_transforms/summarize_graph_main.cc b/tensorflow/tools/graph_transforms/summarize_graph_main.cc index 91670f54d49d057bbb5ff894247c79538877ef5f..e79e7ba121c93da7704a89fe8900d9615ab2c3b5 100644 --- a/tensorflow/tools/graph_transforms/summarize_graph_main.cc +++ b/tensorflow/tools/graph_transforms/summarize_graph_main.cc @@ -23,8 +23,10 @@ limitations under the License. // bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \ // --in_graph=my_graph.pb +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" diff --git a/tensorflow/tools/graph_transforms/transform_utils.h b/tensorflow/tools/graph_transforms/transform_utils.h index 6ed549a9589af2ff287aa199b2cfb113e40bf871..2db0a24267bba49929f8e02240e7b471028ba9d8 100644 --- a/tensorflow/tools/graph_transforms/transform_utils.h +++ b/tensorflow/tools/graph_transforms/transform_utils.h @@ -20,10 +20,12 @@ limitations under the License. #include #include +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD index 51ba3b7a0be143a0186269678d508f4f0e95c55c..9da5d5cb5b818ef14fb4edb7bb294d9c5a05c561 100644 --- a/tensorflow/tools/lib_package/BUILD +++ b/tensorflow/tools/lib_package/BUILD @@ -87,6 +87,7 @@ genrule( "//third_party/fft2d:LICENSE", "@boringssl//:LICENSE", "@com_googlesource_code_re2//:LICENSE", + "@cub_archive//:LICENSE.TXT", "@curl//:COPYING", "@eigen_archive//:COPYING.MPL2", "@farmhash_archive//:COPYING", @@ -117,6 +118,7 @@ genrule( "//third_party/fft2d:LICENSE", "@boringssl//:LICENSE", "@com_googlesource_code_re2//:LICENSE", + "@cub_archive//:LICENSE.TXT", "@curl//:COPYING", "@eigen_archive//:COPYING.MPL2", "@farmhash_archive//:COPYING", diff --git a/tensorflow/tools/mlpbtxt/BUILD b/tensorflow/tools/mlpbtxt/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..fc63e9a0b73fd92c63cde5d60bdb9b984922f820 --- /dev/null +++ b/tensorflow/tools/mlpbtxt/BUILD @@ -0,0 +1,44 @@ +# Description: +# This package provides binaries that convert between multi-line and standard +# pbtxt (text-serialization of protocol message) files. + +package(default_visibility = ["//visibility:private"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files([ + "LICENSE", + "placeholder.txt", +]) + +cc_binary( + name = "tomlpbtxt", + srcs = ["tomlpbtxt.cc"], + deps = [ + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:op_gen_lib", + ], +) + +cc_binary( + name = "frommlpbtxt", + srcs = ["frommlpbtxt.cc"], + deps = [ + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:op_gen_lib", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/tools/mlpbtxt/frommlpbtxt.cc b/tensorflow/tools/mlpbtxt/frommlpbtxt.cc new file mode 100644 index 0000000000000000000000000000000000000000..643924b318d3fec850ebd6c8275a2eab4884a644 --- /dev/null +++ b/tensorflow/tools/mlpbtxt/frommlpbtxt.cc @@ -0,0 +1,70 @@ +/* Copyright 2017 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace { + +int Run(int argc, char** argv) { + string FLAGS_in = ""; + string FLAGS_out = ""; + + std::vector flag_list = { + Flag("in", &FLAGS_in, "Input multi-line proto text (.mlpbtxt) file name"), + Flag("out", &FLAGS_out, "Output proto text (.pbtxt) file name")}; + + // Parse the command-line. + const string usage = Flags::Usage(argv[0], flag_list); + const bool parse_ok = Flags::Parse(&argc, argv, flag_list); + if (argc != 1 || !parse_ok) { + printf("%s", usage.c_str()); + return 2; + } + + port::InitMain(argv[0], &argc, &argv); + + // Read the input file --in. + string in_contents; + Status s = ReadFileToString(Env::Default(), FLAGS_in, &in_contents); + if (!s.ok()) { + printf("Error reading file %s: %s\n", FLAGS_in.c_str(), + s.ToString().c_str()); + return 1; + } + + // Write the output file --out. + const string out_contents = PBTxtFromMultiline(in_contents); + s = WriteStringToFile(Env::Default(), FLAGS_out, out_contents); + if (!s.ok()) { + printf("Error writing file %s: %s\n", FLAGS_out.c_str(), + s.ToString().c_str()); + return 1; + } + + return 0; +} + +} // namespace +} // namespace tensorflow + +int main(int argc, char** argv) { return tensorflow::Run(argc, argv); } diff --git a/tensorflow/tools/mlpbtxt/tomlpbtxt.cc b/tensorflow/tools/mlpbtxt/tomlpbtxt.cc new file mode 100644 index 0000000000000000000000000000000000000000..469be49ed3c966c671f1f45619d0a8d88fe519f1 --- /dev/null +++ b/tensorflow/tools/mlpbtxt/tomlpbtxt.cc @@ -0,0 +1,81 @@ +/* Copyright 2017 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace { + +int Run(int argc, char** argv) { + string FLAGS_in = ""; + string FLAGS_out = ""; + string FLAGS_fields = "description"; + + std::vector flag_list = { + Flag("in", &FLAGS_in, "Input proto text (.pbtxt) file name"), + Flag("out", &FLAGS_out, + "Output multi-line proto text (.mlpbtxt) file name"), + Flag("fields", &FLAGS_fields, "Comma-separated list of field names")}; + + // Parse the command-line. + const string usage = Flags::Usage(argv[0], flag_list); + const bool parse_ok = Flags::Parse(&argc, argv, flag_list); + if (argc != 1 || !parse_ok) { + printf("%s", usage.c_str()); + return 2; + } + + // Parse the --fields option. + std::vector fields = + str_util::Split(FLAGS_fields, ',', str_util::SkipEmpty()); + if (fields.empty()) { + printf("--fields must be non-empty.\n%s", usage.c_str()); + return 2; + } + + port::InitMain(argv[0], &argc, &argv); + + // Read the input file --in. + string in_contents; + Status s = ReadFileToString(Env::Default(), FLAGS_in, &in_contents); + if (!s.ok()) { + printf("Error reading file %s: %s\n", FLAGS_in.c_str(), + s.ToString().c_str()); + return 1; + } + + // Write the output file --out. + const string out_contents = PBTxtToMultiline(in_contents, fields); + s = WriteStringToFile(Env::Default(), FLAGS_out, out_contents); + if (!s.ok()) { + printf("Error writing file %s: %s\n", FLAGS_out.c_str(), + s.ToString().c_str()); + return 1; + } + + return 0; +} + +} // namespace +} // namespace tensorflow + +int main(int argc, char** argv) { return tensorflow::Run(argc, argv); } diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index 798338d787551769d94afd9f774a23655a640086..78a652ccaeade2d3ae7928d0754609e864d8206f 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -75,10 +75,6 @@ py_binary( "//tensorflow/python/saved_model", "//tensorflow/python/tools:tools_pip", # These targets don't build on Windows yet. Exclude them for now. - # rules_closure currently doesn't build on Windows due to - # https://github.com/bazelbuild/rules_closure/pull/206 - # Since tensorboard dependes on rules_closure, exclude tensorboard until it's fixed. - # "//tensorflow/tensorboard", # "//tensorflow/contrib/ndlstm", # "//tensorflow/contrib/slim", # "//tensorflow/contrib/slim/python/slim/nets:nets_pip", @@ -99,6 +95,7 @@ filegroup( "//third_party/hadoop:LICENSE.txt", "@boringssl//:LICENSE", "@com_googlesource_code_re2//:LICENSE", + "@cub_archive//:LICENSE.TXT", "@curl//:COPYING", "@eigen_archive//:COPYING.MPL2", "@farmhash_archive//:COPYING", @@ -113,15 +110,12 @@ filegroup( "@lmdb//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", "@nanopb_git//:LICENSE.txt", - "@org_html5lib//:LICENSE", - "@org_mozilla_bleach//:LICENSE", - "@org_pocoo_werkzeug//:LICENSE", - "@org_pythonhosted_markdown//:LICENSE.md", "@png_archive//:LICENSE", "@protobuf//:LICENSE", "@six_archive//:LICENSE", "@snappy//:COPYING", "@zlib_archive//:zlib.h", + "@org_python_pypi_backports_weakref//:LICENSE", ] + if_not_windows([ "@nccl_archive//:LICENSE.txt", ]) + tf_additional_license_deps(), @@ -141,11 +135,13 @@ sh_binary( ":included_headers", ":simple_console", "//tensorflow:tensorflow_py", + "//tensorflow/contrib/boosted_trees:boosted_trees_pip", "//tensorflow/contrib/graph_editor:graph_editor_pip", "//tensorflow/contrib/keras:keras", "//tensorflow/contrib/labeled_tensor:labeled_tensor_pip", "//tensorflow/contrib/ndlstm:ndlstm", "//tensorflow/contrib/nn:nn_py", + "//tensorflow/contrib/predictor:predictor_pip", "//tensorflow/contrib/session_bundle:session_bundle_pip", "//tensorflow/contrib/signal:signal_py", "//tensorflow/contrib/slim:slim", @@ -154,6 +150,10 @@ sh_binary( "//tensorflow/contrib/specs:specs", "//tensorflow/contrib/tensor_forest:init_py", "//tensorflow/contrib/tensor_forest/hybrid:hybrid_pip", + "//tensorflow/contrib/timeseries:timeseries_pip", + "//tensorflow/contrib/tpu:tpu_estimator", + "//tensorflow/contrib/tpu:tpu_helper_library", + "//tensorflow/contrib/tpu:tpu_py", "//tensorflow/examples/tutorials/mnist:package", "//tensorflow/python:distributed_framework_test_lib", "//tensorflow/python:meta_graph_testdata", @@ -161,7 +161,6 @@ sh_binary( "//tensorflow/python/debug:debug_pip", "//tensorflow/python/saved_model:saved_model", "//tensorflow/python/tools:tools_pip", - "//tensorflow/tensorboard", ], }) + if_mkl(["//third_party/mkl:intel_binary_blob"]), ) diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py index dec08157c2c3edf9e632227fb54a50abf3b1b49d..83909d83ae4c45404419745ef7982649e7f416f5 100644 --- a/tensorflow/tools/pip_package/pip_smoke_test.py +++ b/tensorflow/tools/pip_package/pip_smoke_test.py @@ -60,6 +60,12 @@ BLACKLIST = [ "//tensorflow/contrib/framework:checkpoint_ops_testdata", "//tensorflow/contrib/bayesflow:reinforce_simple_example", "//tensorflow/contrib/bayesflow:examples/reinforce_simple/reinforce_simple_example.py", # pylint:disable=line-too-long + "//tensorflow/contrib/timeseries/examples:predict", + "//tensorflow/contrib/timeseries/examples:multivariate", + "//tensorflow/contrib/timeseries/examples:known_anomaly", + "//tensorflow/contrib/timeseries/examples:data/period_trend.csv", # pylint:disable=line-too-long + "//tensorflow/contrib/timeseries/python/timeseries:test_utils", + "//tensorflow/contrib/timeseries/python/timeseries/state_space_models:test_utils", # pylint:disable=line-too-long ] diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index dd47b44001a80973715863b89058e250f5a07146..39499e17758dee4b4458cd14d98ddf00af504dc7 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -29,17 +29,14 @@ from setuptools.dist import Distribution # This version string is semver compatible, but incompatible with pip. # For pip, we will remove all '-' characters from this string, and use the # result for pip. -_VERSION = '1.2.0' +_VERSION = '1.2.1' REQUIRED_PACKAGES = [ 'numpy >= 1.11.0', 'six >= 1.10.0', - 'protobuf >= 3.3.0', - 'werkzeug >= 0.11.10', - 'html5lib == 0.9999999', # identical to 1.0b8 - 'markdown == 2.2.0', - 'bleach == 1.5.0', + 'protobuf >= 3.2.0', 'backports.weakref == 1.0rc1', + 'tensorflow-tensorboard', ] project_name = 'tensorflow' @@ -59,7 +56,6 @@ else: # pylint: disable=line-too-long CONSOLE_SCRIPTS = [ - 'tensorboard = tensorflow.tensorboard.tensorboard:main', 'saved_model_cli = tensorflow.python.tools.saved_model_cli:main', ] # pylint: enable=line-too-long @@ -191,8 +187,6 @@ setup( package_data={ 'tensorflow': [ EXTENSION_NAME, - 'tensorboard/components/index.html', - 'tensorboard/TAG', ] + matches, }, zip_safe=False, diff --git a/tensorflow/tools/quantization/BUILD b/tensorflow/tools/quantization/BUILD index cb41185219c56f9a0d834a2e4b5b71c57b46810a..e99ad06a06294c4d037b76ea9450e51bd795e79d 100644 --- a/tensorflow/tools/quantization/BUILD +++ b/tensorflow/tools/quantization/BUILD @@ -13,7 +13,20 @@ py_library( name = "quantize_graph_lib", srcs = ["quantize_graph.py"], srcs_version = "PY2AND3", - deps = ["//tensorflow:tensorflow_py"], + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework", + "//tensorflow/python:framework_ops", + "//tensorflow/python:graph_util", + "//tensorflow/python:platform", + "//tensorflow/python:session", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:tensor_util", + "//third_party/py/numpy", + ], ) py_binary( @@ -27,18 +40,17 @@ py_binary( "//tensorflow/python:client", "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:graph_util", "//tensorflow/python:platform", + "//tensorflow/python:tensor_util", "//third_party/py/numpy", - "@six_archive//:six", ], ) py_test( name = "quantize_graph_test", size = "small", - srcs = [ - "quantize_graph_test.py", - ], + srcs = ["quantize_graph_test.py"], srcs_version = "PY2AND3", tags = ["nomsan"], # http://b/32242946 deps = [ @@ -48,6 +60,7 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:graph_util", "//tensorflow/python:platform", "//third_party/py/numpy", ], @@ -55,12 +68,13 @@ py_test( py_binary( name = "graph_to_dot", - srcs = [ - "graph_to_dot.py", - ], + srcs = ["graph_to_dot.py"], main = "graph_to_dot.py", srcs_version = "PY2AND3", - deps = ["//tensorflow/python:platform"], + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:platform", + ], ) filegroup( diff --git a/tensorflow/tools/test/BUILD b/tensorflow/tools/test/BUILD index 9367bcd4a3457d7387ee8dc17a4d19043fa8c9a2..28d651e9106b29058824c06b160df2b9b5781757 100644 --- a/tensorflow/tools/test/BUILD +++ b/tensorflow/tools/test/BUILD @@ -22,6 +22,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + "//tensorflow/core:protos_all_py", "//tensorflow/python:client", "//tensorflow/python:errors", "//tensorflow/python:platform", @@ -46,6 +47,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":system_info_lib", + "//tensorflow/core:protos_all_py", "//tensorflow/python:platform", ], ) @@ -54,8 +56,10 @@ py_binary( name = "run_and_gather_logs", srcs = ["run_and_gather_logs.py"], srcs_version = "PY2AND3", + visibility = ["//visibility:public"], deps = [ ":run_and_gather_logs_lib", + "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", "//tensorflow/python:platform", ], diff --git a/tensorflow/tools/test/performance.bzl b/tensorflow/tools/test/performance.bzl index 2956c6dde74ff38a7f000d6b6b595beaa397fa76..64fff844a70d439306c9bcf7f21d5a6047fa428a 100644 --- a/tensorflow/tools/test/performance.bzl +++ b/tensorflow/tools/test/performance.bzl @@ -28,7 +28,7 @@ def tf_cc_logged_benchmark( name = name, tags = all_tags, size = "large", - srcs = ["//tensorflow/tools/test:run_and_gather_logs.py"], + srcs = ["//tensorflow/tools/test:run_and_gather_logs"], args = [ "--name=//%s:%s" % (PACKAGE_NAME, name), "--test_name=" + target, diff --git a/tensorflow/tools/test/upload_test_benchmarks.py b/tensorflow/tools/test/upload_test_benchmarks.py index 829333e05629946fc5627d37301883d70572b1be..77cc9f75f7725438918f681833d58e9ecb4a2f70 100644 --- a/tensorflow/tools/test/upload_test_benchmarks.py +++ b/tensorflow/tools/test/upload_test_benchmarks.py @@ -162,7 +162,7 @@ def upload_benchmark_data(client, data): t_val.update({ "test": test_name, "start": start_time, - "info": unicode(test_result) + "info": unicode(data) }) batch.append(t_val) diff --git a/tensorflow/tools/test/upload_test_benchmarks_index.yaml b/tensorflow/tools/test/upload_test_benchmarks_index.yaml index 8cd33a1da60cad1c1a0e21998b4eefc81babfd8e..ec7f76f6663b3e586b4b63e92eb576740cd445f9 100644 --- a/tensorflow/tools/test/upload_test_benchmarks_index.yaml +++ b/tensorflow/tools/test/upload_test_benchmarks_index.yaml @@ -27,7 +27,7 @@ indexes: properties: - name: test - name: start - direction: asc + direction: desc # Index to access a specific (test, entry, start) Entity, and also to be able to # fetch a range of (start, timing) graph values for a given (test, entry) pair diff --git a/tensorflow/tools/tfprof/g3doc/advise.md b/tensorflow/tools/tfprof/g3doc/advise.md deleted file mode 100644 index 3bce6270ff8368fb57d183c6f4c6a88f5dd5bc07..0000000000000000000000000000000000000000 --- a/tensorflow/tools/tfprof/g3doc/advise.md +++ /dev/null @@ -1,44 +0,0 @@ -## Auto Detect and Advise - -tfprof analyzes profiles and generates advises for common issues. - -### Run Advise. -```python -# First create a profiler. See profiler tutorials for more details. -profiler = model_analyzer.Profiler(sess.graph) -run_meta = config_pb2.RunMetadata() -_ = sess.run(r1, - options=config_pb2.RunOptions( - trace_level=config_pb2.RunOptions.FULL_TRACE), - run_metadata=run_meta) -profiler.add_step(1, run_meta) - -# Start advise. -profiler.advise() -``` - -### Checker - -There is no magic behind advise mode. tfprof builds the profiles first, then -it runs through a list of `Checkers`, each one responsible for checking one -area with the profile and report issues. A `Checker` is like a plugin. - -For example: - -####JobChecker (Not Available OSS) -* Checking RecvTensor RPC latency and bandwidth. -* Checking CPU/Memory utilization of the job. - -####AcceleratorUtilization Checker -* Checks what percentage of time the accelerator spends on computation. - -####Operation Checker -* Check whether the operation runs with optimal options. -* Checks if there is a better implementation to replace the current operation. - -####Contribute Your Checker - -Follow examples of accelerator_utilization_checker.h - - - diff --git a/tensorflow/tools/tfprof/internal/advisor/checker.h b/tensorflow/tools/tfprof/internal/advisor/checker.h deleted file mode 100644 index b8b057be5b1d6410acfb3e8607693e303a6f963c..0000000000000000000000000000000000000000 --- a/tensorflow/tools/tfprof/internal/advisor/checker.h +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_ADVISOR_CHECKER_H_ -#define THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_ADVISOR_CHECKER_H_ - -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/tools/tfprof/internal/tfprof_stats.h" - -namespace tensorflow { -namespace tfprof { - -static const char* const kLevel[] = { - "NOTE", // Good to know. - "SUGGEST", // Might get better. - "WARN", // Please do it for better. -}; - -class Checker { - public: - virtual ~Checker(){}; - - virtual string name() = 0; - - std::vector Run(const TFStats* stats) { return Check(stats); } - - protected: - // Returns a vector of string, each one being an advice. - virtual std::vector Check(const TFStats* stats) = 0; -}; -} // namespace tfprof -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_ADVISOR_CHECKER_H_ diff --git a/tensorflow/tools/tfprof/internal/advisor/tfprof_advisor.h b/tensorflow/tools/tfprof/internal/advisor/tfprof_advisor.h deleted file mode 100644 index 856f51545921283799a87e053c96a19d0ee4387d..0000000000000000000000000000000000000000 --- a/tensorflow/tools/tfprof/internal/advisor/tfprof_advisor.h +++ /dev/null @@ -1,58 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_ADVISOR_TFPROF_ADVICE_H_ -#define THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_ADVISOR_TFPROF_ADVICE_H_ - -#include "tensorflow/tools/tfprof/internal/advisor/accelerator_utilization_checker.h" -#include "tensorflow/tools/tfprof/internal/advisor/checker.h" -#include "tensorflow/tools/tfprof/internal/advisor/internal_checker_runner.h" -#include "tensorflow/tools/tfprof/internal/advisor/operation_checker.h" - -namespace tensorflow { -namespace tfprof { - -// The Advisor runs a list of Checkers, each checks a specific area. -class Advisor { - public: - Advisor(const TFStats* stats) : stats_(stats) {} - - std::map> Advise() { - // Note: Release a checker's memory ASAP. - std::map> reports = RunInternalCheckers(stats_); - // TODO(xpan): Think of a way to turn off/on specific checkers. - AcceleratorUtilizationChecker au_checker; - reports[au_checker.name()] = au_checker.Run(stats_); - OperationChecker op_checker; - reports[op_checker.name()] = op_checker.Run(stats_); - - for (const auto& checker_r : reports) { - fprintf(stdout, "%s reports:\n", checker_r.first.c_str()); - for (const auto& r : checker_r.second) { - fprintf(stdout, "%s\n", r.c_str()); - } - } - fflush(stdout); - return reports; - } - - private: - const TFStats* stats_; -}; - -} // namespace tfprof -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_TOOLS_TFPROF_INTERNAL_ADVISOR_TFPROF_ADVICE_H_ diff --git a/tensorflow/tools/tfprof/internal/tfprof_show_test.cc b/tensorflow/tools/tfprof/internal/tfprof_show_test.cc deleted file mode 100644 index 498477de0a00f828b07a6a955e05722d6a79d433..0000000000000000000000000000000000000000 --- a/tensorflow/tools/tfprof/internal/tfprof_show_test.cc +++ /dev/null @@ -1,118 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/tools/tfprof/internal/tfprof_stats.h" - -#include - -#include "tensorflow/c/checkpoint_reader.h" -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/protobuf/config.pb.h" -#include "tensorflow/tools/tfprof/internal/tfprof_constants.h" -#include "tensorflow/tools/tfprof/internal/tfprof_options.h" -#include "tensorflow/tools/tfprof/internal/tfprof_utils.h" -#include "tensorflow/tools/tfprof/tfprof_log.pb.h" -#include "tensorflow/tools/tfprof/tfprof_output.pb.h" - -namespace tensorflow { -namespace tfprof { -class TFProfShowTest : public ::testing::Test { - protected: - TFProfShowTest() { - string graph_path = - io::JoinPath(testing::TensorFlowSrcRoot(), - "tools/tfprof/internal/testdata/graph.pbtxt"); - std::unique_ptr graph_pb(new tensorflow::GraphDef()); - TF_CHECK_OK(ReadGraphDef(Env::Default(), graph_path, graph_pb.get())); - - std::unique_ptr run_meta_pb( - new tensorflow::RunMetadata()); - string run_meta_path = - io::JoinPath(testing::TensorFlowSrcRoot(), - "tools/tfprof/internal/testdata/run_meta"); - TF_CHECK_OK( - ReadBinaryProto(Env::Default(), run_meta_path, run_meta_pb.get())); - - std::unique_ptr op_log_pb(new OpLog()); - string op_log_path = - io::JoinPath(testing::TensorFlowSrcRoot(), - "tools/tfprof/internal/testdata/tfprof_log"); - TF_CHECK_OK(ReadBinaryProto(Env::Default(), op_log_path, op_log_pb.get())); - - string ckpt_path = io::JoinPath(testing::TensorFlowSrcRoot(), - "tools/tfprof/internal/testdata/ckpt"); - TF_Status* status = TF_NewStatus(); - std::unique_ptr ckpt_reader( - new checkpoint::CheckpointReader(ckpt_path, status)); - CHECK(TF_GetCode(status) == TF_OK); - TF_DeleteStatus(status); - - tf_stats_.reset(new TFStats(std::move(graph_pb), std::move(run_meta_pb), - std::move(op_log_pb), std::move(ckpt_reader))); - } - - std::unique_ptr tf_stats_; -}; - -TEST_F(TFProfShowTest, DumpScopeMode) { - string dump_file = io::JoinPath(testing::TmpDir(), "dump"); - Options opts(5, 0, 0, 0, 0, 0, -1, "name", - {"VariableV2"}, // accout_type_regexes - {".*"}, {""}, {".*"}, {""}, false, - {"params", "bytes", "micros", "float_ops"}, "file", - {{"outfile", dump_file}}); - tf_stats_->ShowGraphNode("scope", opts); - - string dump_str; - TF_CHECK_OK(ReadFileToString(Env::Default(), dump_file, &dump_str)); - EXPECT_EQ( - "node name | # parameters | # float_ops | output bytes | execution " - "time\n_TFProfRoot (--/370 params, --/0 flops, --/1.48KB, --/5us)\n " - "conv2d (--/140 params, --/0 flops, --/560B, --/2us)\n conv2d/bias " - "(5, 5/5 params, 0/0 flops, 20B/20B, 1us/1us)\n conv2d/kernel " - "(3x3x3x5, 135/135 params, 0/0 flops, 540B/540B, 1us/1us)\n conv2d_1 " - "(--/230 params, --/0 flops, --/920B, --/3us)\n conv2d_1/bias (5, 5/5 " - "params, 0/0 flops, 20B/20B, 1us/1us)\n conv2d_1/kernel (3x3x5x5, " - "225/225 params, 0/0 flops, 900B/900B, 2us/2us)\n", - dump_str); -} - -TEST_F(TFProfShowTest, DumpOpMode) { - string dump_file = io::JoinPath(testing::TmpDir(), "dump"); - Options opts( - 5, 0, 0, 0, 0, 4, -1, "params", {".*"}, // accout_type_regexes - {".*"}, {""}, {".*"}, {""}, false, - {"params", "bytes", "micros", "float_ops", "occurrence", "input_shapes"}, - "file", {{"outfile", dump_file}}); - tf_stats_->ShowMultiGraphNode("op", opts); - - string dump_str; - TF_CHECK_OK(ReadFileToString(Env::Default(), dump_file, &dump_str)); - EXPECT_EQ( - "nodename|outputbytes|executiontime|#parameters|#float_ops|opoccurrence|" - "inputshapes\nVariableV21.48KB(100.00%,17.10%),5us(100.00%,5.15%)," - "370params(100.00%,100.00%),0float_ops(100.00%,0.00%),4\n\ninput_type:\t(" - "*4)\texec_time:5us\n\nAssign0B(0.00%,0.00%),0us(94.85%,0.00%),0params(0." - "00%,0.00%),0float_ops(100.00%,0.00%),8\n\ninput_type:0:unknown,\t1:" - "unknown\t(*8)\texec_time:0us\n\nConst1.54KB(58.87%,17.74%),1us(80.41%,1." - "03%),0params(0.00%,0.00%),0float_ops(98.49%,0.00%),24\n\ninput_type:\t(*" - "24)\texec_time:1us\n\n", - StringReplace(dump_str, " ", "")); -} -} // namespace tfprof -} // namespace tensorflow diff --git a/tensorflow/tools/tfprof/tfprof_main.cc b/tensorflow/tools/tfprof/tfprof_main.cc deleted file mode 100644 index ae02b526347474e1aa738ee1a84cfabaeb7d723c..0000000000000000000000000000000000000000 --- a/tensorflow/tools/tfprof/tfprof_main.cc +++ /dev/null @@ -1,286 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include -#include -#include - -#include "linenoise.h" -#include "tensorflow/c/c_api.h" -#include "tensorflow/c/checkpoint_reader.h" -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/init_main.h" -#include "tensorflow/core/platform/protobuf.h" -#include "tensorflow/core/protobuf/config.pb.h" -#include "tensorflow/core/util/command_line_flags.h" -#include "tensorflow/tools/tfprof/internal/tfprof_options.h" -#include "tensorflow/tools/tfprof/internal/tfprof_stats.h" -#include "tensorflow/tools/tfprof/internal/tfprof_utils.h" -#include "tensorflow/tools/tfprof/tfprof_log.pb.h" - -using tensorflow::str_util::Split; - -void completion(const char* buf, linenoiseCompletions* lc) { - tensorflow::string buf_str = buf; - if (buf_str.find(" ") == buf_str.npos) { - for (const char* opt : tensorflow::tfprof::kCmds) { - if (tensorflow::string(opt).find(buf_str) == 0) { - linenoiseAddCompletion(lc, opt); - } - } - return; - } - - tensorflow::string prefix; - int last_dash = buf_str.find_last_of(' '); - if (last_dash != tensorflow::string::npos) { - prefix = buf_str.substr(0, last_dash + 1); - buf_str = buf_str.substr(last_dash + 1, tensorflow::kint32max); - } - for (const char* opt : tensorflow::tfprof::kOptions) { - if (tensorflow::string(opt).find(buf_str) == 0) { - linenoiseAddCompletion(lc, (prefix + opt).c_str()); - } - } -} - -int main(int argc, char** argv) { - tensorflow::string FLAGS_graph_path = ""; - tensorflow::string FLAGS_run_meta_path = ""; - tensorflow::string FLAGS_op_log_path = ""; - tensorflow::string FLAGS_checkpoint_path = ""; - tensorflow::int32 FLAGS_max_depth = 10; - tensorflow::int64 FLAGS_min_bytes = 0; - tensorflow::int64 FLAGS_min_micros = 0; - tensorflow::int64 FLAGS_min_params = 0; - tensorflow::int64 FLAGS_min_float_ops = 0; - tensorflow::int64 FLAGS_min_occurrence = 0; - tensorflow::int64 FLAGS_step = -1; - tensorflow::string FLAGS_order_by = "name"; - tensorflow::string FLAGS_account_type_regexes = ".*"; - tensorflow::string FLAGS_start_name_regexes = ".*"; - tensorflow::string FLAGS_trim_name_regexes = ""; - tensorflow::string FLAGS_show_name_regexes = ".*"; - tensorflow::string FLAGS_hide_name_regexes; - bool FLAGS_account_displayed_op_only = false; - tensorflow::string FLAGS_select = "params"; - tensorflow::string FLAGS_output = ""; - for (int i = 0; i < argc; i++) { - fprintf(stderr, "%s\n", argv[i]); - } - - std::vector flag_list = { - tensorflow::Flag("graph_path", &FLAGS_graph_path, - "GraphDef proto text file name"), - tensorflow::Flag("run_meta_path", &FLAGS_run_meta_path, - "Comma-separated list of RunMetadata proto binary " - "files. Each file is given step number 0,1,2,etc"), - tensorflow::Flag("op_log_path", &FLAGS_op_log_path, - "tensorflow::tfprof::OpLog proto binary file name"), - tensorflow::Flag("checkpoint_path", &FLAGS_checkpoint_path, - "TensorFlow Checkpoint file name"), - tensorflow::Flag("max_depth", &FLAGS_max_depth, "max depth"), - tensorflow::Flag("min_bytes", &FLAGS_min_bytes, "min_bytes"), - tensorflow::Flag("min_micros", &FLAGS_min_micros, "min micros"), - tensorflow::Flag("min_params", &FLAGS_min_params, "min params"), - tensorflow::Flag("min_float_ops", &FLAGS_min_float_ops, "min float ops"), - tensorflow::Flag("min_occurrence", &FLAGS_min_occurrence, - "min occurrence"), - tensorflow::Flag("step", &FLAGS_step, - "The stats of which step to use. By default average"), - tensorflow::Flag("order_by", &FLAGS_order_by, "order by"), - tensorflow::Flag("account_type_regexes", &FLAGS_start_name_regexes, - "start name regexes"), - tensorflow::Flag("trim_name_regexes", &FLAGS_trim_name_regexes, - "trim name regexes"), - tensorflow::Flag("show_name_regexes", &FLAGS_show_name_regexes, - "show name regexes"), - tensorflow::Flag("hide_name_regexes", &FLAGS_hide_name_regexes, - "hide name regexes"), - tensorflow::Flag("account_displayed_op_only", - &FLAGS_account_displayed_op_only, - "account displayed op only"), - tensorflow::Flag("select", &FLAGS_select, "select"), - tensorflow::Flag("output", &FLAGS_output, "output"), - }; - tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list); - bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); - if (!parse_ok) { - printf("%s", usage.c_str()); - return (2); - } - tensorflow::port::InitMain(argv[0], &argc, &argv); - - fprintf(stderr, "%s\n", FLAGS_graph_path.c_str()); - - std::vector account_type_regexes = - Split(FLAGS_account_type_regexes, ',', tensorflow::str_util::SkipEmpty()); - std::vector start_name_regexes = - Split(FLAGS_start_name_regexes, ',', tensorflow::str_util::SkipEmpty()); - std::vector trim_name_regexes = - Split(FLAGS_trim_name_regexes, ',', tensorflow::str_util::SkipEmpty()); - std::vector show_name_regexes = - Split(FLAGS_show_name_regexes, ',', tensorflow::str_util::SkipEmpty()); - std::vector hide_name_regexes = - Split(FLAGS_hide_name_regexes, ',', tensorflow::str_util::SkipEmpty()); - std::vector select = - Split(FLAGS_select, ',', tensorflow::str_util::SkipEmpty()); - - tensorflow::string output_type; - std::map output_options; - tensorflow::Status s = tensorflow::tfprof::ParseOutput( - FLAGS_output, &output_type, &output_options); - CHECK(s.ok()) << s.ToString(); - - tensorflow::string cmd = ""; - if (argc == 1 && FLAGS_graph_path.empty()) { - printf("1) go/tfprof: Tutorial.\n"); - printf("2) tfprof help: Detail help information.\n"); - printf( - "3) tfprof --graph_path : " - "Profiling model structure, tensor shape and # parameters.\n"); - printf( - "4) tfprof --graph_path \\\n" - " --run_meta_path \\\n" - " --op_log_path " - "\\\n" - " --checkpoint_path : " - "Profiling everything!\n"); - return 0; - } else if (argc > 1) { - if (tensorflow::string(argv[1]) == tensorflow::tfprof::kCmds[5]) { - tensorflow::tfprof::PrintHelp(); - return 0; - } - if (tensorflow::string(argv[1]) == tensorflow::tfprof::kCmds[0] || - tensorflow::string(argv[1]) == tensorflow::tfprof::kCmds[1] || - tensorflow::string(argv[1]) == tensorflow::tfprof::kCmds[2] || - tensorflow::string(argv[1]) == tensorflow::tfprof::kCmds[3]) { - cmd = argv[1]; - } - } - - printf("Reading Files...\n"); - std::unique_ptr graph(new tensorflow::GraphDef()); - TF_CHECK_OK(tensorflow::tfprof::ReadGraphDef(tensorflow::Env::Default(), - FLAGS_graph_path, graph.get())); - - std::unique_ptr op_log( - new tensorflow::tfprof::OpLog()); - if (!FLAGS_op_log_path.empty()) { - tensorflow::string op_log_str; - s = tensorflow::ReadFileToString(tensorflow::Env::Default(), - FLAGS_op_log_path, &op_log_str); - if (!s.ok()) { - fprintf(stderr, "Failed to read op_log_path: %s\n", s.ToString().c_str()); - return 1; - } - if (!tensorflow::ParseProtoUnlimited(op_log.get(), op_log_str)) { - fprintf(stderr, "Failed to parse op_log_path\n"); - return 1; - } - } - - std::unique_ptr ckpt_reader; - TF_Status* status = TF_NewStatus(); - if (!FLAGS_checkpoint_path.empty()) { - ckpt_reader.reset(new tensorflow::checkpoint::CheckpointReader( - FLAGS_checkpoint_path, status)); - if (TF_GetCode(status) != TF_OK) { - fprintf(stderr, "%s\n", TF_Message(status)); - TF_DeleteStatus(status); - return 1; - } - TF_DeleteStatus(status); - } - - tensorflow::tfprof::TFStats tf_stat( - std::move(graph), nullptr, std::move(op_log), std::move(ckpt_reader)); - - std::vector run_meta_files = - Split(FLAGS_run_meta_path, ',', tensorflow::str_util::SkipEmpty()); - for (int i = 0; i < run_meta_files.size(); ++i) { - std::unique_ptr run_meta( - new tensorflow::RunMetadata()); - s = ReadBinaryProto(tensorflow::Env::Default(), run_meta_files[i], - run_meta.get()); - if (!s.ok()) { - fprintf(stderr, "Failed to read run_meta_path %s. Status: %s\n", - run_meta_files[i].c_str(), s.ToString().c_str()); - return 1; - } - tf_stat.ParseRunMeta(i, std::move(run_meta)); - } - - tensorflow::tfprof::Options opts( - FLAGS_max_depth, FLAGS_min_bytes, FLAGS_min_micros, FLAGS_min_params, - FLAGS_min_float_ops, FLAGS_min_occurrence, FLAGS_step, FLAGS_order_by, - account_type_regexes, start_name_regexes, trim_name_regexes, - show_name_regexes, hide_name_regexes, FLAGS_account_displayed_op_only, - select, output_type, output_options); - - if (cmd == tensorflow::tfprof::kCmds[2] || - cmd == tensorflow::tfprof::kCmds[3]) { - tf_stat.ShowMultiGraphNode(cmd, opts); - return 0; - } else if (cmd == tensorflow::tfprof::kCmds[0] || - cmd == tensorflow::tfprof::kCmds[1]) { - tf_stat.ShowGraphNode(cmd, opts); - return 0; - } - - linenoiseSetCompletionCallback(completion); - linenoiseHistoryLoad(".tfprof_history.txt"); - - for (char* line = nullptr; (line = linenoise("tfprof> ")) != nullptr;) { - tensorflow::string line_s = line; - free(line); - - if (line_s.empty()) { - printf("%s", opts.ToString().c_str()); - continue; - } - linenoiseHistoryAdd(line_s.c_str()); - linenoiseHistorySave(".tfprof_history.txt"); - - tensorflow::tfprof::Options new_opts = opts; - tensorflow::Status s = - tensorflow::tfprof::ParseCmdLine(line_s, &cmd, &new_opts); - if (!s.ok()) { - fprintf(stderr, "E: %s\n", s.ToString().c_str()); - continue; - } - if (cmd == tensorflow::tfprof::kCmds[4]) { - opts = new_opts; - } else if (cmd == tensorflow::tfprof::kCmds[5]) { - tensorflow::tfprof::PrintHelp(); - } else if (cmd == tensorflow::tfprof::kCmds[2] || - cmd == tensorflow::tfprof::kCmds[3]) { - tf_stat.ShowMultiGraphNode(cmd, new_opts); - } else if (cmd == tensorflow::tfprof::kCmds[0] || - cmd == tensorflow::tfprof::kCmds[1]) { - tf_stat.ShowGraphNode(cmd, new_opts); - } - } - return 0; -} diff --git a/tensorflow/tools/tfprof/tfprof_options.proto b/tensorflow/tools/tfprof/tfprof_options.proto deleted file mode 100644 index 27eafb1ca9c27a8f03324bf95b31715014d5d95b..0000000000000000000000000000000000000000 --- a/tensorflow/tools/tfprof/tfprof_options.proto +++ /dev/null @@ -1,26 +0,0 @@ -syntax = "proto2"; - -package tensorflow.tfprof; - -// Refers to tfprof_options.h/cc for documentation. -// Only used to pass tfprof options from Python to C++. -message OptionsProto { - optional int64 max_depth = 1; - optional int64 min_bytes = 2; - optional int64 min_micros = 3; - optional int64 min_params = 4; - optional int64 min_float_ops = 5; - optional int64 min_occurrence = 17; - optional int64 step = 18 [default = -1]; - - optional string order_by = 7; - repeated string account_type_regexes = 8; - repeated string start_name_regexes = 9; - repeated string trim_name_regexes = 10; - repeated string show_name_regexes = 11; - repeated string hide_name_regexes = 12; - optional bool account_displayed_op_only = 13; - repeated string select = 14; - optional string output = 15; - optional string dump_to_file = 16; -} diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index ec5922ada8fc0003b8dc63a746a55c3ebe2b848f..037f3536bfa92c8d0a1398830c3348216b2c1e74 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -4,14 +4,8 @@ load("//third_party/gpus:cuda_configure.bzl", "cuda_configure") load("//third_party/sycl:sycl_configure.bzl", "sycl_configure") load("@io_bazel_rules_closure//closure/private:java_import_external.bzl", "java_import_external") load("@io_bazel_rules_closure//closure:defs.bzl", "filegroup_external") -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library_external") load("//third_party/py:python_configure.bzl", "python_configure") -load("//third_party:polymer.bzl", "tensorboard_polymer_workspace") -load("//third_party:python.bzl", "tensorboard_python_workspace") -load("//third_party:js.bzl", "tensorboard_js_workspace") -load("//third_party:typings.bzl", "tensorboard_typings_workspace") - def _is_windows(repository_ctx): """Returns true if the host operating system is windows.""" @@ -150,12 +144,6 @@ def tf_workspace(path_prefix="", tf_repo_name=""): print("path_prefix was specified to tf_workspace but is no longer used " + "and will be removed in the future.") - # TODO(dandelion): Take these out when TB exits TF - tensorboard_polymer_workspace() - tensorboard_python_workspace() - tensorboard_typings_workspace() - tensorboard_js_workspace() - native.new_http_archive( name = "eigen_archive", urls = [ @@ -291,13 +279,46 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "six_archive", urls = [ "http://mirror.bazel.build/pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz", - "http://pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz", + "https://pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz", ], sha256 = "105f8d68616f8248e24bf0e9372ef04d3cc10104f1980f54d57b2ce73a5ad56a", strip_prefix = "six-1.10.0", build_file = str(Label("//third_party:six.BUILD")), ) + native.new_http_archive( + name = "org_python_pypi_backports_weakref", + urls = [ + "http://mirror.bazel.build/pypi.python.org/packages/bc/cc/3cdb0a02e7e96f6c70bd971bc8a90b8463fda83e264fa9c5c1c98ceabd81/backports.weakref-1.0rc1.tar.gz", + "https://pypi.python.org/packages/bc/cc/3cdb0a02e7e96f6c70bd971bc8a90b8463fda83e264fa9c5c1c98ceabd81/backports.weakref-1.0rc1.tar.gz", + ], + sha256 = "8813bf712a66b3d8b85dc289e1104ed220f1878cf981e2fe756dfaabe9a82892", + strip_prefix = "backports.weakref-1.0rc1/src", + build_file = str(Label("//third_party:backports_weakref.BUILD")), + ) + + native.new_http_archive( + name = "com_github_andreif_codegen", + urls = [ + "http://mirror.bazel.build/github.com/andreif/codegen/archive/1.0.tar.gz", + "https://github.com/andreif/codegen/archive/1.0.tar.gz", + ], + sha256 = "2dadd04a2802de27e0fe5a19b76538f6da9d39ff244036afa00c1bba754de5ee", + strip_prefix = "codegen-1.0", + build_file = str(Label("//third_party:codegen.BUILD")), + ) + + filegroup_external( + name = "org_python_license", + licenses = ["notice"], # Python 2.0 + sha256_urls = { + "b5556e921715ddb9242c076cae3963f483aa47266c5e37ea4c187f77cc79501c": [ + "http://mirror.bazel.build/docs.python.org/2.7/_sources/license.txt", + "https://docs.python.org/2.7/_sources/license.txt", + ], + }, + ) + native.bind( name = "six", actual = "@six_archive//:six", @@ -463,11 +484,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): temp_workaround_http_archive( name = "llvm", urls = [ - "http://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/e156d99231a7735d06a97b5b83de70bf4ce4f034.tar.gz", - "https://github.com/llvm-mirror/llvm/archive/e156d99231a7735d06a97b5b83de70bf4ce4f034.tar.gz", + "http://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/9e886358ff35a13de549b4adf49b52f933b9ec37.tar.gz", + "https://github.com/llvm-mirror/llvm/archive/9e886358ff35a13de549b4adf49b52f933b9ec37.tar.gz", ], - sha256 = "72e34e2411a06d4200a2688ee83832805fbef23a12ea481f31c2b8866fde007a", - strip_prefix = "llvm-e156d99231a7735d06a97b5b83de70bf4ce4f034", + sha256 = "5a56369e906e5af2d4baf5a92317a3db085800e848def7114aba176c80432ea0", + strip_prefix = "llvm-9e886358ff35a13de549b4adf49b52f933b9ec37", build_file = str(Label("//third_party/llvm:llvm.BUILD")), repository = tf_repo_name, ) @@ -499,7 +520,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): actual = "@jsoncpp_git//:jsoncpp", ) - native.http_archive( + patched_http_archive( name = "boringssl", urls = [ "http://mirror.bazel.build/github.com/google/boringssl/archive/bbcaa15b0647816b9a1a9b9e0d209cd6712f0105.tar.gz", @@ -507,6 +528,9 @@ def tf_workspace(path_prefix="", tf_repo_name=""): ], sha256 = "025264d6e9a7ad371f2f66d17a28b6627de0c9592dc2eb54afd062f68f1f9aa3", strip_prefix = "boringssl-bbcaa15b0647816b9a1a9b9e0d209cd6712f0105", + + # Add patch to boringssl code to support s390x + patch_file = str(Label("//third_party/boringssl:add_boringssl_s390x.patch")), ) native.new_http_archive( @@ -623,3 +647,18 @@ def tf_workspace(path_prefix="", tf_repo_name=""): build_file = str(Label("//third_party:pprof.BUILD")), ) + native.new_http_archive( + name = "cub_archive", + urls = [ + "http://mirror.bazel.build/github.com/NVlabs/cub/archive/1.6.4.zip", + "https://github.com/NVlabs/cub/archive/1.6.4.zip", + ], + sha256 = "966d0c4f41e2bdc81aebf9ccfbf0baffaac5a74f00b826b06f4dee79b2bb8cee", + strip_prefix = "cub-1.6.4", + build_file = str(Label("//third_party:cub.BUILD")), + ) + + native.bind( + name = "cub", + actual = "@cub_archive//:cub", + ) diff --git a/third_party/backports_weakref.BUILD b/third_party/backports_weakref.BUILD new file mode 100644 index 0000000000000000000000000000000000000000..0adfc5f05419e736b6af01252674e6fb11e6b8d7 --- /dev/null +++ b/third_party/backports_weakref.BUILD @@ -0,0 +1,22 @@ +# Description: +# Backport of new features in Python's weakref module. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Python 2.0 + +py_library( + name = "org_python_pypi_backports_weakref", + srcs = [ + "backports/__init__.py", + "backports/weakref.py", + ], + srcs_version = "PY2AND3", +) + +genrule( + name = "license", + srcs = ["@org_python_license"], + outs = ["LICENSE"], + cmd = "cp $< $@", +) diff --git a/third_party/bleach.BUILD b/third_party/bleach.BUILD deleted file mode 100644 index 1bf75b84a769642d74b9fdef78708eaffceb113e..0000000000000000000000000000000000000000 --- a/third_party/bleach.BUILD +++ /dev/null @@ -1,20 +0,0 @@ -# Description: -# Build file for Bleach. -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -py_library( - name = "org_mozilla_bleach", - srcs = [ - "bleach/__init__.py", - "bleach/callbacks.py", - "bleach/encoding.py", - "bleach/sanitizer.py", - "bleach/version.py", - ], - srcs_version = "PY2AND3", - deps = ["@org_html5lib"], -) diff --git a/tensorflow/tensorboard/components/tf_profile_dashboard/demo/data/plugin/profile/data_run_bar_tag_unsupported.json b/third_party/boringssl/BUILD similarity index 100% rename from tensorflow/tensorboard/components/tf_profile_dashboard/demo/data/plugin/profile/data_run_bar_tag_unsupported.json rename to third_party/boringssl/BUILD diff --git a/third_party/boringssl/add_boringssl_s390x.patch b/third_party/boringssl/add_boringssl_s390x.patch new file mode 100644 index 0000000000000000000000000000000000000000..0b41a4aa96831540bb55c69337bac1ed7b7cd651 --- /dev/null +++ b/third_party/boringssl/add_boringssl_s390x.patch @@ -0,0 +1,13 @@ +diff --git a/src/include/openssl/base.h b/src/include/openssl/base.h +index 7a3adfb..88012ad 100644 +--- a/src/include/openssl/base.h ++++ b/src/include/openssl/base.h +@@ -94,6 +94,8 @@ extern "C" { + #elif defined(__pnacl__) + #define OPENSSL_32_BIT + #define OPENSSL_PNACL ++#elif defined(__s390x__) ++#define OPENSSL_64_BIT + #else + #error "Unknown target CPU" + #endif diff --git a/third_party/clutz.BUILD b/third_party/clutz.BUILD deleted file mode 100644 index 593b70366a3a0908b91120ce5351fe7c2c0159b3..0000000000000000000000000000000000000000 --- a/third_party/clutz.BUILD +++ /dev/null @@ -1,44 +0,0 @@ -# Description: -# Build tool for making TypeScript .d.ts files from Closure JavaScript. - -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # MIT - -exports_files([ - "LICENSE", - "src/resources/closure.lib.d.ts", -]) - -JVM_FLAGS = [ - "-Xss20m", # JSCompiler needs big stacks for recursive parsing - "-XX:+UseParallelGC", # Best GC when app isn't latency sensitive -] - -java_binary( - name = "clutz", - srcs = glob(["src/main/java/com/google/javascript/clutz/**/*.java"]), - jvm_flags = JVM_FLAGS, - main_class = "com.google.javascript.clutz.DeclarationGenerator", - deps = [ - "@args4j", - "@com_google_code_findbugs_jsr305", - "@com_google_code_gson", - "@com_google_guava", - "@com_google_javascript_closure_compiler", - ], -) - -java_binary( - name = "gents", - srcs = glob(["src/main/java/com/google/javascript/gents/**/*.java"]), - jvm_flags = JVM_FLAGS, - main_class = "com.google.javascript.gents.TypeScriptGenerator", - deps = [ - "@args4j", - "@com_google_code_findbugs_jsr305", - "@com_google_code_gson", - "@com_google_guava", - "@com_google_javascript_closure_compiler", - ], -) diff --git a/third_party/clutz.bzl b/third_party/clutz.bzl deleted file mode 100644 index f273c78c794c637f96af52c1c1aa96b31acc5a24..0000000000000000000000000000000000000000 --- a/third_party/clutz.bzl +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Build definitions for TypeScript from Closure JavaScript libraries.""" - -load("@io_bazel_rules_closure//closure/private:defs.bzl", - "JS_FILE_TYPE", - "collect_js", - "unfurl") - -CLUTZ_ATTRIBUTES = { - "_clutz": attr.label( - default=Label("@io_angular_clutz//:clutz"), - executable=True, - cfg="host"), - "_clutz_externs": attr.label( - default=Label("@com_google_javascript_closure_compiler_externs"), - allow_files=True), -} - -def extract_dts_from_closure_libraries(ctx): - """Extracts type definitions from closure dependencies. - - This just generates one big .d.ts file for all transitive Closure sources, - and does not pass it down. That means each rule has to duplicate the effort, - but on the other hand allows transitive dependencies on shared rules without - causing duplicate definition errors. - - Args: - ctx: A Skylark context. - Returns: - The generated Clutz typings file, or None if there were no JS deps. - """ - deps = unfurl(ctx.attr.deps, provider="closure_js_library") - js = collect_js(ctx, deps) - if not js.srcs: - return None - js_typings = ctx.new_file(ctx.bin_dir, "%s-js-typings.d.ts" % ctx.label.name) - srcs = depset(JS_FILE_TYPE.filter(ctx.files._clutz_externs)) + js.srcs - args = ["-o", js_typings.path] - for src in srcs: - args.append(src.path) - if getattr(ctx.attr, "clutz_entry_points", None): - args.append("--closure_entry_points") - args.extend(ctx.attr.clutz_entry_points) - ctx.action( - inputs=list(srcs), - outputs=[js_typings], - executable=ctx.executable._clutz, - arguments=args, - mnemonic="Clutz", - progress_message="Running Clutz on %d JS files %s" % ( - len(srcs), ctx.label)) - return js_typings - -################################################################################ -# The following definitions are for API compatibility with internal clutz.bzl - -CLUTZ_OUTPUTS = {} - -def _clutz_aspect_impl(target, ctx): - return struct() - -clutz_aspect = aspect( - implementation=_clutz_aspect_impl, - attr_aspects=["exports"]) diff --git a/third_party/codegen.BUILD b/third_party/codegen.BUILD new file mode 100644 index 0000000000000000000000000000000000000000..df436c81635a71421a67fa8d8c84eb8dfcc97d7b --- /dev/null +++ b/third_party/codegen.BUILD @@ -0,0 +1,16 @@ +# -*- mode: python; -*- +# +# Description: +# Extension to ast that allow ast -> python code generation. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # New BSD + +exports_files(["LICENSE"]) + +py_library( + name = "com_github_andreif_codegen", + srcs = glob(["codegen.py"]), + srcs_version = "PY2AND3", +) diff --git a/third_party/cub.BUILD b/third_party/cub.BUILD new file mode 100644 index 0000000000000000000000000000000000000000..29159c9dad3d32121ce05278821e41b39f3f2a20 --- /dev/null +++ b/third_party/cub.BUILD @@ -0,0 +1,26 @@ +# Description: CUB library which is a set of primitives for GPU programming. + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # BSD + +exports_files(["LICENSE.TXT"]) + +load("@local_config_cuda//cuda:build_defs.bzl", "cuda_default_copts", "if_cuda") + +filegroup( + name = "cub_header_files", + srcs = glob([ + "cub/**", + ]), +) + +cc_library( + name = "cub", + hdrs = if_cuda([":cub_header_files"]), + deps = [ + "@local_config_cuda//cuda:cuda_headers", + ], +) diff --git a/third_party/gpus/crosstool/remote.BUILD.tpl b/third_party/gpus/crosstool/remote.BUILD.tpl new file mode 100644 index 0000000000000000000000000000000000000000..b2316331db257a39086bdd5ca02b5ca6848cebcb --- /dev/null +++ b/third_party/gpus/crosstool/remote.BUILD.tpl @@ -0,0 +1,10 @@ +# Description: +# Template for crosstool Build file to use a pre-generated config. +licenses(["restricted"]) + +package(default_visibility = ["//visibility:public"]) + +alias( + name = "toolchain", + actual = "%{remote_cuda_repo}:toolchain", +) diff --git a/third_party/gpus/cuda/BUILD.tpl b/third_party/gpus/cuda/BUILD.tpl index f7610dd7a99e3c65ac494d23f0a408d4391680c0..51d9e4e994ecc6084b3eabf19069db93b43a8165 100644 --- a/third_party/gpus/cuda/BUILD.tpl +++ b/third_party/gpus/cuda/BUILD.tpl @@ -40,20 +40,23 @@ config_setting( cc_library( name = "cuda_headers", hdrs = [ - "cuda_config.h", + "cuda/cuda_config.h", %{cuda_headers} ], includes = [ ".", - "include", + "cuda/include", ], visibility = ["//visibility:public"], ) cc_library( name = "cudart_static", - srcs = ["lib/%{cudart_static_lib}"], - includes = ["include"], + srcs = ["cuda/lib/%{cudart_static_lib}"], + includes = [ + ".", + "cuda/include", + ], linkopts = select({ ":freebsd": [], "//conditions:default": ["-ldl"], @@ -66,62 +69,83 @@ cc_library( cc_library( name = "cuda_driver", - srcs = ["lib/%{cuda_driver_lib}"], - includes = ["include"], + srcs = ["cuda/lib/%{cuda_driver_lib}"], + includes = [ + ".", + "cuda/include", + ], visibility = ["//visibility:public"], ) cc_library( name = "cudart", - srcs = ["lib/%{cudart_lib}"], - data = ["lib/%{cudart_lib}"], - includes = ["include"], + srcs = ["cuda/lib/%{cudart_lib}"], + data = ["cuda/lib/%{cudart_lib}"], + includes = [ + ".", + "cuda/include", + ], linkstatic = 1, visibility = ["//visibility:public"], ) cc_library( name = "cublas", - srcs = ["lib/%{cublas_lib}"], - data = ["lib/%{cublas_lib}"], - includes = ["include"], + srcs = ["cuda/lib/%{cublas_lib}"], + data = ["cuda/lib/%{cublas_lib}"], + includes = [ + ".", + "cuda/include", + ], linkstatic = 1, visibility = ["//visibility:public"], ) cc_library( name = "cusolver", - srcs = ["lib/%{cusolver_lib}"], - data = ["lib/%{cusolver_lib}"], - includes = ["include"], - linkstatic = 1, + srcs = ["cuda/lib/%{cusolver_lib}"], + data = ["cuda/lib/%{cusolver_lib}"], + includes = [ + ".", + "cuda/include", + ], linkopts = ["-lgomp"], + linkstatic = 1, visibility = ["//visibility:public"], ) cc_library( name = "cudnn", - srcs = ["lib/%{cudnn_lib}"], - data = ["lib/%{cudnn_lib}"], - includes = ["include"], + srcs = ["cuda/lib/%{cudnn_lib}"], + data = ["cuda/lib/%{cudnn_lib}"], + includes = [ + ".", + "cuda/include", + ], linkstatic = 1, visibility = ["//visibility:public"], ) cc_library( name = "cufft", - srcs = ["lib/%{cufft_lib}"], - data = ["lib/%{cufft_lib}"], - includes = ["include"], + srcs = ["cuda/lib/%{cufft_lib}"], + data = ["cuda/lib/%{cufft_lib}"], + includes = [ + ".", + "cuda/include", + ], linkstatic = 1, visibility = ["//visibility:public"], ) cc_library( name = "curand", - srcs = ["lib/%{curand_lib}"], - data = ["lib/%{curand_lib}"], - includes = ["include"], + srcs = ["cuda/lib/%{curand_lib}"], + data = ["cuda/lib/%{curand_lib}"], + includes = [ + ".", + "cuda/include", + ], linkstatic = 1, visibility = ["//visibility:public"], ) @@ -129,9 +153,9 @@ cc_library( cc_library( name = "cuda", deps = [ + ":cublas", ":cuda_headers", ":cudart", - ":cublas", ":cudnn", ":cufft", ":curand", @@ -142,19 +166,23 @@ cc_library( cc_library( name = "cupti_headers", hdrs = [ - "cuda_config.h", + "cuda/cuda_config.h", ":cuda-extras", ], includes = [ ".", - "extras/CUPTI/include/", + "cuda/extras/CUPTI/include/", ], visibility = ["//visibility:public"], ) cc_library( name = "cupti_dsos", - data = ["lib/%{cupti_lib}"], + data = ["cuda/lib/%{cupti_lib}"], + includes = [ + ".", + "cuda/include", + ], visibility = ["//visibility:public"], ) diff --git a/third_party/gpus/cuda/remote.BUILD.tpl b/third_party/gpus/cuda/remote.BUILD.tpl new file mode 100644 index 0000000000000000000000000000000000000000..d88d512b90c352e6a301ed6efe8266d8dd6bf744 --- /dev/null +++ b/third_party/gpus/cuda/remote.BUILD.tpl @@ -0,0 +1,105 @@ +# Description: +# Template for cuda Build file to use a pre-generated config. +licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like + +package(default_visibility = ["//visibility:public"]) + +config_setting( + name = "using_nvcc", + values = { + "define": "using_cuda_nvcc=true", + }, +) + +config_setting( + name = "using_clang", + values = { + "define": "using_cuda_clang=true", + }, +) + +# Equivalent to using_clang && -c opt. +config_setting( + name = "using_clang_opt", + values = { + "define": "using_cuda_clang=true", + "compilation_mode": "opt", + }, +) + +config_setting( + name = "darwin", + values = {"cpu": "darwin"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "freebsd", + values = {"cpu": "freebsd"}, + visibility = ["//visibility:public"], +) + +alias( + name = "cuda_headers", + actual = "%{remote_cuda_repo}cuda:cuda_headers", +) + +alias( + name = "cudart_static", + actual = "%{remote_cuda_repo}cuda:cudart_static", +) + +alias( + name = "cuda_driver", + actual = "%{remote_cuda_repo}cuda:cuda_driver", +) + +alias( + name = "cudart", + actual = "%{remote_cuda_repo}cuda:cudart", +) + +alias( + name = "cublas", + actual = "%{remote_cuda_repo}cuda:cublas", +) + +alias( + name = "cusolver", + actual = "%{remote_cuda_repo}cuda:cusolver", +) + +alias( + name = "cudnn", + actual = "%{remote_cuda_repo}cuda:cudnn", +) + +alias( + name = "cufft", + actual = "%{remote_cuda_repo}cuda:cufft", +) + +alias( + name = "curand", + actual = "%{remote_cuda_repo}cuda:curand", +) + +alias( + name = "cuda", + actual = "%{remote_cuda_repo}cuda:cuda", +) + +alias( + name = "cupti_headers", + actual = "%{remote_cuda_repo}cuda:cupti_headers", +) + +alias( + name = "cupti_dsos", + actual = "%{remote_cuda_repo}cuda:cupti_dsos", +) + +alias( + name = "libdevice_root", + actual = "%{remote_cuda_repo}cuda:libdevice_root", +) diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index 61932a8e6d1a699392c4de73ee36ed681d9eda94..4dd3169d418797fbda656d33c53e3f147b38725d 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -26,6 +26,7 @@ _TF_CUDA_VERSION = "TF_CUDA_VERSION" _TF_CUDNN_VERSION = "TF_CUDNN_VERSION" _CUDNN_INSTALL_PATH = "CUDNN_INSTALL_PATH" _TF_CUDA_COMPUTE_CAPABILITIES = "TF_CUDA_COMPUTE_CAPABILITIES" +_TF_CUDA_CONFIG_REPO = "TF_CUDA_CONFIG_REPO" _DEFAULT_CUDA_VERSION = "" _DEFAULT_CUDNN_VERSION = "" @@ -739,19 +740,19 @@ def _create_dummy_repository(repository_ctx): # Create dummy files for the CUDA toolkit since they are still required by # tensorflow/core/platform/default/build_config:cuda. - repository_ctx.file("cuda/include/cuda.h", "") - repository_ctx.file("cuda/include/cublas.h", "") - repository_ctx.file("cuda/include/cudnn.h", "") - repository_ctx.file("cuda/extras/CUPTI/include/cupti.h", "") - repository_ctx.file("cuda/lib/%s" % _lib_name("cuda", cpu_value)) - repository_ctx.file("cuda/lib/%s" % _lib_name("cudart", cpu_value)) - repository_ctx.file("cuda/lib/%s" % _lib_name("cudart_static", cpu_value)) - repository_ctx.file("cuda/lib/%s" % _lib_name("cublas", cpu_value)) - repository_ctx.file("cuda/lib/%s" % _lib_name("cusolver", cpu_value)) - repository_ctx.file("cuda/lib/%s" % _lib_name("cudnn", cpu_value)) - repository_ctx.file("cuda/lib/%s" % _lib_name("curand", cpu_value)) - repository_ctx.file("cuda/lib/%s" % _lib_name("cufft", cpu_value)) - repository_ctx.file("cuda/lib/%s" % _lib_name("cupti", cpu_value)) + repository_ctx.file("cuda/cuda/include/cuda.h", "") + repository_ctx.file("cuda/cuda/include/cublas.h", "") + repository_ctx.file("cuda/cuda/include/cudnn.h", "") + repository_ctx.file("cuda/cuda/extras/CUPTI/include/cupti.h", "") + repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cuda", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cudart", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cudart_static", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cublas", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cusolver", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cudnn", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("curand", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cufft", cpu_value)) + repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cupti", cpu_value)) # Set up cuda_config.h, which is used by # tensorflow/stream_executor/dso_loader.cc. @@ -763,7 +764,7 @@ def _create_dummy_repository(repository_ctx): "CudaVersion(\"%s\")" % c for c in _DEFAULT_CUDA_COMPUTE_CAPABILITIES]), "%{cuda_toolkit_path}": _DEFAULT_CUDA_TOOLKIT_PATH, - }) + }, "cuda/cuda/cuda_config.h") # If cuda_configure is not configured to build with GPU support, and the user # attempts to build with --config=cuda, add a dummy build rule to intercept @@ -820,6 +821,13 @@ def _symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name, dest_files = files.replace(src_dir, '').splitlines() src_files = files.splitlines() command = [] + if not _is_windows(repository_ctx): + # We clear folders that might have been generated previously to avoid + # undesired inclusions + command.append('if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi') + command.append('if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi') + command.append('if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi') + command.append('if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi') outs = [] for i in range(len(dest_files)): if dest_files[i] != "": @@ -829,7 +837,7 @@ def _symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name, # On Windows, symlink is not supported, so we just copy all the files. cmd = 'cp -f' if _is_windows(repository_ctx) else 'ln -s' command.append(cmd + ' "%s" "%s"' % (src_files[i] , dest)) - outs.append(' "' + dest_dir + dest_files[i] + '",') + outs.append(' "' + dest_dir + dest_files[i] + '",') genrule = _genrule(src_dir, genrule_name, " && ".join(command), "\n".join(outs)) return genrule @@ -846,11 +854,11 @@ def _genrule(src_dir, genrule_name, command, outs): genrule_name + '",\n' + ' outs = [\n' + outs + - ' ],\n' + + '\n ],\n' + ' cmd = """\n' + command + - ' """,\n' + - ')\n\n' + '\n """,\n' + + ')\n' ) @@ -883,15 +891,16 @@ def _use_cuda_clang(repository_ctx): return enable_cuda == "1" return False -def _compute_cuda_extra_copts(repository_ctx, cuda_config): +def _compute_cuda_extra_copts(repository_ctx, compute_capabilities): if _use_cuda_clang(repository_ctx): - capability_flags = ["--cuda-gpu-arch=sm_" + cap.replace(".", "") for cap in cuda_config.compute_capabilities] + capability_flags = ["--cuda-gpu-arch=sm_" + + cap.replace(".", "") for cap in compute_capabilities] else: # Capabilities are handled in the "crosstool_wrapper_driver_is_not_gcc" for nvcc capability_flags = [] return str(capability_flags) -def _create_cuda_repository(repository_ctx): +def _create_local_cuda_repository(repository_ctx): """Creates the repository containing files set up to build with CUDA.""" cuda_config = _get_cuda_config(repository_ctx) @@ -904,19 +913,19 @@ def _create_cuda_repository(repository_ctx): cuda_toolkit_path = cuda_config.cuda_toolkit_path cuda_include_path = cuda_toolkit_path + "/include" genrules = [_symlink_genrule_for_dir(repository_ctx, - cuda_include_path, "include", "cuda-include")] + cuda_include_path, "cuda/include", "cuda-include")] genrules.append(_symlink_genrule_for_dir(repository_ctx, - cuda_toolkit_path + "/nvvm", "nvvm", "cuda-nvvm")) + cuda_toolkit_path + "/nvvm", "cuda/nvvm", "cuda-nvvm")) genrules.append(_symlink_genrule_for_dir(repository_ctx, cuda_toolkit_path + "/extras/CUPTI/include", - "extras/CUPTI/include", "cuda-extras")) + "cuda/extras/CUPTI/include", "cuda-extras")) cuda_libs = _find_libs(repository_ctx, cuda_config) cuda_lib_src = [] cuda_lib_dest = [] for lib in cuda_libs.values(): cuda_lib_src.append(lib.path) - cuda_lib_dest.append("lib/" + lib.file_name) + cuda_lib_dest.append("cuda/lib/" + lib.file_name) genrules.append(_symlink_genrule_for_dir(repository_ctx, None, "", "cuda-lib", cuda_lib_src, cuda_lib_dest)) @@ -925,8 +934,9 @@ def _create_cuda_repository(repository_ctx): included_files = _read_dir(repository_ctx, cuda_include_path).replace( cuda_include_path, '').splitlines() if '/cudnn.h' not in included_files: - genrules.append(_symlink_genrule_for_dir(repository_ctx, None, "include/", - "cudnn-include", [cudnn_header_dir + "/cudnn.h"], ["cudnn.h"])) + genrules.append(_symlink_genrule_for_dir(repository_ctx, None, + "cuda/include/", "cudnn-include", [cudnn_header_dir + "/cudnn.h"], + ["cudnn.h"])) else: genrules.append( 'filegroup(\n' + @@ -939,7 +949,8 @@ def _create_cuda_repository(repository_ctx): _tpl(repository_ctx, "cuda:build_defs.bzl", { "%{cuda_is_configured}": "True", - "%{cuda_extra_copts}": _compute_cuda_extra_copts(repository_ctx, cuda_config), + "%{cuda_extra_copts}": _compute_cuda_extra_copts( + repository_ctx, cuda_config.compute_capabilities), }) _tpl(repository_ctx, "cuda:BUILD", @@ -997,16 +1008,35 @@ def _create_cuda_repository(repository_ctx): ["CudaVersion(\"%s\")" % c for c in cuda_config.compute_capabilities]), "%{cuda_toolkit_path}": cuda_config.cuda_toolkit_path, - }) + }, "cuda/cuda/cuda_config.h") + +def _create_remote_cuda_repository(repository_ctx, remote_config_repo): + """Creates pointers to a remotely configured repo set up to build with CUDA.""" + _tpl(repository_ctx, "cuda:build_defs.bzl", + { + "%{cuda_is_configured}": "True", + "%{cuda_extra_copts}": _compute_cuda_extra_copts( + repository_ctx, _compute_capabilities(repository_ctx)), + }) + _tpl(repository_ctx, "cuda:remote.BUILD", + { + "%{remote_cuda_repo}": remote_config_repo, + }, "cuda/BUILD") + _tpl(repository_ctx, "crosstool:remote.BUILD", { + "%{remote_cuda_repo}": remote_config_repo, + }, "crosstool/BUILD") def _cuda_autoconf_impl(repository_ctx): """Implementation of the cuda_autoconf repository rule.""" if not _enable_cuda(repository_ctx): _create_dummy_repository(repository_ctx) else: - _create_cuda_repository(repository_ctx) - + if _TF_CUDA_CONFIG_REPO in repository_ctx.os.environ: + _create_remote_cuda_repository(repository_ctx, + repository_ctx.os.environ[_TF_CUDA_CONFIG_REPO]) + else: + _create_local_cuda_repository(repository_ctx) cuda_configure = repository_rule( @@ -1019,6 +1049,7 @@ cuda_configure = repository_rule( _TF_CUDA_VERSION, _TF_CUDNN_VERSION, _TF_CUDA_COMPUTE_CAPABILITIES, + _TF_CUDA_CONFIG_REPO, ], ) diff --git a/third_party/html5lib.BUILD b/third_party/html5lib.BUILD deleted file mode 100644 index 63aac14f1559a86f626a5d99db973111f86f92ae..0000000000000000000000000000000000000000 --- a/third_party/html5lib.BUILD +++ /dev/null @@ -1,17 +0,0 @@ -# Description: -# Import of html5lib library. - -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # BSD-like notice-style license, see LICENSE file - -exports_files(["LICENSE"]) - -py_library( - name = "org_html5lib", - srcs = glob(["html5lib/**/*.py"]), - srcs_version = "PY2AND3", - deps = [ - "@six_archive//:six", - ], -) diff --git a/third_party/js.bzl b/third_party/js.bzl deleted file mode 100644 index 2d2339c95e5b537ae9ba0ebe8044808ebe411a36..0000000000000000000000000000000000000000 --- a/third_party/js.bzl +++ /dev/null @@ -1,420 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the 'License'); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an 'AS IS' BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# TensorBoard external JS dependencies (both infrastructure and frontend libs) -load("@io_bazel_rules_closure//closure:defs.bzl", "filegroup_external") -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library_external") - - - ############################################################################## - # TensorBoard Build Tools -def tensorboard_js_workspace(): - filegroup_external( - name = "org_nodejs", - # MIT with portions licensed: - # - MIT - # - Old MIT - # - 2-Clause-BSD - # - 3-Clause-BSD - # - ISC - # - Unicode - # - zlib - # - Artistic 2.0 - licenses = ["notice"], - sha256_urls_extract_macos = { - "47109a00cac344d80296c195451bb5eee7c21727fcef1594384ddfe1f852957a": [ - "http://mirror.bazel.build/nodejs.org/dist/v4.3.2/node-v4.3.2-darwin-x64.tar.xz", - "http://nodejs.org/dist/v4.3.2/node-v4.3.2-darwin-x64.tar.xz", - ], - }, - sha256_urls_windows = { - "3d4cfca9dcec556a077a2324bf5bd165ea3e6e64a2bfd7fc6e7a1f0dc4eb552b": [ - "http://mirror.bazel.build/raw.githubusercontent.com/nodejs/node/v4.3.2/LICENSE", - "https://raw.githubusercontent.com/nodejs/node/v4.3.2/LICENSE", - ], - "606c44c42d17866c017c50c0afadad411d9492ac4281d2431b937f881911614e": [ - "http://mirror.bazel.build/nodejs.org/dist/v4.3.2/win-x64/node.exe", - "http://nodejs.org/dist/v4.3.2/win-x64/node.exe", - ], - "451a40570099a95488d6438f175813629e0430f87f23c8659bc18dc42494820a": [ - "http://mirror.bazel.build/nodejs.org/dist/v4.3.2/win-x64/node.lib", - "http://nodejs.org/dist/v4.3.2/win-x64/node.lib", - ], - }, - sha256_urls_extract = { - "4350d0431b49697517c6cca5d66adf5f74eb9101c52f52ae959fa94225822d44": [ - "http://mirror.bazel.build/nodejs.org/dist/v4.3.2/node-v4.3.2-linux-x64.tar.xz", - "http://nodejs.org/dist/v4.3.2/node-v4.3.2-linux-x64.tar.xz", - ], - }, - strip_prefix = { - "node-v4.3.2-darwin-x64.tar.xz": "node-v4.3.2-darwin-x64", - "node-v4.3.2-linux-x64.tar.xz": "node-v4.3.2-linux-x64", - }, - executable = [ - "node", - "node.exe", - ], - ) - - filegroup_external( - name = "com_microsoft_typescript", - licenses = ["notice"], # Apache 2.0 - sha256_urls = { - "a7d00bfd54525bc694b6e32f64c7ebcf5e6b7ae3657be5cc12767bce74654a47": [ - "http://mirror.bazel.build/raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/LICENSE.txt", - "https://raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/LICENSE.txt", - ], - "8465342c318f9c4cf0a29b109fa63ee3742dd4dc7080d05d9fd8f604814d04cf": [ - "http://mirror.bazel.build/raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/lib/tsc.js", - "https://raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/lib/tsc.js", - ], - "a67e36da3029d232e4e938e61a0a3302f516d71e7100d54dbf5362ad8618e994": [ - "http://mirror.bazel.build/raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/lib/lib.es6.d.ts", - "https://raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/lib/lib.es6.d.ts", - ], - }, - extra_build_file_content = "\n".join([ - "sh_binary(", - " name = \"tsc\",", - " srcs = [\"tsc.sh\"],", - " data = [", - " \"tsc.js\",", - " \"@org_nodejs\",", - " ],", - ")", - "", - "genrule(", - " name = \"tsc_sh\",", - " outs = [\"tsc.sh\"],", - " cmd = \"cat >$@ <<'EOF'\\n\" +", - " \"#!/bin/bash\\n\" +", - " \"NODE=external/org_nodejs/bin/node\\n\" +", - " \"if [[ -e external/org_nodejs/node.exe ]]; then\\n\" +", - " \" NODE=external/org_nodejs/node.exe\\n\" +", - " \"fi\\n\" +", - " \"exec $${NODE} external/com_microsoft_typescript/tsc.js \\\"$$@\\\"\\n\" +", - " \"EOF\",", - " executable = True,", - ")", - ]), - ) - - - native.new_http_archive( - name = "io_angular_clutz", - build_file = "//third_party:clutz.BUILD", - sha256 = "2981de41d1ff4774b544423da9a2cd8beb3be649e95aef2ef2fd83957300b3fe", - strip_prefix = "clutz-b0db5ade9bb535d387f05292316c422790c9848e", - urls = [ - "http://mirror.bazel.build/github.com/angular/clutz/archive/b0db5ade9bb535d387f05292316c422790c9848e.tar.gz", # 2017-05-22 - "https://github.com/angular/clutz/archive/b0db5ade9bb535d387f05292316c422790c9848e.tar.gz", - ], - ) - - filegroup_external( - name = "com_google_javascript_closure_compiler_externs", - licenses = ["notice"], # Apache 2.0 - sha256_urls_extract = { - "0f515a6ebfa138490b3c5ea9f3591ea1a7e4a930d3074f18b3eca86084ad9b66": [ - "http://mirror.bazel.build/github.com/google/closure-compiler/archive/b37e6000001b0a6bf4c0be49024ebda14a8711d9.tar.gz", # 2017-06-02 - "https://github.com/google/closure-compiler/archive/b37e6000001b0a6bf4c0be49024ebda14a8711d9.tar.gz", - ], - }, - strip_prefix = {"b37e6000001b0a6bf4c0be49024ebda14a8711d9.tar.gz": "closure-compiler-b37e6000001b0a6bf4c0be49024ebda14a8711d9/externs"}, - ) - - filegroup_external( - name = "com_google_javascript_closure_compiler_externs_polymer", - licenses = ["notice"], # Apache 2.0 - sha256_urls = { - "23baad9a200a717a821c6df504c84d3a893d7ea9102b14876eb80097e3b94292": [ - "http://mirror.bazel.build/raw.githubusercontent.com/google/closure-compiler/0e8dc5597a295ee259e3fecd98d6535dc621232f/contrib/externs/polymer-1.0.js", # 2017-05-27 - "https://raw.githubusercontent.com/google/closure-compiler/0e8dc5597a295ee259e3fecd98d6535dc621232f/contrib/externs/polymer-1.0.js", - ], - }, - ) - - filegroup_external( - name = "org_threejs", - # no @license header - licenses = ["notice"], # MIT - sha256_urls = { - "7aff264bd84c90bed3c72a4dc31db8c19151853c6df6980f52b01d3e9872c82d": [ - "http://mirror.bazel.build/raw.githubusercontent.com/mrdoob/three.js/ad419d40bdaab80abbb34b8f359b4ee840033a02/build/three.js", - "https://raw.githubusercontent.com/mrdoob/three.js/ad419d40bdaab80abbb34b8f359b4ee840033a02/build/three.js", - ], - "0e98ded15bb7fe398a655667e76b39909d36c0973a8950d01c62f65f93161c27": [ - "http://mirror.bazel.build/raw.githubusercontent.com/mrdoob/three.js/ad419d40bdaab80abbb34b8f359b4ee840033a02/examples/js/controls/OrbitControls.js", - "https://raw.githubusercontent.com/mrdoob/three.js/ad419d40bdaab80abbb34b8f359b4ee840033a02/examples/js/controls/OrbitControls.js", - ], - }, - ) - - ############################################################################## - # TensorBoard JavaScript Production Dependencies - web_library_external( - name = "com_lodash", - licenses = ["notice"], # MIT - sha256 = "0e88207e5f90af4ce8790d6e1e7d09d2702d81bce0bafdc253d18c0a5bf7661e", - urls = [ - "http://mirror.bazel.build/github.com/lodash/lodash/archive/3.10.1.tar.gz", - "https://github.com/lodash/lodash/archive/3.10.1.tar.gz", - ], - strip_prefix = "lodash-3.10.1", - path = "/lodash", - srcs = ["lodash.js"], - ) - - filegroup_external( - name = "com_numericjs", - # no @license header - licenses = ["notice"], # MIT - sha256_urls = { - "0e94aada97f12dee6118064add9170484c55022f5d53206ee4407143cd36ddcd": [ - "http://mirror.bazel.build/raw.githubusercontent.com/sloisel/numeric/v1.2.6/license.txt", - "https://raw.githubusercontent.com/sloisel/numeric/v1.2.6/license.txt", - ], - "dfaca3b8485bee735788cc6eebca82ea25719adc1fb8911c7799c6bd5a95df3b": [ - "http://mirror.bazel.build/raw.githubusercontent.com/sloisel/numeric/v1.2.6/src/numeric.js", - "https://raw.githubusercontent.com/sloisel/numeric/v1.2.6/src/numeric.js", - ], - }, - ) - - filegroup_external( - name = "com_palantir_plottable", - # no @license header - licenses = ["notice"], # MIT - sha256_urls_extract = { - # Plottable doesn't have a release tarball on GitHub. Using the - # sources directly from git also requires running Node tooling - # beforehand to generate files. NPM is the only place to get it. - "e3159beb279391c47433789f22b32bac88488cfcad6c0b6ec8605ce6b0081b0d": [ - "http://mirror.bazel.build/registry.npmjs.org/plottable/-/plottable-3.1.0.tgz", - "https://registry.npmjs.org/plottable/-/plottable-3.1.0.tgz", - ], - }, - ) - - filegroup_external( - name = "io_github_cpettitt_dagre", - # no @license header - licenses = ["notice"], # MIT - sha256_urls = { - "6a349742a6cb219d5a2fc8d0844f6d89a6efc62e20c664450d884fc7ff2d6015": [ - "http://mirror.bazel.build/raw.githubusercontent.com/cpettitt/dagre/v0.7.4/LICENSE", - "https://raw.githubusercontent.com/cpettitt/dagre/v0.7.4/LICENSE", - ], - "7323829ddd77924a69e2b1235ded3eac30acd990da0f037e0fbd3c8e9035b50d": [ - "http://mirror.bazel.build/raw.githubusercontent.com/cpettitt/dagre/v0.7.4/dist/dagre.core.js", - "https://raw.githubusercontent.com/cpettitt/dagre/v0.7.4/dist/dagre.core.js", - ], - }, - ) - - filegroup_external( - name = "io_github_cpettitt_graphlib", - licenses = ["notice"], # MIT - sha256_urls = { - "6a349742a6cb219d5a2fc8d0844f6d89a6efc62e20c664450d884fc7ff2d6015": [ - "http://mirror.bazel.build/raw.githubusercontent.com/cpettitt/graphlib/v1.0.7/LICENSE", - "https://raw.githubusercontent.com/cpettitt/graphlib/v1.0.7/LICENSE", - ], - "772045d412b1513b549be991c2e1846c38019429d43974efcae943fbe83489bf": [ - "http://mirror.bazel.build/raw.githubusercontent.com/cpettitt/graphlib/v1.0.7/dist/graphlib.core.js", - "https://raw.githubusercontent.com/cpettitt/graphlib/v1.0.7/dist/graphlib.core.js", - ], - }, - ) - - filegroup_external( - name = "io_github_waylonflinn_weblas", - # no @license header - licenses = ["notice"], # MIT - sha256_urls = { - "633f2861a9a862b9cd7967e841e14dd3527912f209d6563595774fa31e3d84cb": [ - "http://mirror.bazel.build/raw.githubusercontent.com/waylonflinn/weblas/v0.9.0/LICENSES", - "https://raw.githubusercontent.com/waylonflinn/weblas/v0.9.0/LICENSE", - ], - "f138fce57f673ca8a633f4aee5ae5b6fcb6ad0de59069a42a74e996fd04d8fcc": [ - "http://mirror.bazel.build/raw.githubusercontent.com/waylonflinn/weblas/v0.9.0/dist/weblas.js", - "https://raw.githubusercontent.com/waylonflinn/weblas/v0.9.0/dist/weblas.js", - ], - }, - ) - - filegroup_external( - name = "org_d3js", - # no @license header - licenses = ["notice"], # BSD-3-Clause - sha256_urls_extract = { - "b5fac5b296bc196e6aa7b59f9e33986fc44d23d59a0e211705187be9e35b943d": [ - "http://mirror.bazel.build/github.com/d3/d3/releases/download/v4.8.0/d3.zip", - "https://github.com/d3/d3/releases/download/v4.8.0/d3.zip", - ], - }, - # TODO(jart): Use srcs=["d3.js"] instead of this once supported. - generated_rule_name = "all_files", - extra_build_file_content = "\n".join([ - "filegroup(", - " name = \"org_d3js\",", - " srcs = [\"d3.js\"],", - ")", - ]), - ) - - filegroup_external( - name = "org_chromium_catapult_vulcanized_trace_viewer", - licenses = ["notice"], # BSD-3-Clause - sha256_urls = { - "f0df289ba9d03d857ad1c2f5918861376b1510b71588ffc60eff5c7a7bfedb09": [ - "http://mirror.bazel.build/raw.githubusercontent.com/catapult-project/catapult/2f7ee994984f3ebd3dd3dc3e05777bf180ec2ee8/LICENSE", - "https://raw.githubusercontent.com/catapult-project/catapult/2f7ee994984f3ebd3dd3dc3e05777bf180ec2ee8/LICENSE", - ], - "9e99e79439ea5a1471bd4dd325bd6733e133bcb3da4df4b878ed6d2aec7c8d86": [ - "http://mirror.bazel.build/raw.githubusercontent.com/catapult-project/catapult/2f7ee994984f3ebd3dd3dc3e05777bf180ec2ee8/trace_viewer_full.html", - "https://raw.githubusercontent.com/catapult-project/catapult/2f7ee994984f3ebd3dd3dc3e05777bf180ec2ee8/trace_viewer_full.html" - ], - }, - ) - - ############################################################################## - # TensorBoard Testing Dependencies - web_library_external( - name = "org_npmjs_registry_accessibility_developer_tools", - licenses = ["notice"], # Apache License 2.0 - sha256 = "1d6a72f401c9d53f68238c617dd43a05cd85ca5aa2e676a5b3c352711448e093", - urls = [ - "http://mirror.bazel.build/registry.npmjs.org/accessibility-developer-tools/-/accessibility-developer-tools-2.10.0.tgz", - "https://registry.npmjs.org/accessibility-developer-tools/-/accessibility-developer-tools-2.10.0.tgz", - ], - strip_prefix = "package", - path = "/accessibility-developer-tools", - suppress = ["strictDependencies"], - ) - - web_library_external( - name = "org_npmjs_registry_async", - licenses = ["notice"], # MIT - sha256 = "08655255ae810bf4d1cb1642df57658fcce823776d3ba8f4b46f4bbff6c87ece", - urls = [ - "http://mirror.bazel.build/registry.npmjs.org/async/-/async-1.5.0.tgz", - "https://registry.npmjs.org/async/-/async-1.5.0.tgz", - ], - strip_prefix = "package", - path = "/async", - ) - - web_library_external( - name = "org_npmjs_registry_chai", - licenses = ["notice"], # MIT - sha256 = "aca8137bed5bb295bd7173325b7ad604cd2aeb341d739232b4f9f0b26745be90", - urls = [ - "http://mirror.bazel.build/registry.npmjs.org/chai/-/chai-3.5.0.tgz", - "https://registry.npmjs.org/chai/-/chai-3.5.0.tgz", - ], - strip_prefix = "package", - path = "/chai", - ) - - web_library_external( - name = "org_npmjs_registry_mocha", - licenses = ["notice"], # MIT - sha256 = "13ef37a071196a2fba680799b906555d3f0ab61e80a7e8f73f93e77914590dd4", - urls = [ - "http://mirror.bazel.build/registry.npmjs.org/mocha/-/mocha-2.5.3.tgz", - "https://registry.npmjs.org/mocha/-/mocha-2.5.3.tgz", - ], - suppress = ["strictDependencies"], - strip_prefix = "package", - path = "/mocha", - ) - - web_library_external( - name = "org_npmjs_registry_sinon", - licenses = ["notice"], # BSD-3-Clause - sha256 = "49edb057695fc9019aae992bf7e677a07de7c6ce2bf9f9facde4a245045d1532", - urls = [ - "http://mirror.bazel.build/registry.npmjs.org/sinon/-/sinon-1.17.4.tgz", - "https://registry.npmjs.org/sinon/-/sinon-1.17.4.tgz", - ], - strip_prefix = "package/lib", - path = "/sinonjs", - ) - - web_library_external( - name = "org_npmjs_registry_sinon_chai", - licenses = ["notice"], # BSD-3-Clause - sha256 = "b85fc56f713832960b56fe9269ee4bb2cd41edd2ceb130b0936e5bdbed5dea63", - urls = [ - "http://mirror.bazel.build/registry.npmjs.org/sinon-chai/-/sinon-chai-2.8.0.tgz", - "https://registry.npmjs.org/sinon-chai/-/sinon-chai-2.8.0.tgz", - ], - strip_prefix = "package", - path = "/sinon-chai", - ) - - web_library_external( - name = "org_npmjs_registry_stacky", - licenses = ["notice"], # BSD-3-Clause - sha256 = "c659e60f7957d9d80c23a7aacc4d71b19c6421a08f91174c0062de369595acae", - urls = [ - "http://mirror.bazel.build/registry.npmjs.org/stacky/-/stacky-1.3.1.tgz", - "https://registry.npmjs.org/stacky/-/stacky-1.3.1.tgz", - ], - strip_prefix = "package", - path = "/stacky", - ) - - web_library_external( - name = "org_npmjs_registry_web_component_tester", - licenses = ["notice"], # BSD-3-Clause - sha256 = "9d4ebd4945df8a936916d4d32b7f280f2a3afa35f79e7ca8ad3ed0a42770c537", - urls = [ - "http://mirror.bazel.build/registry.npmjs.org/web-component-tester/-/web-component-tester-4.3.6.tgz", - "https://registry.npmjs.org/web-component-tester/-/web-component-tester-4.3.6.tgz", - ], - strip_prefix = "package", - path = "/web-component-tester", - suppress = [ - "absolutePaths", - "strictDependencies", - ], - deps = [ - "@com_lodash", - "@org_npmjs_registry_accessibility_developer_tools", - "@org_npmjs_registry_async", - "@org_npmjs_registry_chai", - "@org_npmjs_registry_mocha", - "@org_npmjs_registry_sinon", - "@org_npmjs_registry_sinon_chai", - "@org_npmjs_registry_stacky", - "@org_polymer_test_fixture", - ], - ) - - web_library_external( - name = "org_polymer_test_fixture", - licenses = ["notice"], # BSD-3-Clause - sha256 = "59d6cfb1187733b71275becfea181fe0aa1f734df5ff77f5850c806bbbf9a0d9", - strip_prefix = "test-fixture-2.0.1", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/test-fixture/archive/v2.0.1.tar.gz", - "https://github.com/PolymerElements/test-fixture/archive/v2.0.1.tar.gz", - ], - path = "/test-fixture", - exclude = ["test/**"], - ) - diff --git a/third_party/llvm/llvm.BUILD b/third_party/llvm/llvm.BUILD index 32266997a7e54c09525a60a48d2ad330941e2668..3b13b297f8ab63a23859d83ea7882aa2f3869f56 100644 --- a/third_party/llvm/llvm.BUILD +++ b/third_party/llvm/llvm.BUILD @@ -257,6 +257,16 @@ cc_library( includes = ["include"], ) +# A creator of an empty file include/llvm/Support/VCSRevision.h. +# This is usually populated by the upstream build infrastructure, but in this +# case we leave it blank. See upstream revision r300160. +genrule( + name = "vcs_revision_gen", + srcs = [], + outs = ["include/llvm/Support/VCSRevision.h"], + cmd = "echo '' > \"$@\"", +) + # Rules that apply the LLVM tblgen tool. gentbl( name = "intrinsics_gen", @@ -453,6 +463,7 @@ llvm_target_list = [ "include/llvm/IR/Intrinsics*.td", "include/llvm/TableGen/*.td", "include/llvm/Target/*.td", + "include/llvm/Target/GlobalISel/*.td", ]), ) for target in llvm_target_list @@ -1067,6 +1078,8 @@ cc_library( "include/llvm/BinaryFormat/*.h", "include/llvm/BinaryFormat/*.def", "include/llvm/BinaryFormat/*.inc", + "include/llvm/BinaryFormat/ELFRelocs/*.def", + "include/llvm/BinaryFormat/WasmRelocs/*.def", ]), deps = [ ":config", @@ -1169,7 +1182,7 @@ cc_library( "include/llvm/IR/*.def", "include/llvm/IR/*.inc", "include/llvm/*.h", - ]), + ]) + ["include/llvm/Support/VCSRevision.h"], deps = [ ":attributes_compat_gen", ":attributes_gen", @@ -1194,6 +1207,7 @@ cc_library( "include/llvm/DebugInfo/CodeView/*.inc", ]), deps = [ + ":binary_format", ":config", ":debug_info_msf", ":support", @@ -1426,6 +1440,7 @@ cc_library( "include/llvm/MC/*.inc", ]), deps = [ + ":binary_format", ":config", ":debug_info_code_view", ":support", @@ -1921,6 +1936,8 @@ cc_library( "lib/Support/Unix/*.h", "include/llvm-c/*.h", "include/llvm/CodeGen/MachineValueType.h", + "include/llvm/BinaryFormat/COFF.h", + "include/llvm/BinaryFormat/MachO.h", "lib/Support/*.h", ]), hdrs = glob([ @@ -1931,6 +1948,7 @@ cc_library( "include/llvm/Support/ELFRelocs/*.def", "include/llvm/Support/WasmRelocs/*.def", ]) + [ + "include/llvm/BinaryFormat/MachO.def", "include/llvm/Support/DataTypes.h", "include/llvm/ExecutionEngine/ObjectMemoryBuffer.h", ], diff --git a/third_party/lmdb.BUILD b/third_party/lmdb.BUILD index 7c6e3dc3f0531f7e2dc3c4ad782a6a02a6b4e514..61228bfd4376a03b30d576ed52085c653eb5a9c2 100644 --- a/third_party/lmdb.BUILD +++ b/third_party/lmdb.BUILD @@ -19,8 +19,8 @@ cc_library( "-w", ], linkopts = select({ - ":windows": ["-Wl,advapi32.lib"], # InitializeSecurityDescriptor, SetSecurityDescriptorDacl - ":windows_msvc": ["-Wl,advapi32.lib"], + ":windows": ["-DEFAULTLIB:advapi32.lib"], # InitializeSecurityDescriptor, SetSecurityDescriptorDacl + ":windows_msvc": ["-DEFAULTLIB:advapi32.lib"], "//conditions:default": ["-lpthread"], }), visibility = ["//visibility:public"], diff --git a/third_party/markdown.BUILD b/third_party/markdown.BUILD deleted file mode 100644 index fa3e85d5304083ed0de521c93c5ea1df1f477349..0000000000000000000000000000000000000000 --- a/third_party/markdown.BUILD +++ /dev/null @@ -1,15 +0,0 @@ -# Description: -# Markdown processor - -package(default_visibility = ["//visibility:public"]) - -# This software says they use a BSD license. -licenses(["notice"]) - -exports_files(["LICENSE.md"]) - -py_library( - name = "org_pythonhosted_markdown", - srcs = glob(["markdown/**/*.py"]), - srcs_version = "PY2AND3", -) diff --git a/third_party/polymer.bzl b/third_party/polymer.bzl deleted file mode 100644 index bd6e05803cf39192092fb20015c7abe520e8903e..0000000000000000000000000000000000000000 --- a/third_party/polymer.bzl +++ /dev/null @@ -1,1335 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the 'License'); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an 'AS IS' BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# TensorBoard Polymer Dependencies - -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library_external") - -def tensorboard_polymer_workspace(): - web_library_external( - name = "org_polymer_font_roboto", - licenses = ["notice"], # BSD-3-Clause - sha256 = "fae51429b56a4a4c15f1f0c23b733c7095940cc9c04c275fa7adb3bf055b23b3", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/font-roboto/archive/v1.0.1.tar.gz", - "https://github.com/PolymerElements/font-roboto/archive/v1.0.1.tar.gz", - ], - strip_prefix = "font-roboto-1.0.1", - path = "/font-roboto", - srcs = ["roboto.html"], - ) - - web_library_external( - name = "org_polymer_hydrolysis", - licenses = ["notice"], # BSD-3-Clause - sha256 = "703b50f6b00f9e0546b5a3451da57bb20f77a166e27e4967923b9e835bab9b80", - urls = [ - "http://mirror.bazel.build/github.com/Polymer/polymer-analyzer/archive/v1.19.3.tar.gz", - "https://github.com/Polymer/polymer-analyzer/archive/v1.19.3.tar.gz", - ], - strip_prefix = "polymer-analyzer-1.19.3", - path = "/hydrolysis", - srcs = [ - "hydrolysis-analyzer.html", - "hydrolysis.html", - "hydrolysis.js", - ], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_iron_a11y_announcer", - licenses = ["notice"], # BSD-3-Clause - sha256 = "6bce143db7a374a68535ec8b861a5f30e81f2f1e4ee36a55bda2a891f6fd2818", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-a11y-announcer/archive/v1.0.5.tar.gz", - "https://github.com/PolymerElements/iron-a11y-announcer/archive/v1.0.5.tar.gz", - ], - strip_prefix = "iron-a11y-announcer-1.0.5", - path = "/iron-a11y-announcer", - srcs = ["iron-a11y-announcer.html"], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_iron_a11y_keys_behavior", - licenses = ["notice"], # BSD-3-Clause - sha256 = "6823efc47a83208fd51d39c5a1d3eb0c0bebc705df1ce01310509da22a13ebd2", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-a11y-keys-behavior/archive/v1.1.8.tar.gz", - "https://github.com/PolymerElements/iron-a11y-keys-behavior/archive/v1.1.8.tar.gz", - ], - strip_prefix = "iron-a11y-keys-behavior-1.1.8", - path = "/iron-a11y-keys-behavior", - srcs = ["iron-a11y-keys-behavior.html"], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_iron_ajax", - licenses = ["notice"], # BSD-3-Clause - sha256 = "9162d8af4611e911ac3ebbfc08bb7038ac04f6e79a9287b1476fe36ad6770bc5", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-ajax/archive/v1.2.0.tar.gz", - "https://github.com/PolymerElements/iron-ajax/archive/v1.2.0.tar.gz", - ], - strip_prefix = "iron-ajax-1.2.0", - path = "/iron-ajax", - srcs = [ - "iron-ajax.html", - "iron-request.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_promise_polyfill", - ], - ) - - web_library_external( - name = "org_polymer_iron_autogrow_textarea", - licenses = ["notice"], # BSD-3-Clause - sha256 = "50bbb901d2c8f87462e3552e3d671a552faa12c37c485e548d7a234ebffbc427", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-autogrow-textarea/archive/v1.0.12.tar.gz", - "https://github.com/PolymerElements/iron-autogrow-textarea/archive/v1.0.12.tar.gz", - ], - strip_prefix = "iron-autogrow-textarea-1.0.12", - path = "/iron-autogrow-textarea", - srcs = ["iron-autogrow-textarea.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_behaviors", - "@org_polymer_iron_flex_layout", - "@org_polymer_iron_form_element_behavior", - "@org_polymer_iron_validatable_behavior", - ], - ) - - web_library_external( - name = "org_polymer_iron_behaviors", - licenses = ["notice"], # BSD-3-Clause - sha256 = "a1e8d4b7a13f3d36beba9c2a6b186ed33a53e6af2e79f98c1fcc7e85e7b53f89", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-behaviors/archive/v1.0.17.tar.gz", - "https://github.com/PolymerElements/iron-behaviors/archive/v1.0.17.tar.gz", - ], - strip_prefix = "iron-behaviors-1.0.17", - path = "/iron-behaviors", - srcs = [ - "iron-button-state.html", - "iron-control-state.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_keys_behavior", - ], - ) - - web_library_external( - name = "org_polymer_iron_checked_element_behavior", - licenses = ["notice"], # BSD-3-Clause - sha256 = "539a0e1c4df0bc702d3bd342388e4e56c77ec4c2066cce69e41426a69f92e8bd", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-checked-element-behavior/archive/v1.0.4.tar.gz", - "https://github.com/PolymerElements/iron-checked-element-behavior/archive/v1.0.4.tar.gz", - ], - strip_prefix = "iron-checked-element-behavior-1.0.4", - path = "/iron-checked-element-behavior", - srcs = ["iron-checked-element-behavior.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_form_element_behavior", - "@org_polymer_iron_validatable_behavior", - ], - ) - - web_library_external( - name = "org_polymer_iron_component_page", - licenses = ["notice"], # BSD-3-Clause - sha256 = "3636e8b9a1f229fc33b5aad3933bd02a9825f66e679a0be31855d7c8245c4b4b", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-component-page/archive/v1.1.4.tar.gz", - "https://github.com/PolymerElements/iron-component-page/archive/v1.1.4.tar.gz", - ], - strip_prefix = "iron-component-page-1.1.4", - path = "/iron-component-page", - srcs = ["iron-component-page.html"], - deps = [ - "@org_polymer", - "@org_polymer_hydrolysis", - "@org_polymer_iron_ajax", - "@org_polymer_iron_doc_viewer", - "@org_polymer_iron_flex_layout", - "@org_polymer_iron_icons", - "@org_polymer_iron_selector", - "@org_polymer_paper_header_panel", - "@org_polymer_paper_styles", - "@org_polymer_paper_toolbar", - ], - ) - - web_library_external( - name = "org_polymer_iron_collapse", - licenses = ["notice"], # BSD-3-Clause - sha256 = "275808994a609a2f9923e2dd2db1957945ab141ba840eadc33f19e1f406d600e", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-collapse/archive/v1.0.8.tar.gz", - "https://github.com/PolymerElements/iron-collapse/archive/v1.0.8.tar.gz", - ], - strip_prefix = "iron-collapse-1.0.8", - path = "/iron-collapse", - srcs = ["iron-collapse.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_resizable_behavior", - ], - ) - - web_library_external( - name = "org_polymer_iron_demo_helpers", - licenses = ["notice"], # BSD-3-Clause - sha256 = "aa7458492a6ac3d1f6344640a4c2ab07bce64e7ad0422b83b5d665707598cce6", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-demo-helpers/archive/v1.1.0.tar.gz", - "https://github.com/PolymerElements/iron-demo-helpers/archive/v1.1.0.tar.gz", - ], - strip_prefix = "iron-demo-helpers-1.1.0", - path = "/iron-demo-helpers", - srcs = [ - "demo-pages-shared-styles.html", - "demo-snippet.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_flex_layout", - "@org_polymer_iron_icons", - "@org_polymer_marked_element", - "@org_polymer_paper_icon_button", - "@org_polymer_paper_styles", - "@org_polymer_prism_element", - ], - ) - - web_library_external( - name = "org_polymer_iron_doc_viewer", - licenses = ["notice"], # BSD-3-Clause - sha256 = "f0e9dfbbcd94d7e88ce82cb61e615406ace63c185fee9396f7f182206ca5cc9a", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-doc-viewer/archive/v1.0.12.tar.gz", - "https://github.com/PolymerElements/iron-doc-viewer/archive/v1.0.12.tar.gz", - ], - strip_prefix = "iron-doc-viewer-1.0.12", - path = "/iron-doc-viewer", - srcs = [ - "iron-doc-property-styles.html", - "iron-doc-property.html", - "iron-doc-viewer-styles.html", - "iron-doc-viewer.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_marked_element", - "@org_polymer_paper_button", - "@org_polymer_paper_styles", - "@org_polymer_prism_element", - ], - ) - - web_library_external( - name = "org_polymer_iron_dropdown", - licenses = ["notice"], # BSD-3-Clause - sha256 = "f7e4a31d096d10d8af1920397695cb17f3eb1cbe5e5ff91a861dabfcc085f376", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-dropdown/archive/v1.4.0.tar.gz", - "https://github.com/PolymerElements/iron-dropdown/archive/v1.4.0.tar.gz", - ], - strip_prefix = "iron-dropdown-1.4.0", - path = "/iron-dropdown", - srcs = [ - "iron-dropdown.html", - "iron-dropdown-scroll-manager.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_keys_behavior", - "@org_polymer_iron_behaviors", - "@org_polymer_iron_overlay_behavior", - "@org_polymer_iron_resizable_behavior", - "@org_polymer_neon_animation", - ], - ) - - web_library_external( - name = "org_polymer_iron_fit_behavior", - licenses = ["notice"], # BSD-3-Clause - sha256 = "10132a2ea309a37c4c07b8fead71f64abc588ee6107931e34680f5f36dd8291e", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-fit-behavior/archive/v1.2.5.tar.gz", - "https://github.com/PolymerElements/iron-fit-behavior/archive/v1.2.5.tar.gz", - ], - strip_prefix = "iron-fit-behavior-1.2.5", - path = "/iron-fit-behavior", - srcs = ["iron-fit-behavior.html"], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_iron_flex_layout", - licenses = ["notice"], # BSD-3-Clause - sha256 = "79287f6ca1c2d4e003f68b88fe19d03a1b6a0011e2b4cae579fe4d1474163a2e", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-flex-layout/archive/v1.3.0.tar.gz", - "https://github.com/PolymerElements/iron-flex-layout/archive/v1.3.0.tar.gz", - ], - strip_prefix = "iron-flex-layout-1.3.0", - path = "/iron-flex-layout", - srcs = [ - "classes/iron-flex-layout.html", - "classes/iron-shadow-flex-layout.html", - "iron-flex-layout.html", - "iron-flex-layout-classes.html", - ], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_iron_form_element_behavior", - licenses = ["notice"], # BSD-3-Clause - sha256 = "1dd9371c638e5bc2ecba8a64074aa680dfb8712198e9612f9ed24d387efc8f26", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-form-element-behavior/archive/v1.0.6.tar.gz", - "https://github.com/PolymerElements/iron-form-element-behavior/archive/v1.0.6.tar.gz", - ], - strip_prefix = "iron-form-element-behavior-1.0.6", - path = "/iron-form-element-behavior", - srcs = ["iron-form-element-behavior.html"], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_iron_icon", - licenses = ["notice"], # BSD-3-Clause - sha256 = "9ed58a69159a02c07a6050d242e6d4e585a29f3245b8c8c390cfd52ddb786dc4", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-icon/archive/v1.0.11.tar.gz", - "https://github.com/PolymerElements/iron-icon/archive/v1.0.11.tar.gz", - ], - strip_prefix = "iron-icon-1.0.11", - path = "/iron-icon", - srcs = ["iron-icon.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_flex_layout", - "@org_polymer_iron_meta", - ], - ) - - web_library_external( - name = "org_polymer_iron_icons", - licenses = ["notice"], # BSD-3-Clause - sha256 = "3b18542c147c7923dc3a36b1a51984a73255d610f297d43c9aaccc52859bd0d0", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-icons/archive/v1.1.3.tar.gz", - "https://github.com/PolymerElements/iron-icons/archive/v1.1.3.tar.gz", - ], - strip_prefix = "iron-icons-1.1.3", - path = "/iron-icons", - srcs = [ - "av-icons.html", - "communication-icons.html", - "device-icons.html", - "editor-icons.html", - "hardware-icons.html", - "image-icons.html", - "iron-icons.html", - "maps-icons.html", - "notification-icons.html", - "places-icons.html", - "social-icons.html", - ], - deps = [ - "@org_polymer_iron_icon", - "@org_polymer_iron_iconset_svg", - ], - ) - - web_library_external( - name = "org_polymer_iron_iconset_svg", - licenses = ["notice"], # BSD-3-Clause - sha256 = "7e3925b7e63a7d22524c4b43ce16ab80d06a576649644783643c11a003284368", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-iconset-svg/archive/v1.1.0.tar.gz", - "https://github.com/PolymerElements/iron-iconset-svg/archive/v1.1.0.tar.gz", - ], - strip_prefix = "iron-iconset-svg-1.1.0", - path = "/iron-iconset-svg", - srcs = ["iron-iconset-svg.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_meta", - ], - ) - - web_library_external( - name = "org_polymer_iron_input", - licenses = ["notice"], # BSD-3-Clause - sha256 = "c505101ead08ab25526b1f49baecc8c28b4221b92a65e7334c783bdc81553c36", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-input/archive/1.0.10.tar.gz", - "https://github.com/PolymerElements/iron-input/archive/1.0.10.tar.gz", - ], - strip_prefix = "iron-input-1.0.10", - path = "/iron-input", - srcs = ["iron-input.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_announcer", - "@org_polymer_iron_validatable_behavior", - ], - ) - - web_library_external( - name = "org_polymer_iron_list", - licenses = ["notice"], # BSD-3-Clause - sha256 = "72a6530b9f0ad5557f5d287845792a0ada74d8b159198e27f940e226313dc116", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-list/archive/v1.3.9.tar.gz", - "https://github.com/PolymerElements/iron-list/archive/v1.3.9.tar.gz", - ], - strip_prefix = "iron-list-1.3.9", - path = "/iron-list", - srcs = ["iron-list.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_keys_behavior", - "@org_polymer_iron_resizable_behavior", - "@org_polymer_iron_scroll_target_behavior", - ], - ) - - web_library_external( - name = "org_polymer_iron_menu_behavior", - licenses = ["notice"], # BSD-3-Clause - sha256 = "ad27889343bc9a709258b073f69abc028bb1ffd3fdb975cd2d3939f7f5d7bb6c", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-menu-behavior/archive/v1.1.10.tar.gz", - "https://github.com/PolymerElements/iron-menu-behavior/archive/v1.1.10.tar.gz", - ], - strip_prefix = "iron-menu-behavior-1.1.10", - path = "/iron-menu-behavior", - srcs = [ - "iron-menu-behavior.html", - "iron-menubar-behavior.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_keys_behavior", - "@org_polymer_iron_selector", - ], - ) - - web_library_external( - name = "org_polymer_iron_meta", - licenses = ["notice"], # BSD-3-Clause - sha256 = "fb05e6031bae6b4effe5f15d44b3f548d5807f9e3b3aa2442ba17cf4b8b84361", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-meta/archive/v1.1.1.tar.gz", - "https://github.com/PolymerElements/iron-meta/archive/v1.1.1.tar.gz", - ], - strip_prefix = "iron-meta-1.1.1", - path = "/iron-meta", - srcs = ["iron-meta.html"], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_iron_overlay_behavior", - licenses = ["notice"], # BSD-3-Clause - sha256 = "3df5b54ff2e0510c87a2aff8c9d730d3fe83d3d11277cc1a49fa29b549acb46c", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-overlay-behavior/archive/v1.10.1.tar.gz", - "https://github.com/PolymerElements/iron-overlay-behavior/archive/v1.10.1.tar.gz", - ], - strip_prefix = "iron-overlay-behavior-1.10.1", - path = "/iron-overlay-behavior", - srcs = [ - "iron-focusables-helper.html", - "iron-overlay-backdrop.html", - "iron-overlay-behavior.html", - "iron-overlay-manager.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_keys_behavior", - "@org_polymer_iron_fit_behavior", - "@org_polymer_iron_resizable_behavior", - ], - ) - - web_library_external( - name = "org_polymer_iron_range_behavior", - licenses = ["notice"], # BSD-3-Clause - sha256 = "b2f2b6d52284542330bd30b586e217926eb0adec5e13934a3cef557717c22dc2", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-range-behavior/archive/v1.0.4.tar.gz", - "https://github.com/PolymerElements/iron-range-behavior/archive/v1.0.4.tar.gz", - ], - strip_prefix = "iron-range-behavior-1.0.4", - path = "/iron-range-behavior", - srcs = ["iron-range-behavior.html"], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_iron_resizable_behavior", - licenses = ["notice"], # BSD-3-Clause - sha256 = "a87a78ee9223c2f6afae7fc94a3ff91cbce6f7e2a7ed3f2979af7945c9281616", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-resizable-behavior/archive/v1.0.3.tar.gz", - "https://github.com/PolymerElements/iron-resizable-behavior/archive/v1.0.3.tar.gz", - ], - strip_prefix = "iron-resizable-behavior-1.0.3", - path = "/iron-resizable-behavior", - srcs = ["iron-resizable-behavior.html"], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_iron_scroll_target_behavior", - licenses = ["notice"], # BSD-3-Clause - sha256 = "d0de0c804b1ec91d814754144afd9da1cdb082690de88bd5e47fd5f41990746f", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-scroll-target-behavior/archive/v1.0.3.tar.gz", - "https://github.com/PolymerElements/iron-scroll-target-behavior/archive/v1.0.3.tar.gz", - ], - strip_prefix = "iron-scroll-target-behavior-1.0.3", - path = "/iron-scroll-target-behavior", - srcs = ["iron-scroll-target-behavior.html"], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_iron_selector", - licenses = ["notice"], # BSD-3-Clause - sha256 = "ba28a47443bad3b744611c9d7a79fb21dbdf2e35edc5ef8f812e2dcd72b16747", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-selector/archive/v1.5.2.tar.gz", - "https://github.com/PolymerElements/iron-selector/archive/v1.5.2.tar.gz", - ], - strip_prefix = "iron-selector-1.5.2", - path = "/iron-selector", - srcs = [ - "iron-multi-selectable.html", - "iron-selectable.html", - "iron-selection.html", - "iron-selector.html", - ], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_iron_validatable_behavior", - licenses = ["notice"], # BSD-3-Clause - sha256 = "aef4901e68043824f36104799269573dd345ffaac494186e466fdc79c06fdb63", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-validatable-behavior/archive/v1.1.1.tar.gz", - "https://github.com/PolymerElements/iron-validatable-behavior/archive/v1.1.1.tar.gz", - ], - strip_prefix = "iron-validatable-behavior-1.1.1", - path = "/iron-validatable-behavior", - srcs = ["iron-validatable-behavior.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_meta", - ], - ) - - web_library_external( - name = "org_polymer_marked", - licenses = ["notice"], # MIT - sha256 = "93d30bd593736ca440938d77808b7ef5972da0f3fcfe4ae63ae7b4ce117da2cb", - urls = [ - "http://mirror.bazel.build/github.com/chjj/marked/archive/v0.3.2.zip", - "https://github.com/chjj/marked/archive/v0.3.2.zip", - ], - strip_prefix = "marked-0.3.2", - path = "/marked", - srcs = ["lib/marked.js"], - ) - - web_library_external( - name = "org_polymer_marked_element", - licenses = ["notice"], # BSD-3-Clause - sha256 = "7547616df95f8b903757e6afbabfcdba5322c2bcec3f17c726b8bba5adf4bc5f", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/marked-element/archive/v1.1.3.tar.gz", - "https://github.com/PolymerElements/marked-element/archive/v1.1.3.tar.gz", - ], - strip_prefix = "marked-element-1.1.3", - path = "/marked-element", - srcs = [ - "marked-element.html", - "marked-import.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_marked", - ], - ) - - web_library_external( - name = "org_polymer_neon_animation", - licenses = ["notice"], # BSD-3-Clause - sha256 = "8800c314a76b2da190a2b203259c1091f6d38e0057ed37c2a3d0b734980fa9a5", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/neon-animation/archive/v1.2.2.tar.gz", - "https://github.com/PolymerElements/neon-animation/archive/v1.2.2.tar.gz", - ], - strip_prefix = "neon-animation-1.2.2", - path = "/neon-animation", - srcs = [ - "animations/cascaded-animation.html", - "animations/fade-in-animation.html", - "animations/fade-out-animation.html", - "animations/hero-animation.html", - "animations/opaque-animation.html", - "animations/reverse-ripple-animation.html", - "animations/ripple-animation.html", - "animations/scale-down-animation.html", - "animations/scale-up-animation.html", - "animations/slide-down-animation.html", - "animations/slide-from-bottom-animation.html", - "animations/slide-from-left-animation.html", - "animations/slide-from-right-animation.html", - "animations/slide-from-top-animation.html", - "animations/slide-left-animation.html", - "animations/slide-right-animation.html", - "animations/slide-up-animation.html", - "animations/transform-animation.html", - "neon-animatable.html", - "neon-animatable-behavior.html", - "neon-animated-pages.html", - "neon-animation.html", - "neon-animation-behavior.html", - "neon-animation-runner-behavior.html", - "neon-animations.html", - "neon-shared-element-animatable-behavior.html", - "neon-shared-element-animation-behavior.html", - "web-animations.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_meta", - "@org_polymer_iron_resizable_behavior", - "@org_polymer_iron_selector", - "@org_polymer_web_animations_js", - ], - ) - - web_library_external( - name = "org_polymer_paper_behaviors", - licenses = ["notice"], # BSD-3-Clause - sha256 = "7cfcb9082ef9909da262df6b5c120bc62dbeaff278cb563e8fc60465ddd387e5", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-behaviors/archive/v1.0.12.tar.gz", - "https://github.com/PolymerElements/paper-behaviors/archive/v1.0.12.tar.gz", - ], - strip_prefix = "paper-behaviors-1.0.12", - path = "/paper-behaviors", - srcs = [ - "paper-button-behavior.html", - "paper-checked-element-behavior.html", - "paper-inky-focus-behavior.html", - "paper-ripple-behavior.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_behaviors", - "@org_polymer_iron_checked_element_behavior", - "@org_polymer_paper_ripple", - ], - ) - - web_library_external( - name = "org_polymer_paper_button", - licenses = ["notice"], # BSD-3-Clause - sha256 = "896c0a7e34bfcce63fc23c63e105ed9c4d62fa3a6385b7161e1e5cd4058820a6", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-button/archive/v1.0.11.tar.gz", - "https://github.com/PolymerElements/paper-button/archive/v1.0.11.tar.gz", - ], - strip_prefix = "paper-button-1.0.11", - path = "/paper-button", - srcs = ["paper-button.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_flex_layout", - "@org_polymer_paper_behaviors", - "@org_polymer_paper_material", - "@org_polymer_paper_ripple", - ], - ) - - web_library_external( - name = "org_polymer_paper_checkbox", - licenses = ["notice"], # BSD-3-Clause - sha256 = "6828a6954a048b1230fbd2606faffbae950ba1d042175b96ec50ae355786a166", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-checkbox/archive/v1.4.0.tar.gz", - "https://github.com/PolymerElements/paper-checkbox/archive/v1.4.0.tar.gz", - ], - strip_prefix = "paper-checkbox-1.4.0", - path = "/paper-checkbox", - srcs = ["paper-checkbox.html"], - deps = [ - "@org_polymer", - "@org_polymer_paper_behaviors", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_dialog", - licenses = ["notice"], # BSD-3-Clause - sha256 = "c6a9709e7f528d03dcd574503c18b72d4751ca30017346d16e6a791d37ed9259", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-dialog/archive/v1.0.4.tar.gz", - "https://github.com/PolymerElements/paper-dialog/archive/v1.0.4.tar.gz", - ], - strip_prefix = "paper-dialog-1.0.4", - path = "/paper-dialog", - srcs = ["paper-dialog.html"], - deps = [ - "@org_polymer", - "@org_polymer_neon_animation", - "@org_polymer_paper_dialog_behavior", - ], - ) - - web_library_external( - name = "org_polymer_paper_dialog_behavior", - licenses = ["notice"], # BSD-3-Clause - sha256 = "a7e0e27ce63554bc14f384cf94bcfa24da8dc5f5120dfd565f45e166261aee40", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-dialog-behavior/archive/v1.2.5.tar.gz", - "https://github.com/PolymerElements/paper-dialog-behavior/archive/v1.2.5.tar.gz", - ], - strip_prefix = "paper-dialog-behavior-1.2.5", - path = "/paper-dialog-behavior", - srcs = [ - "paper-dialog-behavior.html", - "paper-dialog-common.css", - "paper-dialog-shared-styles.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_flex_layout", - "@org_polymer_iron_overlay_behavior", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_dialog_scrollable", - licenses = ["notice"], # BSD-3-Clause - sha256 = "a2e69283e7674f782c44d811387a0f8da2d01fac0172743d1add65e253e6b5ff", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-dialog-scrollable/archive/1.1.5.tar.gz", - "https://github.com/PolymerElements/paper-dialog-scrollable/archive/1.1.5.tar.gz", - ], - strip_prefix = "paper-dialog-scrollable-1.1.5", - path = "/paper-dialog-scrollable", - srcs = ["paper-dialog-scrollable.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_flex_layout", - "@org_polymer_paper_dialog_behavior", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_dropdown_menu", - licenses = ["notice"], # BSD-3-Clause - sha256 = "9d88f654ec03ee9be211df9e69bede9e8a22b51bf1dbcc63b79762e4256d81ad", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-dropdown-menu/archive/v1.4.0.tar.gz", - "https://github.com/PolymerElements/paper-dropdown-menu/archive/v1.4.0.tar.gz", - ], - strip_prefix = "paper-dropdown-menu-1.4.0", - path = "/paper-dropdown-menu", - srcs = [ - "paper-dropdown-menu.html", - "paper-dropdown-menu-icons.html", - "paper-dropdown-menu-light.html", - "paper-dropdown-menu-shared-styles.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_keys_behavior", - "@org_polymer_iron_behaviors", - "@org_polymer_iron_form_element_behavior", - "@org_polymer_iron_icon", - "@org_polymer_iron_iconset_svg", - "@org_polymer_iron_validatable_behavior", - "@org_polymer_paper_behaviors", - "@org_polymer_paper_input", - "@org_polymer_paper_menu_button", - "@org_polymer_paper_ripple", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_header_panel", - licenses = ["notice"], # BSD-3-Clause - sha256 = "0db4bd8a4bf6f20dcd0dffb4f907b31c93a8647c9c021344239cf30b40b87075", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-header-panel/archive/v1.1.4.tar.gz", - "https://github.com/PolymerElements/paper-header-panel/archive/v1.1.4.tar.gz", - ], - strip_prefix = "paper-header-panel-1.1.4", - path = "/paper-header-panel", - srcs = ["paper-header-panel.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_flex_layout", - ], - ) - - web_library_external( - name = "org_polymer_paper_icon_button", - licenses = ["notice"], # BSD-3-Clause - sha256 = "9cba5bcfd6aeb4c41581c1392c678cf2278d360e9d122f4d9db54a9ebb404496", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-icon-button/archive/v1.1.3.tar.gz", - "https://github.com/PolymerElements/paper-icon-button/archive/v1.1.3.tar.gz", - ], - strip_prefix = "paper-icon-button-1.1.3", - path = "/paper-icon-button", - srcs = [ - "paper-icon-button.html", - "paper-icon-button-light.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_icon", - "@org_polymer_paper_behaviors", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_input", - licenses = ["notice"], # BSD-3-Clause - sha256 = "17c3dea9bb1c2026cc61324696c6c774214a0dc37686b91ca214a6af550994db", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-input/archive/v1.1.18.tar.gz", - "https://github.com/PolymerElements/paper-input/archive/v1.1.18.tar.gz", - ], - strip_prefix = "paper-input-1.1.18", - path = "/paper-input", - srcs = [ - "paper-input.html", - "paper-input-addon-behavior.html", - "paper-input-behavior.html", - "paper-input-char-counter.html", - "paper-input-container.html", - "paper-input-error.html", - "paper-textarea.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_keys_behavior", - "@org_polymer_iron_autogrow_textarea", - "@org_polymer_iron_behaviors", - "@org_polymer_iron_flex_layout", - "@org_polymer_iron_form_element_behavior", - "@org_polymer_iron_input", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_item", - licenses = ["notice"], # BSD-3-Clause - sha256 = "12ee0dcb61b0d5721c5988571f6974d7b2211e97724f4195893fbcc9058cdac8", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-item/archive/v1.1.4.tar.gz", - "https://github.com/PolymerElements/paper-item/archive/v1.1.4.tar.gz", - ], - strip_prefix = "paper-item-1.1.4", - path = "/paper-item", - srcs = [ - "paper-icon-item.html", - "paper-item.html", - "paper-item-behavior.html", - "paper-item-body.html", - "paper-item-shared-styles.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_behaviors", - "@org_polymer_iron_flex_layout", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_listbox", - licenses = ["notice"], # BSD-3-Clause - sha256 = "3cb35f4fe9a3f15185a9e91711dba8f27e9291c8cd371ebf1be21b8f1d5f65fb", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-listbox/archive/v1.1.2.tar.gz", - "https://github.com/PolymerElements/paper-listbox/archive/v1.1.2.tar.gz", - ], - strip_prefix = "paper-listbox-1.1.2", - path = "/paper-listbox", - srcs = ["paper-listbox.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_menu_behavior", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_material", - licenses = ["notice"], # BSD-3-Clause - sha256 = "09f6c8bd6ddbea2be541dc86306efe41cdfb31bec0b69d35a5dc29772bbc8506", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-material/archive/v1.0.6.tar.gz", - "https://github.com/PolymerElements/paper-material/archive/v1.0.6.tar.gz", - ], - strip_prefix = "paper-material-1.0.6", - path = "/paper-material", - srcs = [ - "paper-material.html", - "paper-material-shared-styles.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_menu", - licenses = ["notice"], # BSD-3-Clause - sha256 = "a3cee220926e315f7412236b3628288774694447c0da4428345f36d0f127ba3b", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-menu/archive/v1.2.2.tar.gz", - "https://github.com/PolymerElements/paper-menu/archive/v1.2.2.tar.gz", - ], - strip_prefix = "paper-menu-1.2.2", - path = "/paper-menu", - srcs = [ - "paper-menu.html", - "paper-menu-shared-styles.html", - "paper-submenu.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_behaviors", - "@org_polymer_iron_collapse", - "@org_polymer_iron_flex_layout", - "@org_polymer_iron_menu_behavior", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_menu_button", - licenses = ["notice"], # BSD-3-Clause - sha256 = "be3290c288a2bd4f9887213db22c75add99cc29ff4d088100c0bc4eb0e57997b", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-menu-button/archive/v1.5.1.tar.gz", - "https://github.com/PolymerElements/paper-menu-button/archive/v1.5.1.tar.gz", - ], - strip_prefix = "paper-menu-button-1.5.1", - path = "/paper-menu-button", - srcs = [ - "paper-menu-button.html", - "paper-menu-button-animations.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_keys_behavior", - "@org_polymer_iron_behaviors", - "@org_polymer_iron_dropdown", - "@org_polymer_neon_animation", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_progress", - licenses = ["notice"], # BSD-3-Clause - sha256 = "2b6776b2f023c1f344feea17ba29b58d879e46f8ed43b7256495054b5183fff6", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-progress/archive/v1.0.9.tar.gz", - "https://github.com/PolymerElements/paper-progress/archive/v1.0.9.tar.gz", - ], - strip_prefix = "paper-progress-1.0.9", - path = "/paper-progress", - srcs = ["paper-progress.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_flex_layout", - "@org_polymer_iron_range_behavior", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_radio_button", - licenses = ["notice"], # BSD-3-Clause - sha256 = "6e911d0c308aa388136b3af79d1bdcbe5a1f4159cbc79d71efb4ff3b6c0b4e91", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-radio-button/archive/v1.1.2.tar.gz", - "https://github.com/PolymerElements/paper-radio-button/archive/v1.1.2.tar.gz", - ], - strip_prefix = "paper-radio-button-1.1.2", - path = "/paper-radio-button", - srcs = ["paper-radio-button.html"], - deps = [ - "@org_polymer", - "@org_polymer_paper_behaviors", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_radio_group", - licenses = ["notice"], # BSD-3-Clause - sha256 = "7885ad1f81e9dcc03dcea4139b54a201ff55c18543770cd44f94530046c9e163", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-radio-group/archive/v1.0.9.tar.gz", - "https://github.com/PolymerElements/paper-radio-group/archive/v1.0.9.tar.gz", - ], - strip_prefix = "paper-radio-group-1.0.9", - path = "/paper-radio-group", - srcs = ["paper-radio-group.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_keys_behavior", - "@org_polymer_iron_selector", - "@org_polymer_paper_radio_button", - ], - ) - - web_library_external( - name = "org_polymer_paper_ripple", - licenses = ["notice"], # BSD-3-Clause - sha256 = "ba76bfb1c737260a8a103d3ca97faa1f7c3288c7db9b2519f401b7a782147c09", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-ripple/archive/v1.0.5.tar.gz", - "https://github.com/PolymerElements/paper-ripple/archive/v1.0.5.tar.gz", - ], - strip_prefix = "paper-ripple-1.0.5", - path = "/paper-ripple", - srcs = ["paper-ripple.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_keys_behavior", - ], - ) - - web_library_external( - name = "org_polymer_paper_slider", - licenses = ["notice"], # BSD-3-Clause - sha256 = "08e7c541dbf5d2e959208810bfc03188e82ced87e4d30d325172967f67962c3c", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-slider/archive/v1.0.10.tar.gz", - "https://github.com/PolymerElements/paper-slider/archive/v1.0.10.tar.gz", - ], - strip_prefix = "paper-slider-1.0.10", - path = "/paper-slider", - srcs = ["paper-slider.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_keys_behavior", - "@org_polymer_iron_flex_layout", - "@org_polymer_iron_form_element_behavior", - "@org_polymer_iron_range_behavior", - "@org_polymer_paper_behaviors", - "@org_polymer_paper_input", - "@org_polymer_paper_progress", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_spinner", - licenses = ["notice"], # BSD-3-Clause - sha256 = "6a752907fab7899cbeed15b478e7b9299047c15fbf9d1561d6eb4d204bdbd178", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-spinner/archive/v1.1.1.tar.gz", - "https://github.com/PolymerElements/paper-spinner/archive/v1.1.1.tar.gz", - ], - strip_prefix = "paper-spinner-1.1.1", - path = "/paper-spinner", - srcs = [ - "paper-spinner.html", "paper-spinner-behavior.html", - "paper-spinner-lite.html", "paper-spinner-styles.html" - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_flex_layout", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_styles", - licenses = ["notice"], # BSD-3-Clause - sha256 = "6d26b0a4c286402098853dc7388f6b22f30dfb7a74e47b34992ac03380144bb2", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-styles/archive/v1.1.4.tar.gz", - "https://github.com/PolymerElements/paper-styles/archive/v1.1.4.tar.gz", - ], - strip_prefix = "paper-styles-1.1.4", - path = "/paper-styles", - srcs = [ - "classes/global.html", - "classes/shadow.html", - "classes/shadow-layout.html", - "classes/typography.html", - "color.html", - "default-theme.html", - "demo.css", - "demo-pages.html", - "paper-styles.html", - "paper-styles-classes.html", - "shadow.html", - "typography.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_font_roboto", - "@org_polymer_iron_flex_layout", - ], - ) - - web_library_external( - name = "org_polymer_paper_tabs", - licenses = ["notice"], # BSD-3-Clause - sha256 = "c23b6a5221db35e5b1ed3eb8e8696b952572563e285adaec96aba1e3134db825", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-tabs/archive/v1.7.0.tar.gz", - "https://github.com/PolymerElements/paper-tabs/archive/v1.7.0.tar.gz", - ], - strip_prefix = "paper-tabs-1.7.0", - path = "/paper-tabs", - srcs = [ - "paper-tab.html", - "paper-tabs.html", - "paper-tabs-icons.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_behaviors", - "@org_polymer_iron_flex_layout", - "@org_polymer_iron_icon", - "@org_polymer_iron_iconset_svg", - "@org_polymer_iron_menu_behavior", - "@org_polymer_iron_resizable_behavior", - "@org_polymer_paper_behaviors", - "@org_polymer_paper_icon_button", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_toast", - licenses = ["notice"], # BSD-3-Clause - sha256 = "55f623712ed1f2bae6d6fadc522a2458e083ccd44cc0a907672547e7b10758a9", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-toast/archive/v1.3.0.tar.gz", - "https://github.com/PolymerElements/paper-toast/archive/v1.3.0.tar.gz", - ], - strip_prefix = "paper-toast-1.3.0", - path = "/paper-toast", - srcs = ["paper-toast.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_announcer", - "@org_polymer_iron_overlay_behavior", - ], - ) - - web_library_external( - name = "org_polymer_paper_toggle_button", - licenses = ["notice"], # BSD-3-Clause - sha256 = "4aa7cf0396fa2994a8bc2ac6e8428f48b07b945bb7c41bd52041ef5827b45de3", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-toggle-button/archive/v1.2.0.tar.gz", - "https://github.com/PolymerElements/paper-toggle-button/archive/v1.2.0.tar.gz", - ], - strip_prefix = "paper-toggle-button-1.2.0", - path = "/paper-toggle-button", - srcs = ["paper-toggle-button.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_flex_layout", - "@org_polymer_paper_behaviors", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_toolbar", - licenses = ["notice"], # BSD-3-Clause - sha256 = "dbddffc0654d9fb5fb48843087eebe16bf7a134902495a664c96c11bf8a2c63d", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-toolbar/archive/v1.1.4.tar.gz", - "https://github.com/PolymerElements/paper-toolbar/archive/v1.1.4.tar.gz", - ], - strip_prefix = "paper-toolbar-1.1.4", - path = "/paper-toolbar", - srcs = ["paper-toolbar.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_flex_layout", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_tooltip", - licenses = ["notice"], # BSD-3-Clause - sha256 = "4c6667acf01f73da14c3cbc0aa574bf14280304567987ee0314534328377d2ad", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-tooltip/archive/v1.1.2.tar.gz", - "https://github.com/PolymerElements/paper-tooltip/archive/v1.1.2.tar.gz", - ], - strip_prefix = "paper-tooltip-1.1.2", - path = "/paper-tooltip", - srcs = ["paper-tooltip.html"], - deps = [ - "@org_polymer", - "@org_polymer_neon_animation", - ], - ) - - web_library_external( - name = "org_polymer", - licenses = ["notice"], # BSD-3-Clause - sha256 = "07a9e62ffb52193da3af09adda2fbac5cc690439978520e2d03e783863f65f91", - strip_prefix = "polymer-1.7.0", - urls = [ - "http://mirror.bazel.build/github.com/polymer/polymer/archive/v1.7.0.tar.gz", - "https://github.com/polymer/polymer/archive/v1.7.0.tar.gz", - ], - path = "/polymer", - srcs = [ - "polymer.html", - "polymer-micro.html", - "polymer-mini.html", - ], - ) - - web_library_external( - name = "org_polymer_prism", - licenses = ["notice"], # MIT - sha256 = "e06eb54f2a80e6b3cd0bd4d59f900423bcaee53fc03998a056df63740c684683", - urls = [ - "http://mirror.bazel.build/github.com/PrismJS/prism/archive/abee2b7587f1925e57777044270e2a1860810994.tar.gz", - "https://github.com/PrismJS/prism/archive/abee2b7587f1925e57777044270e2a1860810994.tar.gz", - ], - strip_prefix = "prism-abee2b7587f1925e57777044270e2a1860810994", - path = "/prism", - srcs = [ - "prism.js", - "themes/prism.css", - ], - ) - - web_library_external( - name = "org_polymer_prism_element", - licenses = ["notice"], # BSD-3-Clause - sha256 = "ad70bf9cd5bbdf525d465e1b0658867ab4022193eb9c74087a839044b46312b4", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/prism-element/archive/1.0.4.tar.gz", - "https://github.com/PolymerElements/prism-element/archive/1.0.4.tar.gz", - ], - strip_prefix = "prism-element-1.0.4", - path = "/prism-element", - srcs = [ - "prism-highlighter.html", - "prism-import.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_prism", - ], - ) - - web_library_external( - name = "org_polymer_promise_polyfill", - licenses = ["notice"], # BSD-3-Clause - sha256 = "4495450e5d884c3e16b537b43afead7f84d17c7dc061bcfcbf440eac083e4ef5", - strip_prefix = "promise-polyfill-1.0.0", - urls = [ - "http://mirror.bazel.build/github.com/PolymerLabs/promise-polyfill/archive/v1.0.0.tar.gz", - "https://github.com/PolymerLabs/promise-polyfill/archive/v1.0.0.tar.gz", - ], - path = "/promise-polyfill", - srcs = [ - "Promise.js", - "Promise-Statics.js", - "promise-polyfill.html", - "promise-polyfill-lite.html" - ], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_web_animations_js", - licenses = ["notice"], # BSD-3-Clause - sha256 = "f8bd760cbdeba131f6790bd5abe170bcbf7b1755ff58ed16d0b82fa8a7f34a7f", - urls = [ - "http://mirror.bazel.build/github.com/web-animations/web-animations-js/archive/2.2.1.tar.gz", - "https://github.com/web-animations/web-animations-js/archive/2.2.1.tar.gz", - ], - strip_prefix = "web-animations-js-2.2.1", - path = "/web-animations-js", - srcs = ["web-animations-next-lite.min.js"], - ) - - web_library_external( - name = "org_polymer_webcomponentsjs", - licenses = ["notice"], # BSD-3-Clause - sha256 = "138c43306ee0a6d699ddca9b3c6b0f4982974ea8b7bdad291ea7276c72301df9", - urls = [ - "http://mirror.bazel.build/github.com/webcomponents/webcomponentsjs/archive/v0.7.22.tar.gz", - "https://github.com/webcomponents/webcomponentsjs/archive/v0.7.22.tar.gz", - ], - strip_prefix = "webcomponentsjs-0.7.22", - path = "/webcomponentsjs", - srcs = [ - "CustomElements.js", - "CustomElements.min.js", - "HTMLImports.js", - "HTMLImports.min.js", - "MutationObserver.js", - "MutationObserver.min.js", - "ShadowDOM.js", - "ShadowDOM.min.js", - "webcomponents.js", - "webcomponents.min.js", - "webcomponents-lite.js", - "webcomponents-lite.min.js", - ], - ) diff --git a/third_party/py/BUILD.tpl b/third_party/py/BUILD.tpl index 1ee9c071adb2d9f4aec84b92277c5067f153b666..2f3503e7930948280b0cc973b6910cd29cecb8f2 100644 --- a/third_party/py/BUILD.tpl +++ b/third_party/py/BUILD.tpl @@ -6,6 +6,16 @@ cc_library( name = "python_headers", hdrs = [":python_include"], includes = ["python_include"], + data = select({ + ":windows" : [":python_import_lib"], + "//conditions:default": [], + }), + linkopts = select({ + # TODO(pcloudy): Ideally, this should just go into deps after resolving + # https://github.com/bazelbuild/bazel/issues/3237, + ":windows" : ["$(locations :python_import_lib)"], + "//conditions:default": [], + }), ) cc_library( @@ -21,5 +31,5 @@ config_setting( ) %{PYTHON_INCLUDE_GENRULE} - +%{PYTHON_IMPORT_LIB_GENRULE} %{NUMPY_INCLUDE_GENRULE} diff --git a/third_party/py/python_configure.bzl b/third_party/py/python_configure.bzl index b4a98af7b6e7742ba99829e6b5e7ce13224cb217..bbc07905fc7f92a26d0aebade66a20209dc3e766 100644 --- a/third_party/py/python_configure.bzl +++ b/third_party/py/python_configure.bzl @@ -9,10 +9,9 @@ * `PYTHON_LIB_PATH`: Location of python libraries. """ -_NUMPY_INCLUDE_PATH = "NUMPY_INCLUDE_PATH" _PYTHON_BIN_PATH = "PYTHON_BIN_PATH" -_PYTHON_INCLUDE_PATH = "PYTHON_INCLUDE_PATH" _PYTHON_LIB_PATH = "PYTHON_LIB_PATH" +_TF_PYTHON_CONFIG_REPO = "TF_PYTHON_CONFIG_REPO" def _tpl(repository_ctx, tpl, substitutions={}, out=None): @@ -116,11 +115,11 @@ def _genrule(src_dir, genrule_name, command, outs): genrule_name + '",\n' + ' outs = [\n' + outs + - ' ],\n' + + '\n ],\n' + ' cmd = """\n' + command + - ' """,\n' + - ')\n\n' + '\n """,\n' + + ')\n' ) @@ -132,15 +131,20 @@ def _norm_path(path): return path -def _symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name): +def _symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name, + src_files = [], dest_files = []): """Returns a genrule to symlink(or copy if on Windows) a set of files. + + If src_dir is passed, files will be read from the given directory; otherwise + we assume files are in src_files and dest_files """ - src_dir = _norm_path(src_dir) - dest_dir = _norm_path(dest_dir) - files = _read_dir(repository_ctx, src_dir) - # Create a list with the src_dir stripped to use for outputs. - dest_files = files.replace(src_dir, '').splitlines() - src_files = files.splitlines() + if src_dir != None: + src_dir = _norm_path(src_dir) + dest_dir = _norm_path(dest_dir) + files = _read_dir(repository_ctx, src_dir) + # Create a list with the src_dir stripped to use for outputs. + dest_files = files.replace(src_dir, '').splitlines() + src_files = files.splitlines() command = [] outs = [] for i in range(len(dest_files)): @@ -151,12 +155,27 @@ def _symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name): # On Windows, symlink is not supported, so we just copy all the files. cmd = 'cp -f' if _is_windows(repository_ctx) else 'ln -s' command.append(cmd + ' "%s" "%s"' % (src_files[i] , dest)) - outs.append(' "' + dest_dir + dest_files[i] + '",') + outs.append(' "' + dest_dir + dest_files[i] + '",') genrule = _genrule(src_dir, genrule_name, " && ".join(command), "\n".join(outs)) return genrule +def _get_python_bin(repository_ctx): + """Gets the python bin path.""" + python_bin = _get_env_var(repository_ctx, _PYTHON_BIN_PATH, + None, False) + if python_bin != None: + return python_bin + python_bin_path = repository_ctx.which("python") + if python_bin_path != None: + return str(python_bin_path) + path = _get_env_var(repository_ctx, "PATH") + _python_configure_fail("Cannot find python in PATH, please make sure " + + "python is installed and add its directory in PATH, or set the " + + "environment variable PYTHON_BIN_PATH.\nPATH=%s" % (path)) + + def _get_python_lib(repository_ctx, python_bin): """Gets the python lib path.""" print_lib = ("< { - let assert = chai.assert; +// DO NOT EDIT: automatically generated file +#ifndef CUDA_CUDA_CONFIG_H_ +#define CUDA_CUDA_CONFIG_H_ - test('dagre exists', () => { assert.isTrue(dagre != null); }); +#define TF_CUDA_CAPABILITIES CudaVersion("3.0") - // TODO(bp): write tests. +#define TF_CUDA_VERSION "8.0" +#define TF_CUDNN_VERSION "5" -}); +#define TF_CUDA_TOOLKIT_PATH "/usr/local/cuda-8.0" + +#endif // CUDA_CUDA_CONFIG_H_ diff --git a/third_party/typings.bzl b/third_party/typings.bzl deleted file mode 100644 index d0c9eddbb3f52803310caed8775840b5af8fbbfa..0000000000000000000000000000000000000000 --- a/third_party/typings.bzl +++ /dev/null @@ -1,365 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the 'License'); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an 'AS IS' BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# TensorBoard typing dependencies - -load("@io_bazel_rules_closure//closure:defs.bzl", "filegroup_external") - -def tensorboard_typings_workspace(): - filegroup_external( - name = "org_definitelytyped", - licenses = ["notice"], # MIT - sha256_urls = { - "b7da645f6e5555feb7aeede73775da0023ce2257df9c8e76c9159266035a9c0d": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/ebc69904eb78f94030d5d517b42db20867f679c0/chai/chai.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/ebc69904eb78f94030d5d517b42db20867f679c0/chai/chai.d.ts", - ], - "177293828c7a206bf2a7f725753d51396d38668311aa37c96445f91bbf8128a7": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/6e2f2280ef16ef277049d0ce8583af167d586c59/d3/d3.d.ts", # v3 - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/6e2f2280ef16ef277049d0ce8583af167d586c59/d3/d3.d.ts", # v3 - ], - "e4cd3d5de0eb3bc7b1063b50d336764a0ac82a658b39b5cf90511f489ffdee60": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/efd40e67ff323f7147651bdbef03c03ead7b1675/lodash/lodash.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/efd40e67ff323f7147651bdbef03c03ead7b1675/lodash/lodash.d.ts", - ], - "695a03dd2ccb238161d97160b239ab841562710e5c4e42886aefd4ace2ce152e": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/ebc69904eb78f94030d5d517b42db20867f679c0/mocha/mocha.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/ebc69904eb78f94030d5d517b42db20867f679c0/mocha/mocha.d.ts", - ], - "513ccd9ee1c708881120eeacd56788fc3b3da8e5c6172b20324cebbe858803fe": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/708609e0764daeb5eb64104af7aca50c520c4e6e/sinon/sinon.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/708609e0764daeb5eb64104af7aca50c520c4e6e/sinon/sinon.d.ts", - ], - "44eba36339bd1c0792072b7b204ee926fe5ffe1e9e2da916e67ac55548e3668a": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/a872802c0c84ba98ff207d5e673a1fa867c67fd6/polymer/polymer.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/a872802c0c84ba98ff207d5e673a1fa867c67fd6/polymer/polymer.d.ts", - ], - "9453c3e6bae824e90758c3b38975c1ed77e6abd79bf513bcb08368fcdb14898e": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/f5407eba29c04fb8387c86df27512bd055b195d2/threejs/three.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/f5407eba29c04fb8387c86df27512bd055b195d2/threejs/three.d.ts", - ], - "691756a6eb455f340c9e834de0d49fff269e7b8c1799c2454465dcd6a4435b80": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/46719185c564694c5583c4b7ad94dbb786ecad46/webcomponents.js/webcomponents.js.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/46719185c564694c5583c4b7ad94dbb786ecad46/webcomponents.js/webcomponents.js.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_array", - licenses = ["notice"], # MIT - sha256_urls = { - "61e7abb7b1f01fbcb0cab8cf39003392f422566209edd681fbd070eaa84ca000": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-array/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-array/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_axis", - licenses = ["notice"], # MIT - sha256_urls = { - "95f75c8dcc89850b2e72581d96a7b5f46ea4ac852f828893f141f14a597421f9": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-axis/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-axis/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_brush", - licenses = ["notice"], # MIT - sha256_urls = { - "a2738e693ce8a8640c2d29001e77582c9c361fd23bda44db471629866b60ada7": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-brush/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-brush/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_chord", - licenses = ["notice"], # MIT - sha256_urls = { - "c54d24756eb6d744b31e538ad9bab3a75f6d54e2288b29cc72338d4a057d3e83": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-chord/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-chord/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_collection", - licenses = ["notice"], # MIT - sha256_urls = { - "f987667167b1d2970911247e325eb1c37ca0823646f81ccec837ae59039822f7": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-collection/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-collection/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_color", - licenses = ["notice"], # MIT - sha256_urls = { - "9580c81f38ddcce7be0ac9bd3d0d083adebc34e17441709f90b9e4dcd1c19a56": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-color/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-color/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_dispatch", - licenses = ["notice"], # MIT - sha256_urls = { - "169f80b4cceca8e2e9ed384d81a5db0624cc01a26451dfb5a7e0cec6ea9cfb06": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-dispatch/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-dispatch/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_drag", - licenses = ["notice"], # MIT - sha256_urls = { - "08d35d139dde58c2722be98d718d01204fd6167d310f09b379e832f3c741489d": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-drag/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-drag/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_dsv", - licenses = ["notice"], # MIT - sha256_urls = { - "62594d00cf9e4bb895339c8e56f64330e202a5eb2a0fa580a1f6e6336f2c93ce": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-dsv/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-dsv/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_ease", - licenses = ["notice"], # MIT - sha256_urls = { - "d1cf8f99b7bf758c2ba3c0a4ce553e151d4d9b4cf45a6e8bd0edec7ce90f725b": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-ease/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-ease/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_force", - licenses = ["notice"], # MIT - sha256_urls = { - "288421e2008668d2076a4684657dd3d29b992832ef02c552981eb94a91042553": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-force/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-force/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_format", - licenses = ["notice"], # MIT - sha256_urls = { - "b42cb17e580c1fd0b64d478f7bd80ca806efaefda24426a833cf1f30a7275bca": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-format/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-format/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_hierarchy", - licenses = ["notice"], # MIT - sha256_urls = { - "a5683f5835d8716c6b89c075235078438cfab5897023ed720bfa492e244e969e": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-hierarchy/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-hierarchy/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_interpolate", - licenses = ["notice"], # MIT - sha256_urls = { - "590a71b741323ac3139b333ec8b743e24717fdd5b32bcff48ee521162a9dfe1c": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-interpolate/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-interpolate/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_path", - licenses = ["notice"], # MIT - sha256_urls = { - "96f35ba041bcaa265e2b373ee675177410d44d31c980e4f7fbeefd4bcba15b00": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-path/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-path/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_polygon", - licenses = ["notice"], # MIT - sha256_urls = { - "ce453451e8105cac6a4f4a4263ca2142ebb4bf442e342f470a81da691f220fcb": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-polygon/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-polygon/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_quadtree", - licenses = ["notice"], # MIT - sha256_urls = { - "238e278f1be5d6985a19800800cffee80f81199f71d848e3bbc288d1791a6f90": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-quadtree/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-quadtree/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_queue", - licenses = ["notice"], # MIT - sha256_urls = { - "e6ae19aad83495475653578de64fb9d6bf9764eda6c84d70f7935ec84bcc482e": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-queue/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-queue/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_random", - licenses = ["notice"], # MIT - sha256_urls = { - "d31b92ed86c23ec0a4776f99fa81ff033c95b96c8304d8aa9baf3b94af779aa8": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-random/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-random/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_request", - licenses = ["notice"], # MIT - sha256_urls = { - "44bb7b07d977028e6567540a3303b06fc9b33fb0960bc75c520e0733c840d89f": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-request/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-request/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_scale", - licenses = ["notice"], # MIT - sha256_urls = { - "02ce7c644ba34bd1abb84da2e832f248b048b6a23812be4365bd837f186c9f1f": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-scale/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-scale/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_selection", - licenses = ["notice"], # MIT - sha256_urls = { - "699043ddb28dfa5e46d87bc6a24cfc6d604237f298259d3fb3c7066e05e8c86e": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-selection/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-selection/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_shape", - licenses = ["notice"], # MIT - sha256_urls = { - "62668a7aaaf6232762b544f9f89c0f557ca7cfb0cd343a358dda7ecbe26f5739": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-shape/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-shape/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_time", - licenses = ["notice"], # MIT - sha256_urls = { - "0502490ce682fd9265fb1d5d693ce6cd82e3b05e5f5ee3433731266ecb03d5fc": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-time/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-time/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_timer", - licenses = ["notice"], # MIT - sha256_urls = { - "6f191f9aea704aa64b1defa40dfdff1447a6e6bb815feff1660f894500a9c94d": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-timer/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-timer/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_transition", - licenses = ["notice"], # MIT - sha256_urls = { - "a0a7c0c9bfb5c7d6d9d22a8d16b4484b66d13f2ed226954037546cb3da4098ba": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-transition/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-transition/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_voronoi", - licenses = ["notice"], # MIT - sha256_urls = { - "c6bd5f229f915151d0ef678fe50b1aa6a62334ea0a8c6fc0effbac9f7032efc7": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-voronoi/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-voronoi/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_zoom", - licenses = ["notice"], # MIT - sha256_urls = { - "a25dc17fbd304cf7a0e5e7bbb8339c930d464eb40c4d6e5f839ce9c0191f4110": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-zoom/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-zoom/index.d.ts", - ], - }, - ) diff --git a/third_party/werkzeug.BUILD b/third_party/werkzeug.BUILD deleted file mode 100644 index 72a1402030d150c21b5d43261a4d5e2c0f1bce91..0000000000000000000000000000000000000000 --- a/third_party/werkzeug.BUILD +++ /dev/null @@ -1,14 +0,0 @@ -# Description: -# Werkzeug provides utilities for making WSGI applications - -licenses(["notice"]) # BSD 3-Clause - -exports_files(["LICENSE"]) - -# Note: this library includes test code. Consider creating a testonly target. -py_library( - name = "werkzeug", - srcs = glob(["werkzeug/*.py"]), - srcs_version = "PY2AND3", - visibility = ["//visibility:public"], -) diff --git a/third_party/zlib.BUILD b/third_party/zlib.BUILD index 279e6395b03a3c86b4b3fe25958ebafa4cb75062..9e9817b860aac700a57175c1fa6d7730b8d4e5dd 100644 --- a/third_party/zlib.BUILD +++ b/third_party/zlib.BUILD @@ -2,6 +2,18 @@ package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # BSD/MIT-like license (for zlib) +config_setting( + name = "windows", + values = {"cpu": "x64_windows"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "windows_msvc", + values = {"cpu": "x64_windows_msvc"}, + visibility = ["//visibility:public"], +) + cc_library( name = "zlib", srcs = [ @@ -32,9 +44,13 @@ cc_library( "zutil.h", ], hdrs = ["zlib.h"], - copts = [ - "-Wno-shift-negative-value", - "-Wno-implicit-function-declaration", - ], + copts = select({ + ":windows": [], + ":windows_msvc": [], + "//conditions:default": [ + "-Wno-shift-negative-value", + "-Wno-implicit-function-declaration", + ], + }), includes = ["."], )