diff --git a/.gitignore b/.gitignore index d11a504bdc56ee98b3d5a0c33f9f75d996e45567..be75938ec401b1d72fa54773c85191aaac7d7f35 100644 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,7 @@ node_modules /bazel-* /bazel_pip /tools/python_bin_path.sh -/tools/git/gen +/tensorflow/tools/git/gen /pip_test /_python_build *.pyc @@ -26,4 +26,11 @@ Podfile.lock /tensorflow/contrib/lite/gen/** /tensorflow/contrib/lite/examples/ios/simple/data/*.txt /tensorflow/contrib/lite/examples/ios/simple/data/*.tflite -xcuserdata/** \ No newline at end of file +xcuserdata/** + +# Android +.gradle +.idea +*.iml +local.properties +gradleBuild diff --git a/AUTHORS b/AUTHORS index a46ae7e616ab3a420d9fb2691ee8d8650032a39f..aa4be5169dcc68c579863e8ba6307cd00e9f9a68 100644 --- a/AUTHORS +++ b/AUTHORS @@ -7,4 +7,4 @@ # The email address is not required for organizations. Google Inc. -Yuan Tang terrytangyuan@gmail.com +Yuan Tang diff --git a/CODEOWNERS b/CODEOWNERS index 57a4df40e651f45dc03493af631d73332e46c182..007a304c3e706ce968576ec8979c08f1a3bcc552 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,53 +1,53 @@ # NOTE: Disabled temporarily because it's too noisy on pushes. # Where component owners are known, add them here. -#tensorflow/core/platform/windows/* @mrry -#tensorflow/java/* @asimshankar -#tensorflow/tensorboard/* @jart @dandelionmane -#tensorflow/tools/docs/* @markdaoust +# /tensorflow/core/platform/windows/ @mrry +# /tensorflow/java/ @asimshankar +# /tensorflow/tensorboard/ @jart @dandelionmane +# /tensorflow/tools/docs/ @markdaoust # contrib -# NEED OWNER: tensorflow/contrib/avro/* -#tensorflow/contrib/batching/* @alextp @chrisolston -#tensorflow/contrib/bayesflow/* @ebrevdo @rsepassi @jvdillon -#tensorflow/contrib/boosted_trees/* @sshrdp @yk5 @nataliaponomareva -#tensorflow/contrib/cmake/* @mrry @benoitsteiner -#tensorflow/contrib/copy_graph/* @tucker @poxvoculi -#tensorflow/contrib/crf/* @kentonl -#tensorflow/contrib/data/* @mrry -#tensorflow/contrib/distributions/* @jvdillon @langmore @rsepassi -#tensorflow/contrib/factorization/* @agarwal-ashish @xavigonzalvo -#tensorflow/contrib/ffmpeg/* @fredbertsch -# NEED OWNER: tensorflow/contrib/framework/* -#tensorflow/contrib/graph_editor/* @purpledog -# NEED OWNER: tensorflow/contrib/grid_rnn/* -#tensorflow/contrib/hvx/* @satok16 -#tensorflow/contrib/integrate/* @shoyer -#tensorflow/contrib/kernel_methods/* @petrosmol -#tensorflow/contrib/ios_examples/* @petewarden -#tensorflow/contrib/labeled_tensor/* @shoyer -#tensorflow/contrib/layers/* @fchollet @martinwicke -#tensorflow/contrib/learn/* @martinwicke @ispirmustafa @alextp -#tensorflow/contrib/linalg/* @langmore -#tensorflow/contrib/linear_optimizer/* @petrosmol @andreasst @katsiapis -#tensorflow/contrib/lookup/* @ysuematsu @andreasst -#tensorflow/contrib/losses/* @alextp @ispirmustafa -#tensorflow/contrib/makefile/* @petewarden @satok16 @wolffg -#tensorflow/contrib/metrics/* @alextp @honkentuber @ispirmustafa -#tensorflow/contrib/nccl/* @cwhipkey @zheng-xq -#tensorflow/contrib/opt/* @strategist333 -#tensorflow/contrib/pi_examples/* @maciekcc -#tensorflow/contrib/quantization/* @petewarden @cwhipkey @keveman -#tensorflow/contrib/rnn/* @ebrevdo -#tensorflow/contrib/saved_model/* @nfiedel @sukritiramesh -#tensorflow/contrib/seq2seq/* @lukaszkaiser -#tensorflow/contrib/session_bundle/* @nfiedel @sukritiramesh -#tensorflow/contrib/slim/* @sguada @thenbasilmanran -#tensorflow/contrib/stateless/* @girving -#tensorflow/contrib/tensor_forest/* @gilberthendry @thomascolthurst -#tensorflow/contrib/testing/* @dandelionmane -#tensorflow/contrib/timeseries/* @allenlavoie -#tensorflow/contrib/tpu/* @frankchn @saeta @jhseu -#tensorflow/contrib/training/* @joel-shor @ebrevdo -#tensorflow/contrib/util/* @sherrym +# NEED OWNER: /tensorflow/contrib/avro/ +# /tensorflow/contrib/batching/ @alextp @chrisolston +# /tensorflow/contrib/bayesflow/ @ebrevdo @rsepassi @jvdillon +# /tensorflow/contrib/boosted_trees/ @sshrdp @yk5 @nataliaponomareva +# /tensorflow/contrib/cmake/ @mrry @benoitsteiner +# /tensorflow/contrib/copy_graph/ @tucker @poxvoculi +# /tensorflow/contrib/crf/ @kentonl +# /tensorflow/contrib/data/ @mrry +# /tensorflow/contrib/distributions/ @jvdillon @langmore @rsepassi +# /tensorflow/contrib/factorization/ @agarwal-ashish @xavigonzalvo +# /tensorflow/contrib/ffmpeg/ @fredbertsch +# NEED OWNER: /tensorflow/contrib/framework/ +# /tensorflow/contrib/graph_editor/ @purpledog +# NEED OWNER: /tensorflow/contrib/grid_rnn/ +# /tensorflow/contrib/hvx/ @satok16 +# /tensorflow/contrib/integrate/ @shoyer +# /tensorflow/contrib/kernel_methods/ @petrosmol +# /tensorflow/contrib/ios_examples/ @petewarden +# /tensorflow/contrib/labeled_tensor/ @shoyer +# /tensorflow/contrib/layers/ @fchollet @martinwicke +# /tensorflow/contrib/learn/ @martinwicke @ispirmustafa @alextp +# /tensorflow/contrib/linalg/ @langmore +# /tensorflow/contrib/linear_optimizer/ @petrosmol @andreasst @katsiapis +# /tensorflow/contrib/lookup/ @ysuematsu @andreasst +# /tensorflow/contrib/losses/ @alextp @ispirmustafa +# /tensorflow/contrib/makefile/ @petewarden @satok16 @wolffg +# /tensorflow/contrib/metrics/ @alextp @honkentuber @ispirmustafa +# /tensorflow/contrib/nccl/ @cwhipkey @zheng-xq +# /tensorflow/contrib/opt/ @strategist333 +# /tensorflow/contrib/pi_examples/ @maciekcc +# /tensorflow/contrib/quantization/ @petewarden @cwhipkey @keveman +# /tensorflow/contrib/rnn/ @ebrevdo +# /tensorflow/contrib/saved_model/ @nfiedel @sukritiramesh +# /tensorflow/contrib/seq2seq/ @lukaszkaiser +# /tensorflow/contrib/session_bundle/ @nfiedel @sukritiramesh +# /tensorflow/contrib/slim/ @sguada @thenbasilmanran +# /tensorflow/contrib/stateless/ @girving +# /tensorflow/contrib/tensor_forest/ @gilberthendry @thomascolthurst +# /tensorflow/contrib/testing/ @dandelionmane +# /tensorflow/contrib/timeseries/ @allenlavoie +# /tensorflow/contrib/tpu/ @frankchn @saeta @jhseu +# /tensorflow/contrib/training/ @joel-shor @ebrevdo +# /tensorflow/contrib/util/ @sherrym diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 43abdaafbf45379430920cd027b26299cd62553b..dc96bc2e3d3960827efd109551f8eaa78a6cfb48 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -20,6 +20,9 @@ Follow either of the two links above to access the appropriate CLA and instructi If you have improvements to TensorFlow, send us your pull requests! For those just getting started, Github has a [howto](https://help.github.com/articles/using-pull-requests/). +TensorFlow team members will be assigned to review your pull requests. Once the pull requests are approved and pass continuous integration checks, we will merge the pull requests. +For some pull requests, we will apply the patch for each pull request to our internal version control system first, and export the change out as a new commit later, at which point the original pull request will be closed. The commits in the pull request will be squashed into a single commit with the pull request creator as the author. These pull requests will be labeled as pending merge internally. + If you want to contribute but you're not sure where to start, take a look at the [issues with the "contributions welcome" label](https://github.com/tensorflow/tensorflow/labels/stat%3Acontributions%20welcome). These are issues that we believe are particularly well suited for outside @@ -114,6 +117,7 @@ pylint --rcfile=/tmp/pylintrc myfile.py * [Google Java Style Guide](https://google.github.io/styleguide/javaguide.html) * [Google JavaScript Style Guide](https://google.github.io/styleguide/jsguide.html) * [Google Shell Style Guide](https://google.github.io/styleguide/shell.xml) +* [Google Objective-C Style Guide](http://google.github.io/styleguide/objcguide.html) #### Running sanity check diff --git a/RELEASE.md b/RELEASE.md index d8db1f72004b5d944e3035a0f33dfc34a674b7ee..e04bd3fc505d51ade9e9fa12c822cb695e90b4f3 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -494,7 +494,7 @@ answered questions, and were part of inspiring discussions. This release contains contributions from many people at Google, as well as: A. Besir Kurtulmus, Adal Chiriliuc, @akash, Alec-Desouza, Alex Rothberg, Alex -Sergeev, Alexander Heinecke, Allen Guo, Andreas Madsen, Ankesh Anand, Anton +Sergeev, Alexander Heinecke, Allen Guo, Andreas Madsen, Ankesh Anand, Anton Loss, @Aravind, @Arie, Ashutosh Das, AuréLien Geron, Bairen Yi, @bakunyo, Ben Visser, Brady Zhou, Calpa Liu, Changming Sun, Chih Cheng Liang, Christopher Berner, Clark Zinzow, @Conchylicultor, Dan Ellis, Dan J, Dan Jarvis, Daniel diff --git a/configure.py b/configure.py index 26da09bd947a0aa3887630d8f2205ec058886b1a..fa2ed2450d4801353622b51da1fdb822778c0811 100644 --- a/configure.py +++ b/configure.py @@ -34,8 +34,10 @@ except ImportError: _TF_BAZELRC = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.tf_configure.bazelrc') -_DEFAULT_CUDA_VERSION = '8.0' -_DEFAULT_CUDNN_VERSION = '6' +_TF_WORKSPACE = os.path.join(os.path.dirname(os.path.abspath(__file__)), + 'WORKSPACE') +_DEFAULT_CUDA_VERSION = '9.0' +_DEFAULT_CUDNN_VERSION = '7' _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2' _DEFAULT_CUDA_PATH = '/usr/local/cuda' _DEFAULT_CUDA_PATH_LINUX = '/opt/cuda' @@ -44,6 +46,13 @@ _DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing ' _TF_OPENCL_VERSION = '1.2' _DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp' _DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include' +_SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15] + +_DEFAULT_PROMPT_ASK_ATTEMPTS = 10 + + +class UserInputError(Exception): + pass def is_windows(): @@ -158,7 +167,7 @@ def get_python_path(environ_cp, python_bin_path): try: library_paths = run_shell( [python_bin_path, '-c', - 'import site; print("\\n".join(site.getsitepackages()))']).split("\n") + 'import site; print("\\n".join(site.getsitepackages()))']).split('\n') except subprocess.CalledProcessError: library_paths = [run_shell( [python_bin_path, '-c', @@ -557,6 +566,218 @@ def set_clang_cuda_compiler_path(environ_cp): clang_cuda_compiler_path) +def prompt_loop_or_load_from_env( + environ_cp, + var_name, + var_default, + ask_for_var, + check_success, + error_msg, + suppress_default_error=False, + n_ask_attempts=_DEFAULT_PROMPT_ASK_ATTEMPTS +): + """Loop over user prompts for an ENV param until receiving a valid response. + + For the env param var_name, read from the environment or verify user input + until receiving valid input. When done, set var_name in the environ_cp to its + new value. + + Args: + environ_cp: (Dict) copy of the os.environ. + var_name: (String) string for name of environment variable, e.g. "TF_MYVAR". + var_default: (String) default value string. + ask_for_var: (String) string for how to ask for user input. + check_success: (Function) function that takes one argument and returns a + boolean. Should return True if the value provided is considered valid. May + contain a complex error message if error_msg does not provide enough + information. In that case, set suppress_default_error to True. + error_msg: (String) String with one and only one '%s'. Formatted with each + invalid response upon check_success(input) failure. + suppress_default_error: (Bool) Suppress the above error message in favor of + one from the check_success function. + n_ask_attempts: (Integer) Number of times to query for valid input before + raising an error and quitting. + + Returns: + [String] The value of var_name after querying for input. + + Raises: + UserInputError: if a query has been attempted n_ask_attempts times without + success, assume that the user has made a scripting error, and will continue + to provide invalid input. Raise the error to avoid infinitely looping. + """ + default = environ_cp.get(var_name) or var_default + full_query = '%s [Default is %s]: ' % ( + ask_for_var, + default, + ) + + for _ in range(n_ask_attempts): + val = get_from_env_or_user_or_default(environ_cp, + var_name, + full_query, + default) + if check_success(val): + break + if not suppress_default_error: + print(error_msg % val) + environ_cp[var_name] = '' + else: + raise UserInputError('Invalid %s setting was provided %d times in a row. ' + 'Assuming to be a scripting mistake.' % + (var_name, n_ask_attempts)) + + environ_cp[var_name] = val + return val + + +def create_android_ndk_rule(environ_cp): + """Set ANDROID_NDK_HOME and write Android NDK WORKSPACE rule.""" + if is_windows() or is_cygwin(): + default_ndk_path = cygpath('%s/Android/Sdk/ndk-bundle' % + environ_cp['APPDATA']) + elif is_macos(): + default_ndk_path = '%s/library/Android/Sdk/ndk-bundle' % environ_cp['HOME'] + else: + default_ndk_path = '%s/Android/Sdk/ndk-bundle' % environ_cp['HOME'] + + def valid_ndk_path(path): + return (os.path.exists(path) and + os.path.exists(os.path.join(path, 'source.properties'))) + + android_ndk_home_path = prompt_loop_or_load_from_env( + environ_cp, + var_name='ANDROID_NDK_HOME', + var_default=default_ndk_path, + ask_for_var='Please specify the home path of the Android NDK to use.', + check_success=valid_ndk_path, + error_msg=('The path %s or its child file "source.properties" ' + 'does not exist.') + ) + + write_android_ndk_workspace_rule(android_ndk_home_path) + + +def create_android_sdk_rule(environ_cp): + """Set Android variables and write Android SDK WORKSPACE rule.""" + if is_windows() or is_cygwin(): + default_sdk_path = cygpath('%s/Android/Sdk' % environ_cp['APPDATA']) + elif is_macos(): + default_sdk_path = '%s/library/Android/Sdk/ndk-bundle' % environ_cp['HOME'] + else: + default_sdk_path = '%s/Android/Sdk' % environ_cp['HOME'] + + def valid_sdk_path(path): + return (os.path.exists(path) and + os.path.exists(os.path.join(path, 'platforms')) and + os.path.exists(os.path.join(path, 'build-tools'))) + + android_sdk_home_path = prompt_loop_or_load_from_env( + environ_cp, + var_name='ANDROID_SDK_HOME', + var_default=default_sdk_path, + ask_for_var='Please specify the home path of the Android SDK to use.', + check_success=valid_sdk_path, + error_msg=('Either %s does not exist, or it does not contain the ' + 'subdirectories "platforms" and "build-tools".')) + + platforms = os.path.join(android_sdk_home_path, 'platforms') + api_levels = sorted(os.listdir(platforms)) + api_levels = [x.replace('android-', '') for x in api_levels] + + def valid_api_level(api_level): + return os.path.exists(os.path.join(android_sdk_home_path, + 'platforms', + 'android-' + api_level)) + + android_api_level = prompt_loop_or_load_from_env( + environ_cp, + var_name='ANDROID_API_LEVEL', + var_default=api_levels[-1], + ask_for_var=('Please specify the Android SDK API level to use. ' + '[Available levels: %s]') % api_levels, + check_success=valid_api_level, + error_msg='Android-%s is not present in the SDK path.') + + build_tools = os.path.join(android_sdk_home_path, 'build-tools') + versions = sorted(os.listdir(build_tools)) + + def valid_build_tools(version): + return os.path.exists(os.path.join(android_sdk_home_path, + 'build-tools', + version)) + + android_build_tools_version = prompt_loop_or_load_from_env( + environ_cp, + var_name='ANDROID_BUILD_TOOLS_VERSION', + var_default=versions[-1], + ask_for_var=('Please specify an Android build tools version to use. ' + '[Available versions: %s]') % versions, + check_success=valid_build_tools, + error_msg=('The selected SDK does not have build-tools version %s ' + 'available.')) + + write_android_sdk_workspace_rule(android_sdk_home_path, + android_build_tools_version, + android_api_level) + + +def write_android_sdk_workspace_rule(android_sdk_home_path, + android_build_tools_version, + android_api_level): + print('Writing android_sdk_workspace rule.\n') + with open(_TF_WORKSPACE, 'a') as f: + f.write(""" +android_sdk_repository( + name="androidsdk", + api_level=%s, + path="%s", + build_tools_version="%s")\n +""" % (android_api_level, android_sdk_home_path, android_build_tools_version)) + + +def write_android_ndk_workspace_rule(android_ndk_home_path): + print('Writing android_ndk_workspace rule.') + ndk_api_level = check_ndk_level(android_ndk_home_path) + if int(ndk_api_level) not in _SUPPORTED_ANDROID_NDK_VERSIONS: + print('WARNING: The API level of the NDK in %s is %s, which is not ' + 'supported by Bazel (officially supported versions: %s). Please use ' + 'another version. Compiling Android targets may result in confusing ' + 'errors.\n' % (android_ndk_home_path, ndk_api_level, + _SUPPORTED_ANDROID_NDK_VERSIONS)) + with open(_TF_WORKSPACE, 'a') as f: + f.write(""" +android_ndk_repository( + name="androidndk", + path="%s", + api_level=%s)\n +""" % (android_ndk_home_path, ndk_api_level)) + + +def check_ndk_level(android_ndk_home_path): + """Check the revision number of an Android NDK path.""" + properties_path = '%s/source.properties' % android_ndk_home_path + if is_windows() or is_cygwin(): + properties_path = cygpath(properties_path) + with open(properties_path, 'r') as f: + filedata = f.read() + + revision = re.search(r'Pkg.Revision = (\d+)', filedata) + if revision: + return revision.group(1) + return None + + +def workspace_has_any_android_rule(): + """Check the WORKSPACE for existing android_*_repository rules.""" + with open(_TF_WORKSPACE, 'r') as f: + workspace = f.read() + has_any_rule = re.search(r'^android_[ns]dk_repository', + workspace, + re.MULTILINE) + return has_any_rule + + def set_gcc_host_compiler_path(environ_cp): """Set GCC_HOST_COMPILER_PATH.""" default_gcc_host_compiler_path = which('gcc') or '' @@ -566,23 +787,16 @@ def set_gcc_host_compiler_path(environ_cp): # os.readlink is only available in linux default_gcc_host_compiler_path = os.path.realpath(cuda_bin_symlink) - ask_gcc_path = ( - 'Please specify which gcc should be used by nvcc as the ' - 'host compiler. [Default is %s]: ') % default_gcc_host_compiler_path - while True: - gcc_host_compiler_path = get_from_env_or_user_or_default( - environ_cp, 'GCC_HOST_COMPILER_PATH', ask_gcc_path, - default_gcc_host_compiler_path) - - if os.path.exists(gcc_host_compiler_path): - break - - # Reset and retry - print('Invalid gcc path. %s cannot be found' % gcc_host_compiler_path) - environ_cp['GCC_HOST_COMPILER_PATH'] = '' + gcc_host_compiler_path = prompt_loop_or_load_from_env( + environ_cp, + var_name='GCC_HOST_COMPILER_PATH', + var_default=default_gcc_host_compiler_path, + ask_for_var= + 'Please specify which gcc should be used by nvcc as the host compiler.', + check_success=os.path.exists, + error_msg='Invalid gcc path. %s cannot be found.', + ) - # Set GCC_HOST_COMPILER_PATH - environ_cp['GCC_HOST_COMPILER_PATH'] = gcc_host_compiler_path write_action_env_to_bazelrc('GCC_HOST_COMPILER_PATH', gcc_host_compiler_path) @@ -592,7 +806,7 @@ def set_tf_cuda_version(environ_cp): 'Please specify the CUDA SDK version you want to use, ' 'e.g. 7.0. [Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION - while True: + for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): # Configure the Cuda SDK version to use. tf_cuda_version = get_from_env_or_user_or_default( environ_cp, 'TF_CUDA_VERSION', ask_cuda_version, _DEFAULT_CUDA_VERSION) @@ -630,6 +844,11 @@ def set_tf_cuda_version(environ_cp): environ_cp['TF_CUDA_VERSION'] = '' environ_cp['CUDA_TOOLKIT_PATH'] = '' + else: + raise UserInputError('Invalid TF_CUDA_SETTING setting was provided %d ' + 'times in a row. Assuming to be a scripting mistake.' % + _DEFAULT_PROMPT_ASK_ATTEMPTS) + # Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION environ_cp['CUDA_TOOLKIT_PATH'] = cuda_toolkit_path write_action_env_to_bazelrc('CUDA_TOOLKIT_PATH', cuda_toolkit_path) @@ -643,7 +862,7 @@ def set_tf_cudnn_version(environ_cp): 'Please specify the cuDNN version you want to use. ' '[Leave empty to default to cuDNN %s.0]: ') % _DEFAULT_CUDNN_VERSION - while True: + for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): tf_cudnn_version = get_from_env_or_user_or_default( environ_cp, 'TF_CUDNN_VERSION', ask_cudnn_version, _DEFAULT_CUDNN_VERSION) @@ -702,6 +921,10 @@ def set_tf_cudnn_version(environ_cp): print('%s.%s' % (cudnn_path_from_ldconfig, tf_cudnn_version)) environ_cp['TF_CUDNN_VERSION'] = '' + else: + raise UserInputError('Invalid TF_CUDNN setting was provided %d ' + 'times in a row. Assuming to be a scripting mistake.' % + _DEFAULT_PROMPT_ASK_ATTEMPTS) # Set CUDNN_INSTALL_PATH and TF_CUDNN_VERSION environ_cp['CUDNN_INSTALL_PATH'] = cudnn_install_path @@ -810,76 +1033,66 @@ def set_other_cuda_vars(environ_cp): def set_host_cxx_compiler(environ_cp): """Set HOST_CXX_COMPILER.""" default_cxx_host_compiler = which('g++') or '' - ask_cxx_host_compiler = ( - 'Please specify which C++ compiler should be used as' - ' the host C++ compiler. [Default is %s]: ') % default_cxx_host_compiler - - while True: - host_cxx_compiler = get_from_env_or_user_or_default( - environ_cp, 'HOST_CXX_COMPILER', ask_cxx_host_compiler, - default_cxx_host_compiler) - if os.path.exists(host_cxx_compiler): - break - # Reset and retry - print('Invalid C++ compiler path. %s cannot be found' % host_cxx_compiler) - environ_cp['HOST_CXX_COMPILER'] = '' + host_cxx_compiler = prompt_loop_or_load_from_env( + environ_cp, + var_name='HOST_CXX_COMPILER', + var_default=default_cxx_host_compiler, + ask_for_var=('Please specify which C++ compiler should be used as the ' + 'host C++ compiler.'), + check_success=os.path.exists, + error_msg='Invalid C++ compiler path. %s cannot be found.', + ) - # Set HOST_CXX_COMPILER - environ_cp['HOST_CXX_COMPILER'] = host_cxx_compiler write_action_env_to_bazelrc('HOST_CXX_COMPILER', host_cxx_compiler) def set_host_c_compiler(environ_cp): """Set HOST_C_COMPILER.""" default_c_host_compiler = which('gcc') or '' - ask_c_host_compiler = ( - 'Please specify which C compiler should be used as the' - ' host C compiler. [Default is %s]: ') % default_c_host_compiler - - while True: - host_c_compiler = get_from_env_or_user_or_default( - environ_cp, 'HOST_C_COMPILER', ask_c_host_compiler, - default_c_host_compiler) - if os.path.exists(host_c_compiler): - break - # Reset and retry - print('Invalid C compiler path. %s cannot be found' % host_c_compiler) - environ_cp['HOST_C_COMPILER'] = '' + host_c_compiler = prompt_loop_or_load_from_env( + environ_cp, + var_name='HOST_C_COMPILER', + var_default=default_c_host_compiler, + ask_for_var=('Please specify which C compiler should be used as the host' + 'C compiler.'), + check_success=os.path.exists, + error_msg='Invalid C compiler path. %s cannot be found.', + ) - # Set HOST_C_COMPILER - environ_cp['HOST_C_COMPILER'] = host_c_compiler write_action_env_to_bazelrc('HOST_C_COMPILER', host_c_compiler) def set_computecpp_toolkit_path(environ_cp): """Set COMPUTECPP_TOOLKIT_PATH.""" - ask_computecpp_toolkit_path = ('Please specify the location where ComputeCpp ' - 'for SYCL %s is installed. [Default is %s]: ' - ) % (_TF_OPENCL_VERSION, - _DEFAULT_COMPUTECPP_TOOLKIT_PATH) - while True: - computecpp_toolkit_path = get_from_env_or_user_or_default( - environ_cp, 'COMPUTECPP_TOOLKIT_PATH', ask_computecpp_toolkit_path, - _DEFAULT_COMPUTECPP_TOOLKIT_PATH) + def toolkit_exists(toolkit_path): + """Check if a computecpp toolkit path is valid.""" if is_linux(): sycl_rt_lib_path = 'lib/libComputeCpp.so' else: sycl_rt_lib_path = '' - sycl_rt_lib_path_full = os.path.join(computecpp_toolkit_path, + sycl_rt_lib_path_full = os.path.join(toolkit_path, sycl_rt_lib_path) - if os.path.exists(sycl_rt_lib_path_full): - break + exists = os.path.exists(sycl_rt_lib_path_full) + if not exists: + print('Invalid SYCL %s library path. %s cannot be found' % + (_TF_OPENCL_VERSION, sycl_rt_lib_path_full)) + return exists - print('Invalid SYCL %s library path. %s cannot be found' % - (_TF_OPENCL_VERSION, sycl_rt_lib_path_full)) - environ_cp['COMPUTECPP_TOOLKIT_PATH'] = '' + computecpp_toolkit_path = prompt_loop_or_load_from_env( + environ_cp, + var_name='COMPUTECPP_TOOLKIT_PATH', + var_default=_DEFAULT_COMPUTECPP_TOOLKIT_PATH, + ask_for_var=( + 'Please specify the location where ComputeCpp for SYCL %s is ' + 'installed.' % _TF_OPENCL_VERSION), + check_success=toolkit_exists, + error_msg='Invalid SYCL compiler path. %s cannot be found.', + suppress_default_error=True) - # Set COMPUTECPP_TOOLKIT_PATH - environ_cp['COMPUTECPP_TOOLKIT_PATH'] = computecpp_toolkit_path write_action_env_to_bazelrc('COMPUTECPP_TOOLKIT_PATH', computecpp_toolkit_path) @@ -905,28 +1118,30 @@ def set_trisycl_include_dir(environ_cp): write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir) + def set_mpi_home(environ_cp): """Set MPI_HOME.""" + default_mpi_home = which('mpirun') or which('mpiexec') or '' default_mpi_home = os.path.dirname(os.path.dirname(default_mpi_home)) - ask_mpi_home = ('Please specify the MPI toolkit folder. [Default is %s]: ' - ) % default_mpi_home - while True: - mpi_home = get_from_env_or_user_or_default(environ_cp, 'MPI_HOME', - ask_mpi_home, default_mpi_home) - - if os.path.exists(os.path.join(mpi_home, 'include')) and os.path.exists( - os.path.join(mpi_home, 'lib')): - break - - print('Invalid path to the MPI Toolkit. %s or %s cannot be found' % - (os.path.join(mpi_home, 'include'), - os.path.exists(os.path.join(mpi_home, 'lib')))) - environ_cp['MPI_HOME'] = '' + def valid_mpi_path(mpi_home): + exists = (os.path.exists(os.path.join(mpi_home, 'include')) and + os.path.exists(os.path.join(mpi_home, 'lib'))) + if not exists: + print('Invalid path to the MPI Toolkit. %s or %s cannot be found' % + (os.path.join(mpi_home, 'include'), + os.path.exists(os.path.join(mpi_home, 'lib')))) + return exists - # Set MPI_HOME - environ_cp['MPI_HOME'] = str(mpi_home) + _ = prompt_loop_or_load_from_env( + environ_cp, + var_name='MPI_HOME', + var_default=default_mpi_home, + ask_for_var='Please specify the MPI toolkit folder.', + check_success=valid_mpi_path, + error_msg='', + suppress_default_error=True) def set_other_mpi_vars(environ_cp): @@ -969,7 +1184,7 @@ def set_mkl(): 'support.\nPlease note that MKL on MacOS or windows is still not ' 'supported.\nIf you would like to use a local MKL instead of ' 'downloading, please set the environment variable \"TF_MKL_ROOT\" every ' - 'time before build.') + 'time before build.\n') def set_monolithic(): @@ -1001,6 +1216,15 @@ def create_android_bazelrc_configs(): def set_grpc_build_flags(): write_to_bazelrc('build --define grpc_no_ares=true') +def set_windows_build_flags(): + if is_windows(): + # The non-monolithic build is not supported yet + write_to_bazelrc('build --config monolithic') + # Suppress warning messages + write_to_bazelrc('build --copt=-w --host_copt=-w') + # Output more verbose information when something goes wrong + write_to_bazelrc('build --verbose_failures') + def main(): # Make a copy of os.environ to be clear when functions and getting and setting @@ -1079,7 +1303,25 @@ def main(): set_cc_opt_flags(environ_cp) set_mkl() set_monolithic() + set_windows_build_flags() create_android_bazelrc_configs() + if workspace_has_any_android_rule(): + print('The WORKSPACE file has at least one of ["android_sdk_repository", ' + '"android_ndk_repository"] already set. Will not ask to help ' + 'configure the WORKSPACE. Please delete the existing rules to ' + 'activate the helper.\n') + else: + if get_var( + environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace', + False, + ('Would you like to interactively configure ./WORKSPACE for ' + 'Android builds?'), + 'Searching for NDK and SDK installations.', + 'Not configuring the WORKSPACE for Android builds.'): + create_android_ndk_rule(environ_cp) + create_android_sdk_rule(environ_cp) + + if __name__ == '__main__': main() diff --git a/tensorflow/BUILD b/tensorflow/BUILD index c0a47cf6b4ae2dcfab15472758023480fb48482d..259dde384c794e980be7e958b2448dc92b9be441 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -364,14 +364,6 @@ config_setting( visibility = ["//visibility:public"], ) -# Make a dummy rule that we can change "default" in select statements to. -# to disable dependencies in copybara. -config_setting( - name = "dummy_disabled_internal", - values = {"define": "with_dummy_disabled_internal=true"}, - visibility = ["//visibility:public"], -) - package_group( name = "internal", packages = [ @@ -427,6 +419,7 @@ filegroup( "//tensorflow/compiler/xla/client:all_files", "//tensorflow/compiler/xla/client/lib:all_files", "//tensorflow/compiler/xla/legacy_flags:all_files", + "//tensorflow/compiler/xla/python:all_files", "//tensorflow/compiler/xla/service:all_files", "//tensorflow/compiler/xla/service/cpu:all_files", "//tensorflow/compiler/xla/service/gpu:all_files", @@ -462,6 +455,7 @@ filegroup( "//tensorflow/contrib/data/python/ops:all_files", "//tensorflow/contrib/decision_trees/proto:all_files", "//tensorflow/contrib/distributions:all_files", + "//tensorflow/contrib/eager/proto:all_files", "//tensorflow/contrib/eager/python:all_files", "//tensorflow/contrib/estimator:all_files", "//tensorflow/contrib/factorization:all_files", @@ -562,6 +556,7 @@ filegroup( "//tensorflow/contrib/timeseries/python/timeseries/state_space_models:all_files", "//tensorflow/contrib/tpu:all_files", "//tensorflow/contrib/tpu/profiler:all_files", + "//tensorflow/contrib/tpu/proto:all_files", "//tensorflow/contrib/training:all_files", "//tensorflow/contrib/util:all_files", "//tensorflow/contrib/verbs:all_files", @@ -576,6 +571,8 @@ filegroup( "//tensorflow/core/grappler/optimizers:all_files", "//tensorflow/core/grappler/utils:all_files", "//tensorflow/core/kernels:all_files", + "//tensorflow/core/kernels/data:all_files", + "//tensorflow/core/kernels/data/sql:all_files", "//tensorflow/core/kernels/fuzzing:all_files", "//tensorflow/core/kernels/hexagon:all_files", "//tensorflow/core/kernels/neon:all_files", @@ -609,6 +606,7 @@ filegroup( "//tensorflow/java/src/main/native:all_files", "//tensorflow/python:all_files", "//tensorflow/python/data:all_files", + "//tensorflow/python/data/kernel_tests:all_files", "//tensorflow/python/data/ops:all_files", "//tensorflow/python/data/util:all_files", "//tensorflow/python/debug:all_files", @@ -645,6 +643,7 @@ filegroup( "//tensorflow/tools/test:all_files", "//tensorflow/user_ops:all_files", "//third_party/hadoop:all_files", + "//third_party/mpi:all_files", "//third_party/sycl:all_files", "//third_party/sycl/sycl:all_files", ], diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index dd638de3c6933fde6214993ae7b15b40b1acf65b..9b5704702841081d7dde78ac019305140066f688 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -383,12 +383,11 @@ void TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers, // be less than the total node count. Status ValidateNoCycles(const Graph& g) { // TODO(nolivia): check this on a subset of the graph instead of all of it. - int total_num_nodes = g.num_node_ids(); // A node is ready when all of its inputs have been visited. std::vector ready; - std::vector pending_count(total_num_nodes, 0); + std::vector pending_count(g.num_node_ids(), 0); - for (int i = 0; i < total_num_nodes; ++i) { + for (int i = 0; i < g.num_node_ids(); ++i) { const Node* n = g.FindNodeId(i); if (n == nullptr) continue; pending_count[i] = n->in_edges().size(); @@ -421,7 +420,7 @@ Status ValidateNoCycles(const Graph& g) { } } - if (processed < total_num_nodes) { + if (processed < g.num_nodes()) { std::vector nodes_in_cycle; for (int i = 0; i < pending_count.size() && nodes_in_cycle.size() < 3; ++i) { @@ -430,7 +429,7 @@ Status ValidateNoCycles(const Graph& g) { } } return errors::InvalidArgument( - "Graph is invalid, contains a cycle with ", total_num_nodes - processed, + "Graph is invalid, contains a cycle with ", g.num_nodes() - processed, " nodes, including: ", str_util::Join(nodes_in_cycle, ", ")); } return Status::OK(); @@ -580,6 +579,7 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, status->status = InvalidArgument( "invalid string tensor encoding (string #", i, " of ", srcarray.size(), "): ", status->status.error_message()); + delete[] base; return nullptr; } dst += consumed; @@ -589,6 +589,7 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, status->status = InvalidArgument( "invalid string tensor encoding (decoded ", (dst - base), " bytes, but the tensor is encoded in ", size, " bytes"); + delete[] base; return nullptr; } @@ -625,6 +626,23 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in, return Status::OK(); } +void RecordMutation(TF_Graph* graph, const TF_Operation& op, + const char* mutation_type) + EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { + // If any session has already run this node_id, mark this session as + // unrunnable. + for (auto it : graph->sessions) { + if (it.first->last_num_graph_nodes > op.node.id()) { + it.second = FailedPrecondition( + "Operation '", op.node.DebugString(), "' was changed by ", + mutation_type, + " after it was run by a session. Nodes can be mutated " + "only before they are executed by a session. Either don't modify " + "nodes after running them or create a new session."); + } + } +} + // Helpers for loading a TensorFlow plugin (a .so file). Status LoadLibrary(const char* library_filename, void** result, const void** buf, size_t* len); @@ -939,13 +957,17 @@ void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output, return; } - std::vector dim_vec; - dim_vec.reserve(num_dims); - for (int i = 0; i < num_dims; ++i) { - dim_vec.push_back(ic->MakeDim(dims[i])); + tensorflow::shape_inference::ShapeHandle new_shape; + if (num_dims != -1) { + std::vector dim_vec; + dim_vec.reserve(num_dims); + for (int i = 0; i < num_dims; ++i) { + dim_vec.push_back(ic->MakeDim(dims[i])); + } + new_shape = ic->MakeShape(dim_vec); + } else { + new_shape = ic->UnknownShape(); } - - tensorflow::shape_inference::ShapeHandle new_shape = ic->MakeShape(dim_vec); status->status = graph->refiner.SetShape(node, output.index, new_shape); } @@ -1741,7 +1763,6 @@ void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def, TF_Graph::TF_Graph() : graph(tensorflow::OpRegistry::Global()), refiner(graph.versions().producer(), graph.op_registry()), - num_sessions(0), delete_requested(false), parent(nullptr), parent_inputs(nullptr) {} @@ -1751,7 +1772,7 @@ TF_Graph* TF_NewGraph() { return new TF_Graph; } void TF_DeleteGraph(TF_Graph* g) { g->mu.lock(); g->delete_requested = true; - const bool del = g->num_sessions == 0; + const bool del = g->sessions.empty(); g->mu.unlock(); if (del) delete g; } @@ -1831,6 +1852,16 @@ void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts, opts->opts.prefix = prefix; } +void TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions* opts, + unsigned char uniquify_names) { + opts->opts.uniquify_names = uniquify_names; +} + +void TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions* opts, + unsigned char uniquify_prefix) { + opts->opts.uniquify_prefix = uniquify_prefix; +} + void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts, const char* src_name, int src_index, TF_Output dst) { @@ -1888,12 +1919,12 @@ void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results, *opers = results->return_nodes.data(); } -void TF_ImportGraphDefResultsUnusedInputMappings( - TF_ImportGraphDefResults* results, int* num_unused_input_mappings, +void TF_ImportGraphDefResultsMissingUnusedInputMappings( + TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings, const char*** src_names, int** src_indexes) { - *num_unused_input_mappings = results->unused_key_names.size(); - *src_names = results->unused_key_names.data(); - *src_indexes = results->unused_key_indexes.data(); + *num_missing_unused_input_mappings = results->missing_unused_key_names.size(); + *src_names = results->missing_unused_key_names.data(); + *src_indexes = results->missing_unused_key_indexes.data(); } void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) { @@ -1933,18 +1964,21 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def, tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]); } - // Populate unused map keys - DCHECK(tf_results->unused_key_names.empty()); - DCHECK(tf_results->unused_key_indexes.empty()); - DCHECK(tf_results->unused_key_names_data.empty()); - tf_results->unused_key_names.resize(results.unused_input_map_keys.size()); - tf_results->unused_key_indexes.resize(results.unused_input_map_keys.size()); - for (int i = 0; i < results.unused_input_map_keys.size(); ++i) { - TensorId id = results.unused_input_map_keys[i]; - tf_results->unused_key_names_data.push_back(id.first.ToString()); - tf_results->unused_key_names[i] = - tf_results->unused_key_names_data.back().c_str(); - tf_results->unused_key_indexes[i] = id.second; + // Populate missing unused map keys + DCHECK(tf_results->missing_unused_key_names.empty()); + DCHECK(tf_results->missing_unused_key_indexes.empty()); + DCHECK(tf_results->missing_unused_key_names_data.empty()); + + size_t size = results.missing_unused_input_map_keys.size(); + tf_results->missing_unused_key_names.resize(size); + tf_results->missing_unused_key_indexes.resize(size); + + for (int i = 0; i < size; ++i) { + TensorId id = results.missing_unused_input_map_keys[i]; + tf_results->missing_unused_key_names_data.push_back(id.first.ToString()); + tf_results->missing_unused_key_names[i] = + tf_results->missing_unused_key_names_data.back().c_str(); + tf_results->missing_unused_key_indexes[i] = id.second; } } @@ -2321,11 +2355,12 @@ TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt, Session* session; status->status = NewSession(opt->options, &session); if (status->status.ok()) { + TF_Session* new_session = new TF_Session(session, graph); if (graph != nullptr) { mutex_lock l(graph->mu); - graph->num_sessions += 1; + graph->sessions[new_session] = Status::OK(); } - return new TF_Session(session, graph); + return new_session; } else { DCHECK_EQ(nullptr, session); return nullptr; @@ -2389,7 +2424,7 @@ TF_Session* TF_LoadSessionFromSavedModel( TF_Session* session = new TF_Session(bundle.session.release(), graph); - graph->num_sessions += 1; + graph->sessions[session] = Status::OK(); session->last_num_graph_nodes = graph->graph.num_node_ids(); return session; #endif // __ANDROID__ @@ -2404,8 +2439,8 @@ void TF_DeleteSession(TF_Session* s, TF_Status* status) { TF_Graph* const graph = s->graph; if (graph != nullptr) { graph->mu.lock(); - graph->num_sessions -= 1; - const bool del = graph->delete_requested && graph->num_sessions == 0; + graph->sessions.erase(s); + const bool del = graph->delete_requested && graph->sessions.empty(); graph->mu.unlock(); if (del) delete graph; } @@ -2421,6 +2456,13 @@ static bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { mutex_lock session_lock(session->mu); session->graph->mu.lock(); const Graph& graph = session->graph->graph; + + status->status = session->graph->sessions[session]; + if (!status->status.ok()) { + session->graph->mu.unlock(); + return false; + } + const auto num_nodes = graph.num_node_ids(); if (session->last_num_graph_nodes < num_nodes) { status->status = tensorflow::ValidateNoCycles(session->graph->graph); diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index bb569d67fcbcec29e9494236abd79b3e40db91cd..de9527f86d1f48846b160230c592a398e00e10c5 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -889,6 +889,20 @@ TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefOptions( TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetPrefix( TF_ImportGraphDefOptions* opts, const char* prefix); +// Set whether to uniquify imported operation names. If true, imported operation +// names will be modified if their name already exists in the graph. If false, +// conflicting names will be treated as an error. Note that this option has no +// effect if a prefix is set, since the prefix will guarantee all names are +// unique. Defaults to false. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetUniquifyNames( + TF_ImportGraphDefOptions* opts, unsigned char uniquify_names); + +// If true, the specified prefix will be modified if it already exists as an +// operation name or prefix in the graph. If false, a conflicting prefix will be +// treated as an error. This option has no effect if no prefix is specified. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetUniquifyPrefix( + TF_ImportGraphDefOptions* opts, unsigned char uniquify_prefix); + // Set any imported nodes with input `src_name:src_index` to have that input // replaced with `dst`. `src_name` refers to a node in the graph to be imported, // `dst` references a node already existing in the graph being imported into. @@ -948,16 +962,16 @@ TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsReturnOperations( TF_ImportGraphDefResults* results, int* num_opers, TF_Operation*** opers); // Fetches any input mappings requested via -// TF_ImportGraphDefOptionsAddInputMapping() that weren't used as input to any -// node in the imported graph def. The number of fetched mappings is returned in -// `num_unused_input_mappings`. The array of each mapping's source node name is -// returned in `src_names`, and the array of each mapping's source index is -// returned in `src_indexes`. +// TF_ImportGraphDefOptionsAddInputMapping() that didn't appear in the GraphDef +// and weren't used as input to any node in the imported graph def. The number +// of fetched mappings is returned in `num_missing_unused_input_mappings`. The +// array of each mapping's source node name is returned in `src_names`, and the +// array of each mapping's source index is returned in `src_indexes`. // // `*src_names`, `*src_indexes`, and the memory backing each string in // `src_names` are owned by and have the lifetime of `results`. -TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsUnusedInputMappings( - TF_ImportGraphDefResults* results, int* num_unused_input_mappings, +TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsMissingUnusedInputMappings( + TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings, const char*** src_names, int** src_indexes); // Deletes a results object returned by TF_GraphImportGraphDefWithResults(). diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index dcb818b88b6fca460852beb6e948d2eb6964f663..d60d1de315ed37a327bd036ddb914a3c32413f65 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -68,7 +68,7 @@ class NodeNameMapping { // This is a superset of values in name_mapping_. std::unordered_set used_names_; // Mapping from original node name from the graph to the normalized - // and uniqified version of it. + // and uniquified version of it. std::unordered_map name_mapping_; }; @@ -226,12 +226,17 @@ Status FillFunctionBody( } node_def->add_input(strings::StrCat("^", normalized)); } + + // A function is stateful if any of its nodes are stateful. + if (node->op_def().is_stateful()) { + fdef->mutable_signature()->set_is_stateful(true); + } } return Status::OK(); } // Graph to FunctionDef conversion. This code is closely modeled on the Python -// code in third_party/tensorflow/python/framework/function.py. +// code in tensorflow/python/framework/function.py. Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, bool append_hash_to_fn_name, const std::vector& body_nodes, diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index d5580b658992413ae6f9cb79ef88751ee28ce465..2e2293ca85175009d1bcc8db5c830789e4701c1d 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -1482,6 +1482,51 @@ TEST_F(CApiFunctionTest, GetOpDef) { EXPECT_EQ(op_def.name(), func_name_); EXPECT_EQ(op_def.input_arg_size(), 1); EXPECT_EQ(op_def.output_arg_size(), 1); + EXPECT_FALSE(op_def.is_stateful()); + + TF_DeleteBuffer(buffer); +} + +void DefineStatefulFunction(const char* name, TF_Function** func) { + std::unique_ptr func_graph( + TF_NewGraph(), TF_DeleteGraph); + std::unique_ptr s(TF_NewStatus(), + TF_DeleteStatus); + + TF_Tensor* tensor_shape = Int32Tensor({37, 1}); + TF_Operation* shape = Const(tensor_shape, func_graph.get(), s.get(), "shape"); + TF_Operation* random = + RandomUniform(shape, TF_FLOAT, func_graph.get(), s.get()); + + TF_Output inputs[] = {}; + TF_Output outputs[] = {{random, 0}}; + *func = TF_GraphToFunction(func_graph.get(), name, /*append_hash=*/false, -1, + /*opers=*/nullptr, 0, inputs, 1, outputs, + /*output_names=*/nullptr, + /*opts=*/nullptr, "", s.get()); + ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get()); + ASSERT_NE(*func, nullptr); + TF_DeleteTensor(tensor_shape); +} + +TEST_F(CApiFunctionTest, StatefulOpDef) { + DefineStatefulFunction(func_name_, &func_); + TF_GraphCopyFunction(host_graph_, func_, nullptr, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Test we can retrieve function OpDef from graph + TF_Buffer* buffer = TF_NewBuffer(); + TF_GraphGetOpDef(host_graph_, func_name_, buffer, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Sanity check returned OpDef + string data(static_cast(buffer->data), buffer->length); + OpDef op_def; + op_def.ParseFromString(data); + EXPECT_EQ(op_def.name(), func_name_); + EXPECT_EQ(op_def.input_arg_size(), 0); + EXPECT_EQ(op_def.output_arg_size(), 1); + EXPECT_TRUE(op_def.is_stateful()); TF_DeleteBuffer(buffer); } diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index bb04e01beec931a8ea66d0855eec9625d3a6a5ab..6df77a7f9baed999a2f2cb5e9404cb63451b6212 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -81,12 +81,20 @@ struct TF_Graph { std::unordered_map name_map GUARDED_BY(mu); - // TF_Graph may only / must be deleted when - // num_sessions == 0 && delete_requested == true - - // num_sessions incremented by TF_NewSession, and decremented by + // The keys of this map are all the active sessions using this graph. + // Each value is the current "runnability" status of the corresponding + // session. Under normal conditions all statuses are Status::OK(), but + // if some operation is mutated after it was run by a session (this + // is detected in RecordMutation function), that session is no longer + // safe to run. Its status will contain the error that will be returned + // to the user, should she try running this session. + // + // Sessions are added to this map in TF_NewSession, and removed in // TF_DeleteSession. - int num_sessions GUARDED_BY(mu); + // TF_Graph may only / must be deleted when + // sessions.size() == 0 && delete_requested == true + tensorflow::gtl::FlatMap sessions + GUARDED_BY(mu); bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph // Used to link graphs contained in TF_WhileParams to the parent graph that @@ -135,11 +143,11 @@ struct TF_ImportGraphDefOptions { struct TF_ImportGraphDefResults { std::vector return_tensors; std::vector return_nodes; - std::vector unused_key_names; - std::vector unused_key_indexes; + std::vector missing_unused_key_names; + std::vector missing_unused_key_indexes; - // Backing memory for unused_key_names values. - std::list unused_key_names_data; + // Backing memory for missing_unused_key_names values. + std::list missing_unused_key_names_data; }; struct TF_DeviceList { @@ -167,6 +175,9 @@ TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status); Status MessageToBuffer(const tensorflow::protobuf::Message& in, TF_Buffer* out); +void RecordMutation(TF_Graph* graph, const TF_Operation& op, + const char* mutation_type); + } // end namespace tensorflow #endif // TENSORFLOW_C_C_API_INTERNAL_H_ diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index e0057eb51cd82e8d9ed5fcf56e296f9fb0c2fe40..4e89b4fc43973e4cc9a6c64f50e288d81bf22033 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -287,6 +287,13 @@ TEST(CAPI, SetShape) { ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); EXPECT_EQ(-1, num_dims); + // Set the shape to be unknown, expect no change. + TF_GraphSetTensorShape(graph, feed_out_0, /*dims=*/nullptr, -1, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + num_dims = TF_GraphGetTensorNumDims(graph, feed_out_0, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(-1, num_dims); + // Set the shape to be 2 x Unknown int64_t dims[] = {2, -1}; TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s); @@ -315,7 +322,17 @@ TEST(CAPI, SetShape) { EXPECT_EQ(dims[0], returned_dims[0]); EXPECT_EQ(dims[1], returned_dims[1]); - // Try to set 'unknown' on the shape and see that + // Try to set 'unknown' with unknown rank on the shape and see that + // it doesn't change. + TF_GraphSetTensorShape(graph, feed_out_0, /*dims=*/nullptr, -1, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(2, num_dims); + EXPECT_EQ(2, returned_dims[0]); + EXPECT_EQ(3, returned_dims[1]); + + // Try to set 'unknown' with same rank on the shape and see that // it doesn't change. dims[0] = -1; dims[1] = -1; @@ -756,7 +773,7 @@ TEST(CAPI, ImportGraphDef_WithReturnOutputs) { TF_DeleteStatus(s); } -TEST(CAPI, ImportGraphDef_UnusedInputMappings) { +TEST(CAPI, ImportGraphDef_MissingUnusedInputMappings) { TF_Status* s = TF_NewStatus(); TF_Graph* graph = TF_NewGraph(); @@ -799,7 +816,7 @@ TEST(CAPI, ImportGraphDef_UnusedInputMappings) { int num_unused_input_mappings; const char** src_names; int* src_indexes; - TF_ImportGraphDefResultsUnusedInputMappings( + TF_ImportGraphDefResultsMissingUnusedInputMappings( results, &num_unused_input_mappings, &src_names, &src_indexes); ASSERT_EQ(1, num_unused_input_mappings); EXPECT_EQ(string("fake"), string(src_names[0])); diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc index c291a2e440a8515e968b0ce0395b289080f04e8b..37439ff0beac5a5220460465e954b6c093ee1ba9 100644 --- a/tensorflow/c/c_test_util.cc +++ b/tensorflow/c/c_test_util.cc @@ -193,6 +193,15 @@ TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, return TF_FinishOperation(desc, s); } +TF_Operation* RandomUniform(TF_Operation* shape, TF_DataType dtype, + TF_Graph* graph, TF_Status* s) { + TF_OperationDescription* desc = + TF_NewOperation(graph, "RandomUniform", "random_uniform"); + TF_AddInput(desc, {shape, 0}); + TF_SetAttrType(desc, "dtype", dtype); + return TF_FinishOperation(desc, s); +} + void Split3Helper(TF_Operation* input, TF_Graph* graph, TF_Status* s, const char* name, TF_Operation** op) { TF_Operation* zero = ScalarConst( diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h index d54733749248fa32c39d88bb0281d329dd50c7bd..3429009a71a863ae6b69b5cd29ace3c7fd078f4c 100644 --- a/tensorflow/c/c_test_util.h +++ b/tensorflow/c/c_test_util.h @@ -74,7 +74,10 @@ TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s, TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s); -// Split `input` along the first dimention into 3 tensors +TF_Operation* RandomUniform(TF_Operation* shape, TF_DataType dtype, + TF_Graph* graph, TF_Status* s); + +// Split `input` along the first dimension into 3 tensors TF_Operation* Split3(TF_Operation* input, TF_Graph* graph, TF_Status* s, const char* name = "split3"); diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 8359de62b7ff690fec9f6a0e3280f947c62f8b6e..706c89536db019c7f7389af576815746b2425520 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -571,6 +571,12 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx, status->status = ctx->func_lib_def.AddFunctionDef(function_def); } +void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, + TF_Status* status) { + tensorflow::mutex_lock l(ctx->functions_mu); + status->status = ctx->func_lib_def.AddFunctionDef(function->fdef); +} + } // extern "C" TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) { diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 865580c5f3a823d9cf49fe460bd007e3b3b88767..ca105962df0d6655946304159937621022e7fcba 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -200,6 +200,13 @@ TF_CAPI_EXPORT extern void TFE_ContextAddFunctionDef(TFE_Context* ctx, const char* serialized_function_def, size_t size, TF_Status* status); +// Adds a function (created from TF_GraphToFunction or +// TF_FunctionImportFunctionDef) to the context, allowing it to be executed with +// TFE_Execute by creating an op with the same name as the function. +TF_CAPI_EXPORT extern void TFE_ContextAddFunction(TFE_Context* ctx, + TF_Function* function, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 4af91b8853d0e85570bad136752a9d0a04b87da5..3fe0b7efa11bc619ed98bf9a1634ade5b6ed0a7c 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -295,6 +295,67 @@ TEST(CAPI, Execute) { TF_DeleteStatus(status); } +TEST(CAPI, Function) { + // First create a simple identity function. + TF_Graph* function_graph = TF_NewGraph(); + TF_OperationDescription* arg_descr = + TF_NewOperation(function_graph, "Placeholder", "arg"); + TF_SetAttrType(arg_descr, "dtype", TF_INT32); + TF_Status* status = TF_NewStatus(); + TF_Operation* arg = TF_FinishOperation(arg_descr, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_OperationDescription* id_descr = + TF_NewOperation(function_graph, "Identity", "id"); + TF_SetAttrType(id_descr, "T", TF_INT32); + TF_AddInput(id_descr, {arg, 0}); + TF_Operation* id = TF_FinishOperation(id_descr, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_Output input{arg, 0}; + TF_Output output{id, 0}; + TF_Function* fn = + TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1, + &output, nullptr, nullptr, "test", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteGraph(function_graph); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_DeleteContextOptions(opts); + TFE_ContextAddFunction(ctx, fn, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteFunction(fn); + + TF_Tensor* t = + TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32)); + *reinterpret_cast(TF_TensorData(t)) = 42; + TFE_TensorHandle* h = TFE_NewTensorHandle(t, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteTensor(t); + + TFE_Op* op = TFE_NewOp(ctx, "ident", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_OpAddInput(op, h, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + + std::vector result; + result.push_back(nullptr); + int num_retvals = 1; + TFE_Execute(op, result.data(), &num_retvals, status); + TFE_DeleteOp(op); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + ASSERT_EQ(num_retvals, 1); + + TF_Tensor* r = TFE_TensorHandleResolve(result[0], status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + EXPECT_EQ(*reinterpret_cast(TF_TensorData(r)), 42); + TFE_DeleteTensorHandle(h); + TF_DeleteTensor(r); + TFE_DeleteTensorHandle(result[0]); + TFE_DeleteContext(ctx, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteStatus(status); +} + string MatMulFunction() { tensorflow::FunctionDef def; CHECK(tensorflow::protobuf::TextFormat::ParseFromString( diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 29d73c5ca43a9ad3dbbc5d0f9c08b0b704724b03..2b65e38f54090af6731685f78d5f7f914a875e3c 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -106,6 +106,12 @@ class VSpace { // Deletes the input tensor. virtual void DeleteGradient(Gradient* gradient) const = 0; + + // Lets this VSpace know that it can release resources held by the + // `backward_function`, It will not be called again. + // `backward_function` must not be null. + virtual void ReleaseBackwardFunction( + BackwardFunction* backward_function) const = 0; }; // Traces the execution of operations, doing eager garbage collection, and @@ -113,7 +119,11 @@ class VSpace { template class GradientTape { public: - GradientTape() {} + // If `persistent` is true, GradientTape will not eagerly delete backward + // functions (and hence the tensors they keep alive). Instead, everything + // is deleted in ~GradientTape. Persistent GradientTapes are useful when + // users want to compute multiple gradients over the same tape. + GradientTape(bool persistent) : persistent_(persistent) {} ~GradientTape() { for (const auto& pair : op_tape_) { pair.second.backward_function_deleter(); @@ -150,6 +160,10 @@ class GradientTape { // Map from tensor id to number of remaining usages (i.e. how many entries in // the tape refer to it); to aid in tape garbage collection. std::unordered_map tensor_usage_; + + // If false, all activations are deleted in the first call to ComputeGradient. + // Else, only when this is destructed. + bool persistent_; }; // Template instantiations here @@ -279,11 +293,16 @@ struct BackpropInitialState { std::unordered_map op_missing_tensor; }; +// If `persistent_tape` is true, op_tape is not changed and none of the +// backwards functions are deleted. +// If `persistent_tape` is false, op_tape is cleared and backwards functions +// not needed for gradient computation are deleted. Backwards functions that +// are needed, are copied and returned in BackpropInitialState. template BackpropInitialState PrepareBackprop( gtl::ArraySlice target, const TensorTape& tensor_tape, - OpTape op_tape, - const std::unordered_set& sources_set) { + OpTape* op_tape, + const std::unordered_set& sources_set, bool persistent_tape) { std::vector tensor_stack; tensor_stack.reserve(target.size()); for (auto t : target) { @@ -298,9 +317,9 @@ BackpropInitialState PrepareBackprop( continue; } int64 op_id = op_id_it->second; - auto op_it = op_tape.find(op_id); + auto op_it = op_tape->find(op_id); auto result_op_it = result.op_tape.find(op_id); - if (op_id == -1 || op_it == op_tape.end() || + if (op_id == -1 || op_it == op_tape->end() || result_op_it != result.op_tape.end()) { continue; } @@ -317,7 +336,9 @@ BackpropInitialState PrepareBackprop( } } } - op_tape.erase(op_it); + if (!persistent_tape) { + op_tape->erase(op_it); + } } for (auto& pair : result.tensor_usage_counts) { auto it = tensor_tape.find(pair.first); @@ -325,9 +346,15 @@ BackpropInitialState PrepareBackprop( result.op_missing_tensor[it->second] += 1; } } - // Call destructors for all unneeded gradient functions. - for (const auto& op_pair : op_tape) { - op_pair.second.backward_function_deleter(); + if (!persistent_tape) { + // Call destructors for all unneeded gradient functions and + // clear the op_tape. We can clear the tape because ownership of + // backward functions that will be used for gradient computation + // has been transferred to `result`. + for (const auto& op_pair : *op_tape) { + op_pair.second.backward_function_deleter(); + } + op_tape->clear(); } return result; } @@ -369,7 +396,8 @@ Status InitialGradients( auto op_it = op_tape.find(tensor_it->second); if (op_it == op_tape.end()) { return errors::Internal( - "Internal state of the gradient tape is invalid."); + "Internal state of the gradient tape is invalid: " + "failed to find operation producing a tensor"); } bool found = false; for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) { @@ -383,7 +411,8 @@ Status InitialGradients( } if (!found) { return errors::Internal( - "Internal state of the gradient tape is invalid."); + "Internal state of the gradient tape is invalid: " + "none of operations outputs match expected tensor"); } } else { // No record of the target tensor found on the tape, so no gradient @@ -415,17 +444,19 @@ Status GradientTape::ComputeGradient( std::unordered_set sources_set(source_tensor_ids.begin(), source_tensor_ids.end()); BackpropInitialState state = PrepareBackprop( - target_tensor_ids, tensor_tape_, std::move(op_tape_), sources_set); + target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_); std::vector op_stack = InitialStack(state.op_tape, state.op_missing_tensor); std::unordered_map> gradients; Status s = InitialGradients(vspace, target_tensor_ids, output_gradients, tensor_tape_, state.op_tape, state.tensor_usage_counts, &gradients); - auto cleanup = [&state]() { - // Release all backprop functions - for (const auto& pair : state.op_tape) { - pair.second.backward_function_deleter(); + auto cleanup = [this, &state]() { + if (!persistent_) { + // Release all backprop functions + for (const auto& pair : state.op_tape) { + pair.second.backward_function_deleter(); + } } }; if (!s.ok()) { @@ -460,6 +491,7 @@ Status GradientTape::ComputeGradient( state.op_tape.erase(op_it); std::vector out_gradients; out_gradients.reserve(trace.output_tensor_info.size()); + bool any_gradient_nonzero = false; for (int i = 0; i < trace.output_tensor_info.size(); ++i) { const int64 id = trace.output_tensor_info[i].id; auto grad_it = gradients.find(id); @@ -475,6 +507,7 @@ Status GradientTape::ComputeGradient( trace.output_tensor_info[i].dtype)); } } else { + any_gradient_nonzero = true; out_gradients.push_back(vspace.AggregateGradients(grad_it->second)); if (sources_set.find(grad_it->first) == sources_set.end()) { gradients.erase(grad_it); @@ -482,12 +515,26 @@ Status GradientTape::ComputeGradient( } } std::vector in_gradients; - Status s = vspace.CallBackwardFunction(trace.backward_function, - out_gradients, &in_gradients); - if (!s.ok()) { - VLOG(1) << "Gradient function failed."; - cleanup(); - return s; + if (any_gradient_nonzero) { + Status s = vspace.CallBackwardFunction(trace.backward_function, + out_gradients, &in_gradients); + if (!persistent_) { + vspace.ReleaseBackwardFunction(trace.backward_function); + } + if (!s.ok()) { + cleanup(); + return s; + } + } else { + in_gradients.resize(trace.input_tensor_id.size()); + if (!persistent_) { + vspace.ReleaseBackwardFunction(trace.backward_function); + } + for (Gradient* grad : out_gradients) { + if (grad != nullptr) { + vspace.DeleteGradient(grad); + } + } } VLOG(1) << "Got " << in_gradients.size() << " in_gradients for " << trace.input_tensor_id.size() << " sources"; diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index ba5a9268b4f671499590d66fb41060dd18e1ce47..6e37cdb5f4beea53d4a2ded0705ae482d0bc2d68 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -22,6 +22,7 @@ namespace tensorflow { void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) { mutex_lock l(graph->mu); graph->graph.AddControlEdge(&input->node, &op->node); + RecordMutation(graph, *op, "adding control input"); } void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, @@ -36,11 +37,13 @@ void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, mutex_lock l(graph->mu); op->node.AddAttr(attr_name, attr_val); + RecordMutation(graph, *op, "setting attribute"); } void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) { mutex_lock l(graph->mu); op->node.set_requested_device(device); + RecordMutation(graph, *op, "setting device"); } void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, @@ -75,6 +78,25 @@ void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, } status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index, &dst.oper->node, dst.index); + + if (status->status.ok()) { + // This modification only updates the destination node for + // the purposes of running this graph in a session. Thus, we don't + // record the source node as being modified. + RecordMutation(graph, *dst.oper, "updating input tensor"); + } +} + +void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) { + mutex_lock l(graph->mu); + std::vector control_edges; + for (const Edge* edge : op->node.in_edges()) { + if (!edge->IsControlEdge()) continue; + control_edges.push_back(edge); + } + for (const Edge* edge : control_edges) { + graph->graph.RemoveControlEdge(edge); + } } } // namespace tensorflow diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h index f54585b0a1034ff108202272a11416e34985959e..b51ef2b53122802fef598a26bd6f1843976f11b0 100644 --- a/tensorflow/c/python_api.h +++ b/tensorflow/c/python_api.h @@ -35,6 +35,8 @@ void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device); void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, TF_Status* status); +void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op); + } // namespace tensorflow #endif // THIRD_PARTY_TENSORFLOW_C_PYTHON_API_H_ diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index 80112f9b44b1d5fd65a7d47788b072dc47a2b29a..e354831d7d25af83c068a68a4f844056263a598c 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -421,6 +421,7 @@ tf_cc_test( tf_gen_op_wrappers_cc( name = "cc_ops", + api_def_srcs = ["//tensorflow/core:base_api_def"], op_lib_names = [ "array_ops", "audio_ops", @@ -525,6 +526,30 @@ cc_library_with_android_deps( "//tensorflow/core:android_tensorflow_lib", ], copts = tf_copts(), + data = [ + "//tensorflow/core:base_api_def", + ], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:op_gen_lib", + "//tensorflow/core:op_gen_overrides_proto_cc", + "//tensorflow/core:proto_text", + "//tensorflow/core:protos_all_cc", + ], +) + +tf_cc_test( + name = "cc_op_gen_test", + srcs = [ + "framework/cc_op_gen.cc", + "framework/cc_op_gen.h", + "framework/cc_op_gen_test.cc", + ], + data = [ + "//tensorflow/cc:ops/op_gen_overrides.pbtxt", + ], deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -533,6 +558,8 @@ cc_library_with_android_deps( "//tensorflow/core:op_gen_overrides_proto_cc", "//tensorflow/core:proto_text", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", ], ) diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index 38a17598b8e4161f96ab8134823de033d3284440..d889c518f9c38a9f070970b37a2ad4b1fc26671b 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -18,10 +18,11 @@ limitations under the License. #include #include "tensorflow/cc/framework/cc_op_gen.h" +#include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/framework/op_gen_lib.h" -#include "tensorflow/core/framework/op_gen_overrides.pb.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb_text.h" @@ -35,7 +36,6 @@ limitations under the License. #include "tensorflow/core/public/version.h" namespace tensorflow { - namespace { const int kRightMargin = 79; @@ -297,7 +297,7 @@ string ToCamelCase(const string& str) { // argument to a function. std::pair AttrTypeName(StringPiece attr_type) { static const std::unordered_map, - StringPiece::Hasher> + StringPieceHasher> attr_type_map{ {"string", {"StringPiece", false}}, {"list(string)", {"gtl::ArraySlice", true}}, @@ -325,29 +325,112 @@ std::pair AttrTypeName(StringPiece attr_type) { } bool IsCPPKeyword(StringPiece name) { - static const std::unordered_set + static const std::unordered_set // Keywords obtained from http://en.cppreference.com/w/cpp/keyword kCPPReserved{ - "alignas", "alignof", "and", "and_eq", "asm", "atomic_cancel", - "atomic_commit", "atomic_noexcept", "auto", "bitand", "bitor", "bool", - "break", "case", "catch", "char", "char16_t", "char32_t", "class", - "compl", "concept", "const", "const_cast", "constexpr", "continue", - "decltype", "default", "delete", "do", "double", "dynamic_cast", - "else", "enum", "explicit", "export", "extern", "false", "final", - "float", "for", "friend", "goto", "if", "import", "inline", "int", - "long", "module", "mutable", "namespace", "new", "noexcept", "not", - "not_eq", "nullptr", "operator", "or", "or_eq", "override", "private", - "protected", "public", "register", "reinterpret_cast", "requires", - "return", "short", "signed", "sizeof", "static", "static_assert", - "static_cast", "struct", "switch", "synchronized", "template", "this", - "thread_local", "throw", "true", "try", "typedef", "typeid", - "typename", "union", "unsigned", "using", "virtual", "void", - "volatile", "wchar_t", "while", "xor", "xor_eq", + "alignas", + "alignof", + "and", + "and_eq", + "asm", + "atomic_cancel", + "atomic_commit", + "atomic_noexcept", + "auto", + "bitand", + "bitor", + "bool", + "break", + "case", + "catch", + "char", + "char16_t", + "char32_t", + "class", + "compl", + "concept", + "const", + "const_cast", + "constexpr", + "continue", + "decltype", + "default", + "delete", + "do", + "double", + "dynamic_cast", + "else", + "enum", + "explicit", + "export", + "extern", + "false", + "final", + "float", + "for", + "friend", + "goto", + "if", + "import", + "inline", + "int", + "long", + "module", + "mutable", + "namespace", + "new", + "noexcept", + "not", + "not_eq", + "nullptr", + "operator", + "or", + "or_eq", + "override", + "private", + "protected", + "public", + "register", + "reinterpret_cast", + "requires", + "return", + "short", + "signed", + "sizeof", + "static", + "static_assert", + "static_cast", + "struct", + "switch", + "synchronized", + "template", + "this", + "thread_local", + "throw", + "true", + "try", + "typedef", + "typeid", + "typename", + "union", + "unsigned", + "using", + "virtual", + "void", + "volatile", + "wchar_t", + "while", + "xor", + "xor_eq", // The following are not C++ keywords, but names of local variables // and parameters used in the op constructor. Treating them as // keywords, so that other parameter names don't conflict with these. - "builder", "node", "ret", "scope", "unique_name", + "builder", + "node", + "ret", + "scope", + "unique_name", }; return kCPPReserved.count(name) > 0; } @@ -385,10 +468,10 @@ bool ArgIsList(const OpDef::ArgDef& arg) { } bool HasOptionalAttrs( - const OpDef& op_def, + const ApiDef& api_def, const std::unordered_map& inferred_input_attrs) { - for (int i = 0; i < op_def.attr_size(); ++i) { - const auto& attr(op_def.attr(i)); + for (int i = 0; i < api_def.attr_size(); ++i) { + const auto& attr(api_def.attr(i)); if ((inferred_input_attrs.find(attr.name()) == inferred_input_attrs.end()) && attr.has_default_value()) { @@ -398,12 +481,21 @@ bool HasOptionalAttrs( return false; } +const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) { + for (int i = 0; i < api_def.in_arg_size(); ++i) { + if (api_def.in_arg(i).name() == name) { + return &api_def.in_arg(i); + } + } + return nullptr; +} + struct OpInfo { // graph_op_def: The OpDef used by the runtime, has the names that // must be used when calling NodeBuilder. // interface_op_def: The OpDef used in the interface in the generated // code, with possibly overridden names and defaults. - explicit OpInfo(const OpDef& graph_op_def, const OpDef& inteface_op_def, + explicit OpInfo(const OpDef& graph_op_def, const ApiDef& api_def, const std::vector& aliases); string GetOpAttrStruct() const; string GetConstructorDecl(StringPiece op_name_prefix, @@ -423,74 +515,81 @@ struct OpInfo { string comment; const OpDef& graph_op_def; - const OpDef& op_def; + const ApiDef& api_def; const std::vector& aliases; + // Map from type attribute to corresponding original argument name. std::unordered_map inferred_input_attrs; }; -OpInfo::OpInfo(const OpDef& g_op_def, const OpDef& i_op_def, - const std::vector& a) - : graph_op_def(g_op_def), op_def(i_op_def), aliases(a) { - op_name = op_def.name(); - InferOpAttributes(op_def, &inferred_input_attrs); - has_optional_attrs = HasOptionalAttrs(op_def, inferred_input_attrs); +OpInfo::OpInfo(const OpDef& graph_op_def, const ApiDef& api_def, + const std::vector& aliases) + : graph_op_def(graph_op_def), api_def(api_def), aliases(aliases) { + op_name = api_def.endpoint(0).name(); + InferOpAttributes(graph_op_def, &inferred_input_attrs); + has_optional_attrs = HasOptionalAttrs(api_def, inferred_input_attrs); arg_types.push_back("const ::tensorflow::Scope&"); arg_names.push_back("scope"); - if (op_def.has_deprecation()) { - if (!op_def.summary().empty()) { - comment = strings::StrCat(op_def.summary(), "\n"); + if (graph_op_def.has_deprecation()) { + if (!api_def.summary().empty()) { + comment = strings::StrCat(api_def.summary(), "\n"); } strings::StrAppend(&comment, "DEPRECATED at GraphDef version ", - op_def.deprecation().version(), ":\n", - op_def.deprecation().explanation(), ".\n"); - } else if (op_def.summary().empty()) { + graph_op_def.deprecation().version(), ":\n", + graph_op_def.deprecation().explanation(), ".\n"); + } else if (api_def.summary().empty()) { comment = "TODO: add doc.\n"; } else { - comment = strings::StrCat(op_def.summary(), "\n"); + comment = strings::StrCat(api_def.summary(), "\n"); } - if (!op_def.description().empty()) { - strings::StrAppend(&comment, "\n", op_def.description(), "\n"); + if (!api_def.description().empty()) { + strings::StrAppend(&comment, "\n", api_def.description(), "\n"); } strings::StrAppend(&comment, "\nArguments:\n* scope: A Scope object\n"); // Process inputs - for (int i = 0; i < op_def.input_arg_size(); ++i) { - const auto& arg(op_def.input_arg(i)); + for (int i = 0; i < api_def.arg_order_size(); ++i) { + const auto& arg = *FindInputArg(api_def.arg_order(i), graph_op_def); + const auto& api_def_arg = *FindInputArg(api_def.arg_order(i), api_def); arg_types.push_back(strings::StrCat( "::tensorflow::", ArgIsList(arg) ? "InputList" : "Input")); - arg_names.push_back(AvoidCPPKeywords(arg.name())); + arg_names.push_back(AvoidCPPKeywords(api_def_arg.rename_to())); // TODO(keveman): Include input type information. - StringPiece description = arg.description(); + StringPiece description = api_def_arg.description(); if (!description.empty()) { ConsumeEquals(&description); - strings::StrAppend(&comment, "* ", AvoidCPPKeywords(arg.name()), ": ", - arg.description(), "\n"); + strings::StrAppend(&comment, "* ", + AvoidCPPKeywords(api_def_arg.rename_to()), ": ", + api_def_arg.description(), "\n"); } } // Process attrs string required_attrs_comment; string optional_attrs_comment; - for (int i = 0; i < op_def.attr_size(); ++i) { - const auto& attr(op_def.attr(i)); + for (int i = 0; i < graph_op_def.attr_size(); ++i) { + // ApiDef attributes must be in the same order as in OpDef since + // we initialize ApiDef based on OpDef. + const auto& attr(graph_op_def.attr(i)); + const auto& api_def_attr(api_def.attr(i)); + CHECK_EQ(attr.name(), api_def_attr.name()); // Skip inferred arguments if (inferred_input_attrs.count(attr.name()) > 0) continue; const auto entry = AttrTypeName(attr.type()); const auto attr_type_name = entry.first; const bool use_const = entry.second; - string attr_name = AvoidCPPKeywords(attr.name()); + string attr_name = AvoidCPPKeywords(api_def_attr.rename_to()); string attr_comment; - if (!attr.description().empty()) { + if (!api_def_attr.description().empty()) { // TODO(keveman): Word wrap and indent this, to handle multi-line // descriptions. strings::StrAppend(&attr_comment, "* ", attr_name, ": ", - attr.description(), "\n"); + api_def_attr.description(), "\n"); } - if (attr.has_default_value()) { + if (api_def_attr.has_default_value()) { strings::StrAppend(&optional_attrs_comment, attr_comment); } else { strings::StrAppend(&required_attrs_comment, attr_comment); @@ -508,44 +607,49 @@ OpInfo::OpInfo(const OpDef& g_op_def, const OpDef& i_op_def, } // Process outputs - for (int i = 0; i < op_def.output_arg_size(); ++i) { - const auto& arg = op_def.output_arg(i); + for (int i = 0; i < graph_op_def.output_arg_size(); ++i) { + // ApiDef arguments must be in the same order as in OpDef since + // we initialize ApiDef based on OpDef. + const auto& arg = graph_op_def.output_arg(i); + const auto& api_def_arg(api_def.out_arg(i)); + CHECK_EQ(arg.name(), api_def_arg.name()); + bool is_list = ArgIsList(arg); output_types.push_back( strings::StrCat("::tensorflow::", is_list ? "OutputList" : "Output")); - output_names.push_back(AvoidCPPKeywords(arg.name())); + output_names.push_back(AvoidCPPKeywords(api_def_arg.rename_to())); is_list_output.push_back(is_list); } strings::StrAppend(&comment, "\nReturns:\n"); - if (op_def.output_arg_size() == 0) { // No outputs. + if (graph_op_def.output_arg_size() == 0) { // No outputs. strings::StrAppend(&comment, "* the created `Operation`\n"); - } else if (op_def.output_arg_size() == 1) { // One output + } else if (graph_op_def.output_arg_size() == 1) { // One output if (is_list_output[0]) { strings::StrAppend(&comment, "* `OutputList`: "); } else { strings::StrAppend(&comment, "* `Output`: "); } - if (op_def.output_arg(0).description().empty()) { - strings::StrAppend(&comment, "The ", op_def.output_arg(0).name(), + if (api_def.out_arg(0).description().empty()) { + strings::StrAppend(&comment, "The ", api_def.out_arg(0).name(), " tensor.\n"); } else { // TODO(josh11b): Word wrap this. - strings::StrAppend(&comment, op_def.output_arg(0).description(), "\n"); + strings::StrAppend(&comment, api_def.out_arg(0).description(), "\n"); } } else { // Multiple outputs. - for (int i = 0; i < op_def.output_arg_size(); ++i) { + for (int i = 0; i < graph_op_def.output_arg_size(); ++i) { if (is_list_output[i]) { strings::StrAppend(&comment, "* `OutputList`"); } else { strings::StrAppend(&comment, "* `Output`"); } strings::StrAppend(&comment, " ", output_names[i]); - if (op_def.output_arg(i).description().empty()) { + if (api_def.out_arg(i).description().empty()) { strings::StrAppend(&comment, "\n"); } else { // TODO(josh11b): Word wrap this. - strings::StrAppend(&comment, ": ", op_def.output_arg(i).description(), + strings::StrAppend(&comment, ": ", api_def.out_arg(i).description(), "\n"); } } @@ -564,19 +668,20 @@ string OpInfo::GetOpAttrStruct() const { string struct_fields; string setters; - for (int i = 0; i < op_def.attr_size(); ++i) { - const auto& attr(op_def.attr(i)); + for (int i = 0; i < graph_op_def.attr_size(); ++i) { + const auto& attr(graph_op_def.attr(i)); + const auto& api_def_attr(api_def.attr(i)); // If attr will be inferred or it doesn't have a default value, don't // add it to the struct. if ((inferred_input_attrs.find(attr.name()) != inferred_input_attrs.end()) || - !attr.has_default_value()) { + !api_def_attr.has_default_value()) { continue; } const auto entry = AttrTypeName(attr.type()); const auto attr_type_name = entry.first; const bool use_const = entry.second; - const string camel_case_name = ToCamelCase(attr.name()); + const string camel_case_name = ToCamelCase(api_def_attr.rename_to()); const string suffix = (camel_case_name == op_name || camel_case_name == "Attrs") ? "_" : ""; const string attr_func_def = @@ -584,22 +689,25 @@ string OpInfo::GetOpAttrStruct() const { attr_type_name, use_const ? "&" : ""); string attr_comment; - if (!attr.description().empty()) { - strings::StrAppend(&attr_comment, attr.description(), "\n\n"); + if (!api_def_attr.description().empty()) { + strings::StrAppend(&attr_comment, api_def_attr.description(), "\n\n"); } strings::StrAppend(&attr_comment, "Defaults to ", - SummarizeAttrValue(attr.default_value()), "\n"); + SummarizeAttrValue(api_def_attr.default_value()), "\n"); attr_comment = MakeComment(attr_comment, " "); strings::StrAppend(&setters, attr_comment); strings::StrAppend(&setters, " Attrs ", attr_func_def, " x) {\n"); strings::StrAppend(&setters, " Attrs ret = *this;\n"); - strings::StrAppend(&setters, " ret.", attr.name(), "_ = x;\n"); + strings::StrAppend(&setters, " ret.", api_def_attr.rename_to(), + "_ = x;\n"); strings::StrAppend(&setters, " return ret;\n }\n\n"); strings::StrAppend( - &struct_fields, " ", attr_type_name, " ", attr.name(), "_ = ", - PrintAttrValue(op_def.name(), attr.default_value()), ";\n"); + &struct_fields, " ", attr_type_name, " ", api_def_attr.rename_to(), + "_ = ", + PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()), + ";\n"); } if (struct_fields.empty()) { @@ -676,17 +784,18 @@ void OpInfo::WriteClassDecl(WritableFile* h) const { // Add the static functions to set optional attrs if (has_optional_attrs) { strings::StrAppend(&class_decl, "\n"); - for (int i = 0; i < op_def.attr_size(); ++i) { - const auto& attr(op_def.attr(i)); + for (int i = 0; i < graph_op_def.attr_size(); ++i) { + const auto& attr(graph_op_def.attr(i)); + const auto& api_def_attr(api_def.attr(i)); if ((inferred_input_attrs.find(attr.name()) != inferred_input_attrs.end()) || - !attr.has_default_value()) { + !api_def_attr.has_default_value()) { continue; } const auto entry = AttrTypeName(attr.type()); const auto attr_type_name = entry.first; const bool use_const = entry.second; - const string camel_case_name = ToCamelCase(attr.name()); + const string camel_case_name = ToCamelCase(api_def_attr.rename_to()); const string suffix = (camel_case_name == op_name || camel_case_name == "Attrs") ? "_" : ""; const string attr_func_def = strings::StrCat( @@ -726,11 +835,11 @@ void OpInfo::GetOutput(string* out) const { strings::StrCat("if (!", scope_str, ".ok()) return;"); // No outputs. - if (op_def.output_arg_size() == 0) { + if (graph_op_def.output_arg_size() == 0) { strings::StrAppend(out, " this->operation = Operation(ret);\n return;\n"); return; } - if (op_def.output_arg_size() == 1) { + if (graph_op_def.output_arg_size() == 1) { // One output, no need for NameRangeMap if (is_list_output[0]) { strings::StrAppend(out, @@ -752,7 +861,7 @@ void OpInfo::GetOutput(string* out) const { ".UpdateStatus(_status_);\n", " return;\n"); strings::StrAppend(out, " }\n\n"); - for (int i = 0; i < op_def.output_arg_size(); ++i) { + for (int i = 0; i < graph_op_def.output_arg_size(); ++i) { const string arg_range = strings::StrCat( "_outputs_range[\"", graph_op_def.output_arg(i).name(), "\"]"); if (is_list_output[i]) { @@ -776,11 +885,13 @@ string OpInfo::GetConstructorBody() const { strings::StrAppend(&body, " ", return_on_error, "\n"); - for (int i = 0; i < op_def.input_arg_size(); ++i) { - const auto& arg(op_def.input_arg(i)); - strings::StrAppend(&body, " auto _", arg.name(), " = ::tensorflow::ops::", - ArgIsList(arg) ? "AsNodeOutList" : "AsNodeOut", "(", - scope_str, ", ", AvoidCPPKeywords(arg.name()), ");\n"); + for (int i = 0; i < graph_op_def.input_arg_size(); ++i) { + const auto& arg(graph_op_def.input_arg(i)); + const auto& api_def_arg(api_def.in_arg(i)); + strings::StrAppend( + &body, " auto _", api_def_arg.rename_to(), " = ::tensorflow::ops::", + ArgIsList(arg) ? "AsNodeOutList" : "AsNodeOut", "(", scope_str, ", ", + AvoidCPPKeywords(api_def_arg.rename_to()), ");\n"); strings::StrAppend(&body, " ", return_on_error, "\n"); } @@ -791,19 +902,21 @@ string OpInfo::GetConstructorBody() const { &body, " auto builder = ::tensorflow::NodeBuilder(unique_name, \"", graph_op_def.name(), "\")\n"); const string spaces = " "; - for (int i = 0; i < op_def.input_arg_size(); ++i) { - const auto& arg(op_def.input_arg(i)); - strings::StrAppend(&body, spaces, ".Input(_", arg.name(), ")\n"); + for (int i = 0; i < api_def.in_arg_size(); ++i) { + const auto& arg(api_def.in_arg(i)); + strings::StrAppend(&body, spaces, ".Input(_", arg.rename_to(), ")\n"); } - for (int i = 0; i < op_def.attr_size(); ++i) { + for (int i = 0; i < api_def.attr_size(); ++i) { const auto& graph_attr(graph_op_def.attr(i)); - const auto& attr(op_def.attr(i)); - if (inferred_input_attrs.find(attr.name()) != inferred_input_attrs.end()) { + const auto& api_def_attr(api_def.attr(i)); + if (inferred_input_attrs.find(api_def_attr.name()) != + inferred_input_attrs.end()) { continue; } - const string attr_name = attr.has_default_value() - ? strings::StrCat("attrs.", attr.name(), "_") - : AvoidCPPKeywords(attr.name()); + const string attr_name = + api_def_attr.has_default_value() + ? strings::StrCat("attrs.", api_def_attr.rename_to(), "_") + : AvoidCPPKeywords(api_def_attr.rename_to()); strings::StrAppend(&body, spaces, ".Attr(\"", graph_attr.name(), "\", ", attr_name, ")\n"); } @@ -845,10 +958,10 @@ void OpInfo::WriteClassDef(WritableFile* cc) const { TF_CHECK_OK(cc->Append(class_def)); } -void WriteCCOp(const OpDef& graph_op_def, const OpDef& interface_op_def, +void WriteCCOp(const OpDef& graph_op_def, const ApiDef& api_def, const std::vector& aliases, WritableFile* h, WritableFile* cc) { - OpInfo op_info(graph_op_def, interface_op_def, aliases); + OpInfo op_info(graph_op_def, api_def, aliases); op_info.WriteClassDecl(h); op_info.WriteClassDef(cc); @@ -943,8 +1056,9 @@ string MakeInternal(const string& fname) { } // namespace -void WriteCCOps(const OpList& ops, const string& dot_h_fname, - const string& dot_cc_fname, const string& overrides_fnames) { +void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map, + const string& dot_h_fname, const string& dot_cc_fname, + const string& overrides_fnames) { Env* env = Env::Default(); // Load the override map. @@ -984,24 +1098,23 @@ void WriteCCOps(const OpList& ops, const string& dot_h_fname, // code depends on it. if (graph_op_def.name() == "Const") continue; - // Incorporate overrides from override_map. - OpDef interface_op_def = graph_op_def; - const OpGenOverride* op_override = - override_map.ApplyOverride(&interface_op_def); + const auto* api_def = api_def_map.GetApiDef(graph_op_def.name()); + std::vector aliases; - if (op_override) { - if (op_override->skip()) continue; - aliases.assign(op_override->alias().begin(), op_override->alias().end()); - if (op_override->hide()) { - // Write hidden ops to _internal.h and _internal.cc. - WriteCCOp(graph_op_def, interface_op_def, aliases, internal_h.get(), - internal_cc.get()); - continue; - } + if (api_def->visibility() == ApiDef::SKIP) continue; + // First endpoint is canonical, the rest are aliases. + for (int endpoint_i = 1; endpoint_i < api_def->endpoint_size(); + ++endpoint_i) { + aliases.push_back(api_def->endpoint(endpoint_i).name()); + } + if (api_def->visibility() == ApiDef::HIDDEN) { + // Write hidden ops to _internal.h and _internal.cc. + WriteCCOp(graph_op_def, *api_def, aliases, internal_h.get(), + internal_cc.get()); + continue; } - // This isn't a hidden op, write it to the main files. - WriteCCOp(graph_op_def, interface_op_def, aliases, h.get(), cc.get()); + WriteCCOp(graph_op_def, *api_def, aliases, h.get(), cc.get()); } FinishFiles(false, h.get(), cc.get(), op_header_guard); diff --git a/tensorflow/cc/framework/cc_op_gen.h b/tensorflow/cc/framework/cc_op_gen.h index fa5e004f0317d046d82bee005bdf9f17773a45f3..cea28990144b9371e8009ce13f912b44044f9aac 100644 --- a/tensorflow/cc/framework/cc_op_gen.h +++ b/tensorflow/cc/framework/cc_op_gen.h @@ -17,13 +17,15 @@ limitations under the License. #define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_ #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_gen_lib.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { /// Result is written to files dot_h and dot_cc. -void WriteCCOps(const OpList& ops, const string& dot_h_fname, - const string& dot_cc_fname, const string& overrides_fnames); +void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map, + const string& dot_h_fname, const string& dot_cc_fname, + const string& overrides_fnames); } // namespace tensorflow diff --git a/tensorflow/cc/framework/cc_op_gen_main.cc b/tensorflow/cc/framework/cc_op_gen_main.cc index 3b80cf993eb9a5d5f4c41687577414e7216dd174..326d5668b8803ee39ffe24900c92e1db87b93601 100644 --- a/tensorflow/cc/framework/cc_op_gen_main.cc +++ b/tensorflow/cc/framework/cc_op_gen_main.cc @@ -16,7 +16,11 @@ limitations under the License. #include "tensorflow/cc/framework/cc_op_gen.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_gen_lib.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/types.h" @@ -24,10 +28,28 @@ namespace tensorflow { namespace { void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc, - const std::string& overrides_fnames, bool include_internal) { + const std::string& overrides_fnames, bool include_internal, + const std::vector& api_def_dirs) { OpList ops; OpRegistry::Global()->Export(include_internal, &ops); - WriteCCOps(ops, dot_h, dot_cc, overrides_fnames); + ApiDefMap api_def_map(ops); + if (!api_def_dirs.empty()) { + Env* env = Env::Default(); + // Only load files that correspond to "ops". + for (const auto& op : ops.op()) { + for (const auto& api_def_dir : api_def_dirs) { + const std::string api_def_file_pattern = + io::JoinPath(api_def_dir, "api_def_" + op.name() + ".pbtxt"); + if (env->FileExists(api_def_file_pattern).ok()) { + TF_CHECK_OK(api_def_map.LoadFile(env, api_def_file_pattern)); + } + } + } + } + + api_def_map.UpdateDocs(); + + WriteCCOps(ops, api_def_map, dot_h, dot_cc, overrides_fnames); } } // namespace @@ -35,18 +57,24 @@ void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc, int main(int argc, char* argv[]) { tensorflow::port::InitMain(argv[0], &argc, &argv); - if (argc != 5) { + // TODO(annarev): Update this file to no longer take op_gen_overrides.pbtxt + // as an argument. + if (argc != 6) { for (int i = 1; i < argc; ++i) { fprintf(stderr, "Arg %d = %s\n", i, argv[i]); } fprintf(stderr, - "Usage: %s out.h out.cc overrides1.pbtxt,2.pbtxt include_internal\n" + "Usage: %s out.h out.cc overrides1.pbtxt,2.pbtxt include_internal " + "api_def_dirs1,api_def_dir2 ...\n" " include_internal: 1 means include internal ops\n", argv[0]); exit(1); } bool include_internal = tensorflow::StringPiece("1") == argv[4]; - tensorflow::PrintAllCCOps(argv[1], argv[2], argv[3], include_internal); + std::vector api_def_dirs = tensorflow::str_util::Split( + argv[5], ",", tensorflow::str_util::SkipEmpty()); + tensorflow::PrintAllCCOps(argv[1], argv[2], argv[3], include_internal, + api_def_dirs); return 0; } diff --git a/tensorflow/cc/framework/cc_op_gen_test.cc b/tensorflow/cc/framework/cc_op_gen_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0b7e720a5c7b343415eee1aa157b8de755a1e1a5 --- /dev/null +++ b/tensorflow/cc/framework/cc_op_gen_test.cc @@ -0,0 +1,195 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/framework/cc_op_gen.h" + +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +// TODO(annarev): Remove this op_gen_overrides.pbtxt reference. +// It is needed only because WriteCCOps takes it as an argument. +constexpr char kOverridesFnames[] = + "tensorflow/cc/ops/op_gen_overrides.pbtxt"; +constexpr char kBaseOpDef[] = R"( +op { + name: "Foo" + input_arg { + name: "images" + description: "Images to process." + } + input_arg { + name: "dim" + description: "Description for dim." + type: DT_FLOAT + } + output_arg { + name: "output" + description: "Description for output." + type: DT_FLOAT + } + attr { + name: "T" + type: "type" + description: "Type for images" + allowed_values { + list { + type: DT_UINT8 + type: DT_INT8 + } + } + default_value { + i: 1 + } + } + summary: "Summary for op Foo." + description: "Description for op Foo." +} +)"; + +void ExpectHasSubstr(StringPiece s, StringPiece expected) { + EXPECT_TRUE(s.contains(expected)) + << "'" << s << "' does not contain '" << expected << "'"; +} + +void ExpectDoesNotHaveSubstr(StringPiece s, StringPiece expected) { + EXPECT_FALSE(s.contains(expected)) + << "'" << s << "' contains '" << expected << "'"; +} + +void ExpectSubstrOrder(const string& s, const string& before, + const string& after) { + int before_pos = s.find(before); + int after_pos = s.find(after); + ASSERT_NE(std::string::npos, before_pos); + ASSERT_NE(std::string::npos, after_pos); + EXPECT_LT(before_pos, after_pos) + << before << " is not before " << after << " in " << s; +} + +// Runs WriteCCOps and stores output in (internal_)cc_file_path and +// (internal_)h_file_path. +void GenerateCcOpFiles(Env* env, const OpList& ops, + const ApiDefMap& api_def_map, string* h_file_text, + string* internal_h_file_text) { + const string& tmpdir = testing::TmpDir(); + + const auto h_file_path = io::JoinPath(tmpdir, "test.h"); + const auto cc_file_path = io::JoinPath(tmpdir, "test.cc"); + const auto internal_h_file_path = io::JoinPath(tmpdir, "test_internal.h"); + const auto internal_cc_file_path = io::JoinPath(tmpdir, "test_internal.cc"); + + WriteCCOps(ops, api_def_map, h_file_path, cc_file_path, kOverridesFnames); + + TF_ASSERT_OK(ReadFileToString(env, h_file_path, h_file_text)); + TF_ASSERT_OK( + ReadFileToString(env, internal_h_file_path, internal_h_file_text)); +} + +TEST(CcOpGenTest, TestVisibilityChangedToHidden) { + const string api_def = R"( +op { + graph_op_name: "Foo" + visibility: HIDDEN +} +)"; + Env* env = Env::Default(); + OpList op_defs; + protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); // NOLINT + ApiDefMap api_def_map(op_defs); + + string h_file_text, internal_h_file_text; + // Without ApiDef + GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text, + &internal_h_file_text); + ExpectHasSubstr(h_file_text, "class Foo"); + ExpectDoesNotHaveSubstr(internal_h_file_text, "class Foo"); + + // With ApiDef + TF_ASSERT_OK(api_def_map.LoadApiDef(api_def)); + GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text, + &internal_h_file_text); + ExpectHasSubstr(internal_h_file_text, "class Foo"); + ExpectDoesNotHaveSubstr(h_file_text, "class Foo"); +} + +TEST(CcOpGenTest, TestArgNameChanges) { + const string api_def = R"( +op { + graph_op_name: "Foo" + arg_order: "dim" + arg_order: "images" +} +)"; + Env* env = Env::Default(); + OpList op_defs; + protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); // NOLINT + + ApiDefMap api_def_map(op_defs); + string cc_file_text, h_file_text; + string internal_cc_file_text, internal_h_file_text; + // Without ApiDef + GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text, + &internal_h_file_text); + ExpectSubstrOrder(h_file_text, "Input images", "Input dim"); + + // With ApiDef + TF_ASSERT_OK(api_def_map.LoadApiDef(api_def)); + GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text, + &internal_h_file_text); + ExpectSubstrOrder(h_file_text, "Input dim", "Input images"); +} + +TEST(CcOpGenTest, TestEndpoints) { + const string api_def = R"( +op { + graph_op_name: "Foo" + endpoint { + name: "Foo1" + } + endpoint { + name: "Foo2" + } +} +)"; + Env* env = Env::Default(); + OpList op_defs; + protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); // NOLINT + + ApiDefMap api_def_map(op_defs); + string cc_file_text, h_file_text; + string internal_cc_file_text, internal_h_file_text; + // Without ApiDef + GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text, + &internal_h_file_text); + ExpectHasSubstr(h_file_text, "class Foo {"); + ExpectDoesNotHaveSubstr(h_file_text, "class Foo1"); + ExpectDoesNotHaveSubstr(h_file_text, "class Foo2"); + + // With ApiDef + TF_ASSERT_OK(api_def_map.LoadApiDef(api_def)); + GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text, + &internal_h_file_text); + ExpectHasSubstr(h_file_text, "class Foo1"); + ExpectHasSubstr(h_file_text, "typedef Foo1 Foo2"); + ExpectDoesNotHaveSubstr(h_file_text, "class Foo {"); +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index d7446b9560fd7dc8377ea3710641906b274313a9..ebc0c77828dc32ec170d4ddfbfa150d1f38ab27b 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -728,6 +728,24 @@ Status LgammaGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Lgamma", LgammaGrad); +Status SelectGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto comparator = op.input(0); + auto x = op.input(1); + auto zeros = ZerosLike(scope, x); + auto grad = grad_inputs[0]; + + auto gx_1 = Where3(scope, comparator, grad, zeros); + auto gx_2 = Where3(scope, comparator, zeros, grad); + + grad_outputs->push_back(NoGradient()); + grad_outputs->push_back(gx_1); + grad_outputs->push_back(gx_2); + return scope.status(); +} +REGISTER_GRADIENT_OP("Select", SelectGrad); + Status MinOrMaxGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index 6313f41da5e5f9cf88be4c8a84408a8df77f0e25..29def3c3ea2b0be963cae000db587f94fae5af55 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -865,5 +865,13 @@ TEST_F(NaryGradTest, Minimum) { RunTest(x, x_init_value, y, shape); } +TEST_F(NaryGradTest, Select) { + TensorShape shape({3, 4}); + auto x1 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto x2 = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto y = Where3(scope_, Greater(scope_, x1, x2), x1, x2); + RunTest({x1, x2}, {shape, shape}, {y}, {shape}); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index 09fadfcab51575798286876f9a4e0ee9a60940ac..13a3bba5e6d5ca19ff3f0eca76665ba7d3ab628d 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -196,6 +196,18 @@ Status MaxPoolGradV2Helper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("MaxPoolV2", MaxPoolGradV2Helper); +Status LRNGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs){ + internal::LRNGrad::Attrs grad_attrs; + + auto dx = internal::LRNGrad(scope, grad_inputs[0], op.input(0), op.output(0), + grad_attrs); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("LRN", LRNGradHelper); + } // anonymous namespace } // namespace ops } // namespace tensorflow diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index ac66f51cf01911957722e94ca28e8e78dc6de2ed..f9063e836509669d81d03b1d2f0d32d1166b6eca 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -191,5 +191,12 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) { RunTest(x, x_init_value, y, y_shape); } +TEST_F(NNGradTest, LRN){ + TensorShape x_shape({1, 1, 2, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + auto y = LRN(scope_, x); + RunTest(x, x_shape, y, x_shape); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index f98abc8a817eca7bc129bb03a2ad31b97d957065..acef098c7d07f45d171679bff7c41e13ef0424f1 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -62,6 +62,15 @@ Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) { export_dir); } +string GetTagsAsString(const std::unordered_set& tags) { + string tags_as_string = "{ "; + for (const string& tag : tags) { + tags_as_string = strings::StrCat(tags_as_string, tag, " "); + } + tags_as_string = strings::StrCat(tags_as_string, "}"); + return tags_as_string; +} + Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto, const std::unordered_set& tags, MetaGraphDef* meta_graph_def_to_load) { @@ -77,14 +86,9 @@ Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto, return Status::OK(); } } - string tags_as_string = "{ "; - for (const string& tag : tags) { - tags_as_string = strings::StrCat(tags_as_string, tag, " "); - } - tags_as_string = strings::StrCat(tags_as_string, "}"); return Status(error::Code::NOT_FOUND, "Could not find meta graph def matching supplied tags: " + - tags_as_string + + GetTagsAsString(tags) + ". To inspect available tag-sets in the SavedModel, please " "use the SavedModel CLI: `saved_model_cli`"); } @@ -233,7 +237,8 @@ Status LoadSavedModelInternal(const SessionOptions& session_options, return Status(error::Code::NOT_FOUND, "SavedModel not found in export directory: " + export_dir); } - LOG(INFO) << "Loading SavedModel from: " << export_dir; + LOG(INFO) << "Loading SavedModel with tags: " << GetTagsAsString(tags) + << "; from: " << export_dir; SavedModel saved_model_proto; TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto)); @@ -281,7 +286,8 @@ Status LoadSavedModel(const SessionOptions& session_options, return end_microseconds - start_microseconds; }(); auto log_and_count = [&](const string& status_str) { - LOG(INFO) << "Loading SavedModel: " << status_str << ". Took " + LOG(INFO) << "SavedModel load for tags " << GetTagsAsString(tags) + << "; Status: " << status_str << ". Took " << load_latency_microsecs << " microseconds."; load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1); }; diff --git a/tensorflow/cc/saved_model/tag_constants.h b/tensorflow/cc/saved_model/tag_constants.h index 2b0b2d5c7fb33768494c1781669c1adcb875a579..b71cb263ca42dab7e830c1880ec4b311bc272f82 100644 --- a/tensorflow/cc/saved_model/tag_constants.h +++ b/tensorflow/cc/saved_model/tag_constants.h @@ -21,6 +21,9 @@ namespace tensorflow { /// Tag for the `gpu` graph. constexpr char kSavedModelTagGpu[] = "gpu"; +/// Tag for the `tpu` graph. +constexpr char kSavedModelTagTpu[] = "tpu"; + /// Tag for the `serving` graph. constexpr char kSavedModelTagServe[] = "serve"; diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index a9a6ea84319a18a8fbce648391bf5918ff6d9a08..5740c040e309bad8d7e3bdc468c09a3323fb99e0 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -24,7 +24,6 @@ tf_cc_test( srcs = ["runtime_test.cc"], deps = [ ":runtime", - "//tensorflow/compiler/tf2xla:xla_local_runtime_context", "//tensorflow/core:framework", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -111,6 +110,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", + "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", ], diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index ae22f7edc423247b34895411d19d7a3c21f86d4f..53da2881b60db9ad39565567623eb86f754559af 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -101,21 +101,8 @@ Status ComputeArgSizes(const CompileResult& compile_result, std::vector* arg_sizes) { const xla::ProgramShape& ps = compile_result.program_shape; for (int i = 0; i < ps.parameters_size(); ++i) { - if (i == ps.parameters_size() - 1 && compile_result.has_context_arg) { - // If the compiled function needs a XlaLocalRuntimeContext* arg, it's - // always last, and must be represented as an opaque type. - const xla::PrimitiveType type = ps.parameters(i).element_type(); - if (type != xla::OPAQUE) { - return errors::InvalidArgument( - "expected final context arg to be opaque, but got type: ", - xla::PrimitiveType_Name(type), ", from program shape: ", - xla::ShapeUtil::HumanString(ps)); - } - arg_sizes->push_back(-1); - } else { - arg_sizes->push_back(xla::ShapeUtil::ByteSizeOf( - ps.parameters(i), compile_result.pointer_size)); - } + arg_sizes->push_back(xla::ShapeUtil::ByteSizeOf( + ps.parameters(i), compile_result.pointer_size)); } return Status::OK(); } @@ -165,11 +152,6 @@ string RewriteWithName(const string& name, string code, Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, const CompileResult& compile_result, string* methods) { size_t num_args = ps.parameters_size(); - if (compile_result.has_context_arg) { - // If the compiled function needs a XlaLocalRuntimeContext* arg, it's - // always last, and is set in the class constructor. - num_args--; - } if (config.feed_size() != num_args) { return errors::InvalidArgument("mismatch between feed_size(", config.feed_size(), ") and num_args(", @@ -418,7 +400,7 @@ namespace xla { class ExecutableRunOptions; } // (Implementation detail) Entry point to the function in the object file. extern "C" void {{ENTRY}}( void* result, const xla::ExecutableRunOptions* run_options, - const void** args, void** temps); + const void** args, void** temps, tensorflow::int64* profile_counters); {{NS_START}} // {{CLASS}} represents a computation previously specified in a @@ -474,7 +456,6 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { data->temp_sizes = TempSizes(); data->num_temps = kNumTemps; data->result_index = kResultIndex; - data->requires_runtime_context = {{HAS_CONTEXT_ARG}}; data->arg_names = StaticArgNames(); data->result_names = StaticResultNames(); data->program_shape = StaticProgramShape(); @@ -483,7 +464,7 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { return *kStaticData; } - {{CLASS}}(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS) + {{CLASS}}(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) : XlaCompiledCpuFunction(StaticData(), alloc_mode) {} {{CLASS}}(const {{CLASS}}&) = delete; @@ -496,8 +477,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { // void set_argN_data(void* data) // Sets the buffer of type T for positional argument N. May be called in // any AllocMode. Must be called before Run to have an affect. Must be - // called in AllocMode::RESULTS_AND_TEMPS_ONLY for each positional argument, - // to set the argument buffers. + // called in AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY for each positional + // argument, to set the argument buffers. // // T* argN_data() // Returns the buffer of type T for positional argument N. @@ -560,8 +541,6 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { {"{{ARG_SIZES}}", str_util::Join(arg_sizes, ", ")}, {"{{CLASS}}", opts.class_name}, {"{{ENTRY}}", compile_result.entry_point}, - {"{{HAS_CONTEXT_ARG}}", - compile_result.has_context_arg ? "true" : "false"}, {"{{INCLUDE_XLA_DATA_PROTO}}", include_xla_data_proto}, {"{{METHODS_ARG}}\n", methods_arg}, {"{{METHODS_RESULT}}\n", methods_result}, diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index 0f6114666fcc89c631434527d2ae8c92c039ffea..75026c57c04a64186a1e5be6c41e4dd7de8520b7 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -145,11 +145,9 @@ TEST(GenerateHeader, Golden) { { xla::ShapeUtil::MakeShape(xla::F32, {1, 2}), xla::ShapeUtil::MakeShape(xla::S64, {3, 4}), - xla::ShapeUtil::MakeOpaqueShape(), }, xla::ShapeUtil::MakeTupleShape( {xla::ShapeUtil::MakeShape(xla::U32, {5, 6})})); - compile_result.has_context_arg = true; compile_result.entry_point = "entry_point"; compile_result.pointer_size = 8; string header; diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index 65f342ce27ef09092f252f791973f245a8cdd6f3..95ab3a7332f51070732bf5d62c7926c84e3b738d 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -19,7 +19,7 @@ namespace xla { class ExecutableRunOptions; } // (Implementation detail) Entry point to the function in the object file. extern "C" void entry_point( void* result, const xla::ExecutableRunOptions* run_options, - const void** args, void** temps); + const void** args, void** temps, tensorflow::int64* profile_counters); namespace foo { namespace bar { @@ -48,7 +48,7 @@ namespace bar { // is guaranteed that no thread may call a non-const method. // // The logical function signature is: -// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): opaque[]) -> (u32[5,6]) +// ((unknown): f32[1,2], (unknown): s64[3,4]) -> (u32[5,6]) // // Memory stats: // arg bytes total: 104 @@ -58,11 +58,11 @@ namespace bar { class MyClass : public tensorflow::XlaCompiledCpuFunction { public: // Number of input arguments for the compiled computation. - static constexpr size_t kNumArgs = 3; + static constexpr size_t kNumArgs = 2; // Byte size of each argument buffer. There are kNumArgs entries. static const intptr_t* ArgSizes() { - static constexpr intptr_t kArgSizes[kNumArgs] = {8, 96, -1}; + static constexpr intptr_t kArgSizes[kNumArgs] = {8, 96}; return kArgSizes; } @@ -77,7 +77,6 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { data->temp_sizes = TempSizes(); data->num_temps = kNumTemps; data->result_index = kResultIndex; - data->requires_runtime_context = true; data->arg_names = StaticArgNames(); data->result_names = StaticResultNames(); data->program_shape = StaticProgramShape(); @@ -86,7 +85,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { return *kStaticData; } - MyClass(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS) + MyClass(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) : XlaCompiledCpuFunction(StaticData(), alloc_mode) {} MyClass(const MyClass&) = delete; @@ -99,8 +98,8 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { // void set_argN_data(void* data) // Sets the buffer of type T for positional argument N. May be called in // any AllocMode. Must be called before Run to have an affect. Must be - // called in AllocMode::RESULTS_AND_TEMPS_ONLY for each positional argument, - // to set the argument buffers. + // called in AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY for each positional + // argument, to set the argument buffers. // // T* argN_data() // Returns the buffer of type T for positional argument N. @@ -236,8 +235,8 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { // Shape of the args and results. static const xla::ProgramShape* StaticProgramShape() { static const xla::ProgramShape* kShape = []() { - static const char kProto[] = {10,12,16,11,26,2,1,2,42,4,10,2,1,0,10,12,16,5,26,2,3,4,42,4,10,2,1,0,10,2,16,14,18,16,16,13,34,12,16,8,26,2,5,6,42,4,10,2,1,0}; - static constexpr int kProtoSize = 50; + static const char kProto[] = {10,14,16,11,26,2,1,2,42,6,10,2,1,0,32,1,10,14,16,5,26,2,3,4,42,6,10,2,1,0,32,1,18,18,16,13,34,14,16,8,26,2,5,6,42,6,10,2,1,0,32,1}; + static constexpr int kProtoSize = 52; xla::ProgramShape* shape = new xla::ProgramShape; shape->ParseFromArray(kProto, kProtoSize); return shape; diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index 2b8cc6024cb85e4f6269313927ff66d1d9a1cf79..c87f2b75dfa18ad5c3eda4bd6fcbcb3083ef73fd 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -94,9 +94,8 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config, xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform) .ValueOrDie(); xla::Computation computation; - TF_RETURN_IF_ERROR(ConvertGraphDefToXla(graph_def, config, client, - &computation, - &compile_result->has_context_arg)); + TF_RETURN_IF_ERROR( + ConvertGraphDefToXla(graph_def, config, client, &computation)); if (!flags.out_session_module.empty()) { TF_ASSIGN_OR_RETURN(std::unique_ptr module, computation.Snapshot()); diff --git a/tensorflow/compiler/aot/compile.h b/tensorflow/compiler/aot/compile.h index 965c2960816b3acc8d2209e6824d88647de0ce14..e03c5b1aa77c1262ed903aae3072ef65f34d80a2 100644 --- a/tensorflow/compiler/aot/compile.h +++ b/tensorflow/compiler/aot/compile.h @@ -34,7 +34,6 @@ struct CompileResult { // Contains object file and meta-info. std::unique_ptr aot; xla::ProgramShape program_shape; // Static shape of args and results. - bool has_context_arg = false; // Is last arg XlaLocalRuntimeContext? string entry_point; // Name of generated function. int pointer_size = 0; // Size of a pointer in bytes. }; diff --git a/tensorflow/compiler/aot/runtime_test.cc b/tensorflow/compiler/aot/runtime_test.cc index ac79c278c1fdf8b6aedcb52121c767b8ba0ad358..6d603a02eb4ceade6832ba67b2981814ee25327a 100644 --- a/tensorflow/compiler/aot/runtime_test.cc +++ b/tensorflow/compiler/aot/runtime_test.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/aot/runtime.h" -#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py index a898eab1d1ab0eb5d55983bf366753c968887296..89c7cd4507cbd476104a039d6083d8f89de11278 100644 --- a/tensorflow/compiler/aot/tests/make_test_graphs.py +++ b/tensorflow/compiler/aot/tests/make_test_graphs.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import argparse +import os import sys from tensorflow.core.protobuf import saver_pb2 @@ -53,7 +54,7 @@ def tfadd_with_ckpt(out_dir): sess.run(init_op) sess.run(y.assign(y + 42)) # Without the checkpoint, the variable won't be set to 42. - ckpt = '%s/test_graph_tfadd_with_ckpt.ckpt' % out_dir + ckpt = os.path.join(out_dir, 'test_graph_tfadd_with_ckpt.ckpt') saver.save(sess, ckpt) @@ -68,10 +69,10 @@ def tfadd_with_ckpt_saver(out_dir): sess.run(init_op) sess.run(y.assign(y + 42)) # Without the checkpoint, the variable won't be set to 42. - ckpt_file = '%s/test_graph_tfadd_with_ckpt_saver.ckpt' % out_dir + ckpt_file = os.path.join(out_dir, 'test_graph_tfadd_with_ckpt_saver.ckpt') saver.save(sess, ckpt_file) # Without the SaverDef, the restore op won't be named correctly. - saver_file = '%s/test_graph_tfadd_with_ckpt_saver.saver' % out_dir + saver_file = os.path.join(out_dir, 'test_graph_tfadd_with_ckpt_saver.saver') with open(saver_file, 'wb') as f: f.write(saver.as_saver_def().SerializeToString()) @@ -129,7 +130,7 @@ def write_graph(build_graph, out_dir): g = ops.Graph() with g.as_default(): build_graph(out_dir) - filename = '%s/test_graph_%s.pb' % (out_dir, build_graph.__name__) + filename = os.path.join(out_dir, 'test_graph_%s.pb' % build_graph.__name__) with open(filename, 'wb') as f: f.write(g.as_graph_def().SerializeToString()) diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index 6b037f276ad1d6771b904bb970f45f32ae9531b8..413efd9cea3b6f71574615ad9ca92471ff925781 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -70,7 +70,7 @@ TEST(TFCompileTest, Add) { // Run tests that use set_argN_data separately, to avoid accidentally re-using // non-existent buffers. TEST(TFCompileTest, Add_SetArg) { - AddComp add(AddComp::AllocMode::RESULTS_AND_TEMPS_ONLY); + AddComp add(AddComp::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY); int32 arg_x = 10; int32 arg_y = 32; @@ -258,7 +258,7 @@ TEST(TFCompileTest, MatMul2_SetArg) { Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); foo::bar::MatMulComp matmul( - foo::bar::MatMulComp::AllocMode::RESULTS_AND_TEMPS_ONLY); + foo::bar::MatMulComp::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY); matmul.set_thread_pool(&device); // Test using the set_argN_data() methods. diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 1e22b760b8a4189165a59ac307374277474bbc31..542451ed2d14fbceca00c6ccb6e28c1c3a0d4321 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -152,7 +152,7 @@ def tf_library(name, graph, config, " --target_triple=" + target_llvm_triple() + " --out_header=$(@D)/" + header_file + " --out_object=$(@D)/" + object_file + - flags), + " " + flags), tools=[tfcompile_tool], visibility=visibility, testonly=testonly, @@ -189,7 +189,7 @@ def tf_library(name, graph, config, " --cpp_class=" + cpp_class + " --target_triple=" + target_llvm_triple() + " --out_session_module=$(@D)/" + session_module_pb + - flags), + " " + flags), tools=[tfcompile_tool], visibility=visibility, testonly=testonly, @@ -267,7 +267,6 @@ def tf_library(name, graph, config, srcs=[test_file], deps=[ ":" + name, - "@org_tensorflow//tensorflow/compiler/tf2xla:xla_local_runtime_context", "@org_tensorflow//tensorflow/compiler/aot:runtime", "@org_tensorflow//tensorflow/compiler/aot:tf_library_test_main", "@org_tensorflow//tensorflow/compiler/xla:executable_run_options", @@ -313,7 +312,6 @@ def tf_library(name, graph, config, linkopts = if_android(["-pie", "-s"]), deps=[ ":" + name, - "@org_tensorflow//tensorflow/compiler/tf2xla:xla_local_runtime_context", "@org_tensorflow//tensorflow/compiler/aot:benchmark", "@org_tensorflow//tensorflow/compiler/aot:runtime", "@org_tensorflow//tensorflow/compiler/xla:executable_run_options", diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index bf7d9cf14d10f41aa48ea594a8d63db97b9973e1..026a1bf879d373fd0f5f4444b3ce10d01702f82b 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -251,6 +251,7 @@ cc_library( "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 22899ebeebc929055518893b358f7950d380d6f6..407b7dcbfb4b36674928d68eedaf58fcefc645f2 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -16,7 +16,11 @@ limitations under the License. #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include +#include #include +#include +#include +#include #include "tensorflow/compiler/jit/graph_to_functiondef.h" #include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h" @@ -32,6 +36,7 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -48,6 +53,52 @@ const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs"; namespace { +bool AreAllParentsConst(const Node& n, + const gtl::FlatSet& runtime_const_nodes) { + if (n.type_string() == "GuaranteeConst" || n.type_string() == "Const") { + // If the current node is itself a cast-to-const, no need + // to look at the incoming edges. + return true; + } + + bool all_parents_const = true; + bool atleast_one_non_control_edge = false; + for (const Edge* in : n.in_edges()) { + atleast_one_non_control_edge = + atleast_one_non_control_edge || !in->IsControlEdge(); + if (!in->IsControlEdge() && runtime_const_nodes.count(in->src()) == 0) { + all_parents_const = false; + break; + } + } + return all_parents_const && atleast_one_non_control_edge; +} + +void MarkGuaranteedConstants( + const Graph& graph, + const std::vector>& src_arg_pairs) { + gtl::FlatSet guaranteed_const_nodes; + std::vector srcs; + srcs.reserve(src_arg_pairs.size()); + for (const auto& src_arg : src_arg_pairs) { + srcs.push_back(src_arg.first); + } + ReverseDFSFrom(graph, srcs, /*enter=*/nullptr, + /*leave=*/[&guaranteed_const_nodes](const Node* n) { + // TODO(vinuraja): Doesn't work in the presence of loops. + if (AreAllParentsConst(*n, guaranteed_const_nodes)) { + guaranteed_const_nodes.insert(n); + } + }); + + for (auto& src_arg : src_arg_pairs) { + if (guaranteed_const_nodes.count(src_arg.first) != 0) { + VLOG(1) << "Guaranteed const found: " << src_arg.first->DebugString(); + src_arg.second->AddAttr("_is_guaranteed_constant", true); + } + } +} + // A node/slot pair. // TODO(phawkins): is there a common definition of this? struct NodeSlot { @@ -75,6 +126,11 @@ struct NodeSlot { }; }; +// TODO(phawkins) add a canonical copy of these operator names and refactor +// everything to use it. +static const char* const kArgOp = "_Arg"; +static const char* const kRetValOp = "_Retval"; + class Encapsulator { public: Encapsulator(string group_attribute, Graph const* graph_in) @@ -99,54 +155,167 @@ class Encapsulator { Status BuildOutputGraph(bool parallel_checking, Graph* graph_out); private: - // Returns the key attribute associated with a node. Returns the empty string - // if no key attribute is found. - string GetFunctionNameAttr(const Node* node) const; - // A subgraph of the input, all marked with a common 'group_attribute' // value. - struct Subgraph { + class Subgraph { + public: + // Creates a graph to build the subgraph in, if it doesn't already exist, + // using the same op registry and versions as graph_in. + Node* MakeNodeImage(const Graph* graph_in, Node* node); + + // Returns the graph the subgraph is being built in. + Graph* GetGraph() const; + + // Builds a FunctionDef, and adds it to 'library'. The value of the + // 'group_attribute' annotations becomes the function name. If + // 'reuse_existing_functions' is set, use an existing function with the same + // name, if any. If 'rewrite_subgraph_fn' is set, it is applied to the + // subgraph before function conversion. + Status BuildFunctionDef(const string& name_in, + const RewriteSubgraphFn& rewrite_subgraph_fn, + bool reuse_existing_functions, + FunctionLibraryDefinition* library); + + // Adds the function call node to graph_out. + Status AddFunctionCallNode( + const std::unordered_map& node_images, + bool parallel_checking, Graph* graph_out); + + // Returns the Node that inputs to the function should be wired up to. + Node* GetCallNodeForInputs() const; + + // Returns the Node that outputs to the function should be wired up to. + Node* GetCallNodeForOutputs() const; + + // Returns the index of the arg that the dst of edge should connect to. + int GetArgIndexForEdge(const Edge* edge) const; + + // Returns the index of the result that the src of edge should connect to. + int GetResultIndexForEdge(const Edge* edge) const; + + // Creates an _Arg node for the src node of edge, and add its index to + // args_by_src_, if none exists yet. Also adds its index to args_by_dst_, + // and adds the edge within the subgraph from the _Arg node to the image of + // the dst node. + Status RecordArg(const Edge* edge, + const std::unordered_map& node_images, + std::vector>* src_arg_pairs); + + // Creates a _Retval node for the src node of edge, and add it to results_, + // if none exists yet. If a new _Retval node is created, also adds the edge + // within the subgraph from the src to the _Retval node. + Status RecordResult( + const Edge* edge, + const std::unordered_map& node_images); + + private: + // Builds a ParallelCheck op that compares the output of the original + // subgraph with the encapsulated subgraph. + Status BuildParallelCheckOp( + const std::unordered_map& node_images, + Graph* graph_out); + // The subgraph extracted from the input graph, suitable for being turned // into a FunctionDef. Inputs are fed by _Arg nodes, and outputs are // returned by _Retval nodes. - std::unique_ptr graph; + std::unique_ptr graph_; // Which device are these nodes on? Used to assign a device to the call // node. - string device; + string device_; // NodeDef for the function call node. - NodeDef call_node_def; + NodeDef call_node_def_; // Function call node(s) in the output graph. Not owned. // If parallel_checking is enabled, 'call_node_inputs' is the function call // node to which inputs should be fed, and 'call_node_outputs' is the // parallel check op from which outputs should be read. If parallel checking // is disabled, both point to the function call node. - Node* call_node_inputs; - Node* call_node_outputs; + Node* call_node_inputs_; + Node* call_node_outputs_; // Maps from source (producer node/slot) and destination // (consumer node/slot) tensors in the input graph to _Arg numbers in // the subgraph. The source map is one-to-one, whereas the dest map may be // many-to-one. - std::unordered_map args_by_src; - std::unordered_map args_by_dst; + std::unordered_map args_by_src_; + std::unordered_map args_by_dst_; // The _Arg nodes in the subgraph, in order by argument number. - std::vector args; + std::vector args_; // Map from source tensor in the input graph to result #. - std::unordered_map results; + std::unordered_map results_; }; - // Builds a ParallelCheck op that compares the output of the original subgraph - // with the encapsulated subgraph. - Status BuildParallelCheckOp( + // Returns the key attribute associated with a node in attr. Sets attr to the + // empty string if the attribute is not found. + Status GetFunctionNameAttr(const Node* node, string* attr) const; + + // Copies edges local to a subgraph. Adds _Arg and _Retval nodes to subgraphs + // for data edges that cross subgraph boundaries. + Status CopySubgraphEdges( + const std::unordered_map& node_images, + std::vector>* src_arg_pairs); + + // Copies all marked nodes to a subgraph. Does nothing for unmarked nodes. + Status CopySubgraphNodes(std::unordered_map* node_images); + + // Copies all nodes that aren't in a compiled subgraph to the output graph. + Status CopyNodesToOutputGraph( + bool parallel_checking, Graph* graph_out, + std::unordered_map* node_images); + + // Adds function call nodes for each compiled subgraph. + Status AddFunctionCallNodes( const std::unordered_map& node_images, - const Subgraph& subgraph, Graph* graph_out, Node** parallel_check_op); + bool parallel_checking, Graph* graph_out); + + // Finds the image of an edge source in the output graph. If the edge crosses + // a subgraph boundary it is the output of a call node, otherwise it is a node + // in the output graph. + Status FindOutputImageOfEdgeSrc( + const string& src_func_id, const string& dst_func_id, + const std::unordered_map& node_images, + const Node* original_src_node, Node** src_image); + + // Finds an edge source slot in the output graph. If the edge crosses a + // subgraph boundary it is a slot on the output of a call node, otherwise it + // is a slot on a node in the output graph. + int FindOutputSlotOfEdgeSrc(const string& src_func_id, + const string& dst_func_id, const Edge* edge); + + // Finds the image of an edge destination in the output graph. If the edge + // crosses a subgraph boundary it is the input of a call node, otherwise it is + // a node in the output graph. + Status FindOutputImageOfEdgeDst( + const string& src_func_id, const string& dst_func_id, + const std::unordered_map& node_images, + const Node* original_dst_node, Node** dst_image); + + // Finds an edge destination slot in the output graph. If the edge crosses a + // subgraph boundary it is a slot on the input of a call node, otherwise it is + // a slot on a node in the output graph. + int FindOutputSlotOfEdgeDst(const string& src_func_id, + const string& dst_func_id, const Edge* edge); + + // Copies a single edge to the output graph. The edge is either entirely + // within the output graph, or crosses into or out of a compiled subgraph. + Status CopyEdgeToOutputGraph( + const Edge* edge, const string& src_func_id, const string& dst_func_id, + const std::unordered_map& node_images, + bool parallel_checking, Graph* graph_out, + std::unordered_set, NodeSlot::PairHasher>* + edges_added); + + // Adds all edges to the output graph. + Status AddEdgesToOutputGraph( + const std::unordered_map& node_images, + bool parallel_checking, Graph* graph_out); const string group_attribute_; + const string outside_compilation_attribute_; const Graph* graph_in_; std::unordered_map subgraphs_; @@ -154,224 +323,184 @@ class Encapsulator { TF_DISALLOW_COPY_AND_ASSIGN(Encapsulator); }; -// TODO(phawkins) add a canonical copy of these operator names and refactor -// everything to use it. -static const char* const kArgOp = "_Arg"; -static const char* const kRetValOp = "_Retval"; - -// Returns the function name attached to 'node', or the empty string if there is -// none. -string Encapsulator::GetFunctionNameAttr(Node const* node) const { - string attr; - if (!GetNodeAttr(node->attrs(), group_attribute_, &attr).ok()) { - attr.clear(); - } - return attr; +Node* Encapsulator::Subgraph::GetCallNodeForInputs() const { + return call_node_inputs_; } -Status Encapsulator::SplitIntoSubgraphs() { - Status s; - - // Map from input graph nodes to subgraph nodes. - std::unordered_map node_images; - - // Copy all marked nodes to a subgraph. Do nothing for unmarked nodes. - for (Node* node : graph_in_->op_nodes()) { - string func_id = GetFunctionNameAttr(node); - if (func_id.empty()) continue; +Node* Encapsulator::Subgraph::GetCallNodeForOutputs() const { + return call_node_outputs_; +} - Subgraph& subgraph = subgraphs_[func_id]; - if (!subgraph.graph) { - subgraph.graph.reset(new Graph(graph_in_->op_registry())); - subgraph.graph->set_versions(graph_in_->versions()); - } +int Encapsulator::Subgraph::GetArgIndexForEdge(const Edge* edge) const { + return args_by_dst_.at(NodeSlot(edge->dst(), edge->dst_input())); +} - Node* image = subgraph.graph->CopyNode(node); - image->ClearAttr(group_attribute_); - node_images[node] = image; +int Encapsulator::Subgraph::GetResultIndexForEdge(const Edge* edge) const { + return results_.at(NodeSlot(edge->src(), edge->src_output())); +} - if (subgraph.device.empty()) { - subgraph.device = node->assigned_device_name().empty() - ? node->requested_device() - : node->assigned_device_name(); - } +Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) { + if (!graph_) { + graph_.reset(new Graph(graph_in->op_registry())); + graph_->set_versions(graph_in->versions()); } - // Copy edges local to a subgraph. Add _Arg and _Retval nodes to subgraphs for - // data edges that cross subgraph boundaries. - for (const Edge* edge : graph_in_->edges()) { - string src_func_id = GetFunctionNameAttr(edge->src()); - string dst_func_id = GetFunctionNameAttr(edge->dst()); - Node* src_image = gtl::FindWithDefault(node_images, edge->src(), nullptr); - Node* dst_image = gtl::FindWithDefault(node_images, edge->dst(), nullptr); - - // Copy edges that are local to a subgraph. - if (!src_func_id.empty() && src_func_id == dst_func_id) { - Graph* g = subgraphs_[src_func_id].graph.get(); - if (edge->IsControlEdge()) { - g->AddControlEdge(src_image, dst_image); - } else { - g->AddEdge(src_image, edge->src_output(), dst_image, edge->dst_input()); - } - continue; - } - - // Ignore cross-boundary control edges for right now. We will lift them - // onto the enclosing call operators in BuildOutputGraph(). - if (edge->IsControlEdge()) continue; + if (device_.empty()) { + device_ = node->assigned_device_name().empty() + ? node->requested_device() + : node->assigned_device_name(); + } - // Add 'src' as an output of its subgraph, if applicable. - if (!src_func_id.empty()) { - Subgraph& src_subgraph = subgraphs_[src_func_id]; - int ret_index = src_subgraph.results.size(); - if (src_subgraph.results - .emplace(NodeSlot(edge->src(), edge->src_output()), ret_index) - .second) { - // Create a new _Retval node - DataType dtype = edge->src()->output_type(edge->src_output()); - - if (IsRefType(dtype)) { - return errors::InvalidArgument( - "Ref Tensors (e.g., Variables) are not supported: tensor ", - edge->src()->name(), ":", edge->src_output()); - } + return graph_->CopyNode(node); +} - NodeDef ret_def; - ret_def.set_op(kRetValOp); - ret_def.set_name(strings::StrCat(edge->src()->name(), "_", - edge->src_output(), "_retval")); - AddNodeAttr("T", dtype, &ret_def); - AddNodeAttr("index", ret_index, &ret_def); - Node* ret = src_subgraph.graph->AddNode(ret_def, &s); - if (!s.ok()) return s; - - // Add an edge from 'src' to _Retval. - src_subgraph.graph->AddEdge(src_image, edge->src_output(), ret, 0); - } +Graph* Encapsulator::Subgraph::GetGraph() const { return graph_.get(); } + +Status Encapsulator::Subgraph::RecordArg( + const Edge* edge, const std::unordered_map& node_images, + std::vector>* src_arg_pairs) { + Node* src_node = edge->src(); + int src_slot = edge->src_output(); + std::unordered_map::iterator iter; + bool inserted; + std::tie(iter, inserted) = + args_by_src_.emplace(NodeSlot(src_node, src_slot), args_by_src_.size()); + int arg_index = iter->second; + if (inserted) { + // Look at the type of the destination not the source, since Ref output + // Tensors can be automatically cast to non-Ref Tensors at the destination. + DataType dtype = edge->dst()->input_type(edge->dst_input()); + + if (IsRefType(dtype)) { + return errors::InvalidArgument( + "Ref Tensors (e.g., Variables) are not supported as args: tensor ", + src_node->name(), ":", src_slot); } - // Add 'dst' as an input of its subgraph, if applicable. - if (!dst_func_id.empty()) { - Subgraph& dst_subgraph = subgraphs_[dst_func_id]; - - // Create an _Arg node for this tensor, if none exists yet. - std::unordered_map::iterator iter; - bool inserted; - std::tie(iter, inserted) = dst_subgraph.args_by_src.emplace( - NodeSlot(edge->src(), edge->src_output()), dst_subgraph.args.size()); - int arg_index = iter->second; - if (inserted) { - // This is the first time we have seen this tensor. Create an _Arg node. - DataType dtype = edge->dst()->input_type(edge->dst_input()); - - if (IsRefType(dtype)) { - return errors::InvalidArgument( - "Ref Tensors (e.g., Variables) are not supported: tensor ", - edge->src()->name(), ":", edge->src_output()); - } + NodeDef arg_def; + NodeDefBuilder builder( + strings::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp); + builder.Attr("T", dtype); + builder.Attr("index", arg_index); + Status s = builder.Finalize(&arg_def); + if (!s.ok()) return s; - NodeDef arg_def; - NodeDefBuilder builder(strings::StrCat(edge->src()->name(), "_", - edge->src_output(), "_arg"), - kArgOp); - builder.Attr("T", dtype); - builder.Attr("index", arg_index); - s = builder.Finalize(&arg_def); - if (!s.ok()) return s; + Node* arg = graph_->AddNode(arg_def, &s); + if (!s.ok()) return s; - Node* arg = dst_subgraph.graph->AddNode(arg_def, &s); - if (!s.ok()) return s; + src_arg_pairs->push_back({src_node, arg}); + args_.push_back(arg); + } + Node* dst_node = edge->dst(); + Node* dst_image = node_images.at(dst_node); + int dst_slot = edge->dst_input(); + args_by_dst_[NodeSlot(dst_node, dst_slot)] = arg_index; + graph_->AddEdge(args_[arg_index], 0, dst_image, dst_slot); + return Status::OK(); +} - dst_subgraph.args.push_back(arg); - } - // Add an edge from the _Arg node to 'dst' in the subgraph. - dst_subgraph.args_by_dst[NodeSlot(edge->dst(), edge->dst_input())] = - arg_index; - dst_subgraph.graph->AddEdge(dst_subgraph.args[arg_index], 0, dst_image, - edge->dst_input()); +Status Encapsulator::Subgraph::RecordResult( + const Edge* edge, + const std::unordered_map& node_images) { + Node* src_node = edge->src(); + Node* src_image = node_images.at(src_node); + int src_slot = edge->src_output(); + std::unordered_map::iterator iter; + bool inserted; + std::tie(iter, inserted) = + results_.emplace(NodeSlot(src_node, src_slot), results_.size()); + int ret_index = iter->second; + if (inserted) { + DataType dtype = src_node->output_type(src_slot); + + if (IsRefType(dtype)) { + return errors::InvalidArgument( + "Ref Tensors (e.g., Variables) are not supported as results: tensor ", + src_node->name(), ":", src_slot); } - } - for (auto& entry : subgraphs_) { - FixupSourceAndSinkEdges(entry.second.graph.get()); - } + NodeDef ret_def; + NodeDefBuilder builder( + strings::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp); + builder.Attr("T", dtype); + builder.Attr("index", ret_index); + builder.Input(src_image->name(), src_slot, dtype); + Status s = builder.Finalize(&ret_def); + if (!s.ok()) return s; + Node* ret = graph_->AddNode(ret_def, &s); + if (!s.ok()) return s; - return s; + graph_->AddEdge(src_image, src_slot, ret, 0); + } + return Status::OK(); } -Status Encapsulator::BuildFunctionDefs( - const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions, - FunctionLibraryDefinition* library) { - // For each subgraph, build a FunctionDef. - for (auto& subgraph_entry : subgraphs_) { - string name = subgraph_entry.first; - Subgraph& subgraph = subgraph_entry.second; - - subgraph.call_node_def.set_op(name); - subgraph.call_node_def.set_name(name); - subgraph.call_node_def.set_device(subgraph.device); - - if (rewrite_subgraph_fn) { - // Initialize the input and output permutations to the identity. - std::vector input_permutation(subgraph.args_by_src.size()); - std::iota(input_permutation.begin(), input_permutation.end(), 0); - std::vector output_permutation(subgraph.results.size()); - std::iota(output_permutation.begin(), output_permutation.end(), 0); - - TF_RETURN_IF_ERROR( - rewrite_subgraph_fn(&subgraph.graph, &input_permutation, - &output_permutation, &subgraph.call_node_def)); - - // Apply the input/output permutations to the 'args_by_...' and 'results' - // mappings in 'subgraph', so when we build edges in BuildOutputGraph() we - // connect them to the right input/output positions. - if (input_permutation.size() != subgraph.args_by_src.size()) { - return errors::InvalidArgument("Input permutation has incorrect size."); - } - if (output_permutation.size() != subgraph.results.size()) { - return errors::InvalidArgument( - "Output permutation has incorrect size."); - } - for (auto& arg : subgraph.args_by_src) { - arg.second = input_permutation[arg.second]; - } - for (auto& arg : subgraph.args_by_dst) { - arg.second = input_permutation[arg.second]; - } - for (auto& result : subgraph.results) { - result.second = output_permutation[result.second]; - } - - name = subgraph.call_node_def.op(); +Status Encapsulator::Subgraph::BuildFunctionDef( + const string& name_in, const RewriteSubgraphFn& rewrite_subgraph_fn, + bool reuse_existing_functions, FunctionLibraryDefinition* library) { + // name_in is copied here because name may be modified below if + // rewrite_subgraph_fn is true. + string name = name_in; + call_node_def_.set_op(name); + call_node_def_.set_name(name); + call_node_def_.set_device(device_); + + if (rewrite_subgraph_fn) { + // Initialize the input and output permutations to the identity. + std::vector input_permutation(args_by_src_.size()); + std::iota(input_permutation.begin(), input_permutation.end(), 0); + std::vector output_permutation(results_.size()); + std::iota(output_permutation.begin(), output_permutation.end(), 0); + + TF_RETURN_IF_ERROR(rewrite_subgraph_fn( + &graph_, &input_permutation, &output_permutation, &call_node_def_)); + + // Apply the input/output permutations to the 'args_by_...' and 'results_' + // mappings, so when we build edges in BuildOutputGraph() we + // connect them to the right input/output positions. + if (input_permutation.size() != args_by_src_.size()) { + return errors::InvalidArgument("Input permutation has incorrect size."); + } + if (output_permutation.size() != results_.size()) { + return errors::InvalidArgument("Output permutation has incorrect size."); + } + for (auto& arg : args_by_src_) { + arg.second = input_permutation[arg.second]; + } + for (auto& arg : args_by_dst_) { + arg.second = input_permutation[arg.second]; + } + for (auto& result : results_) { + result.second = output_permutation[result.second]; } - FunctionDef fdef; - TF_RETURN_IF_ERROR(GraphToFunctionDef(*subgraph.graph, name, &fdef)); + name = call_node_def_.op(); + } - if (VLOG_IS_ON(1)) { - VLOG(2) << "Build function def " << name; - dump_graph::DumpGraphToFile( - strings::StrCat("encapsulate_fdef_graph_", name), *subgraph.graph, - library); - dump_graph::DumpFunctionDefToFile( - strings::StrCat("encapsulate_fdef_", name), fdef); - } + FunctionDef fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef)); - if (!reuse_existing_functions || library->Find(name) == nullptr) { - TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); - } + if (VLOG_IS_ON(1)) { + VLOG(2) << "Build function def " << name; + dump_graph::DumpGraphToFile( + strings::StrCat("encapsulate_fdef_graph_", name), *graph_, library); + dump_graph::DumpFunctionDefToFile( + strings::StrCat("encapsulate_fdef_", name), fdef); + } + + if (!reuse_existing_functions || library->Find(name) == nullptr) { + TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); } return Status::OK(); } -Status Encapsulator::BuildParallelCheckOp( +Status Encapsulator::Subgraph::BuildParallelCheckOp( const std::unordered_map& node_images, - const Encapsulator::Subgraph& subgraph, Graph* graph_out, - Node** parallel_check_op) { + Graph* graph_out) { // Build an index mapping output positions to node/slot pairs in the // original graph. - std::vector results_by_num(subgraph.results.size()); - for (const auto& entry : subgraph.results) { + std::vector results_by_num(results_.size()); + for (const auto& entry : results_) { results_by_num[entry.second] = entry.first; } @@ -386,22 +515,22 @@ Status Encapsulator::BuildParallelCheckOp( expected_outputs[i] = NodeDefBuilder::NodeOut(node_images.at(node_slot.node)->name(), node_slot.slot, result_dtypes[i]); - actual_outputs[i] = NodeDefBuilder::NodeOut(subgraph.call_node_def.name(), - i, result_dtypes[i]); + actual_outputs[i] = + NodeDefBuilder::NodeOut(call_node_def_.name(), i, result_dtypes[i]); } // Assign the parallel check op to a CPU on the same task as the cluster it is // checking. string device, dummy; if (!DeviceNameUtils::SplitDeviceName( - subgraph.call_node_inputs->assigned_device_name(), &device, &dummy)) { + call_node_inputs_->assigned_device_name(), &device, &dummy)) { return errors::InvalidArgument("Could not parse device name"); } strings::StrAppend(&device, "/cpu:0"); NodeDef check_def; TF_RETURN_IF_ERROR( - NodeDefBuilder(graph_out->NewName(strings::StrCat( - subgraph.call_node_def.name(), "_parallel_check")), + NodeDefBuilder(graph_out->NewName(strings::StrCat(call_node_def_.name(), + "_parallel_check")), "ParallelCheck") .Device(device) .Attr("T", result_dtypes) @@ -421,65 +550,303 @@ Status Encapsulator::BuildParallelCheckOp( const NodeSlot& node_slot = results_by_num[i]; graph_out->AddEdge(node_images.at(node_slot.node), node_slot.slot, check_op, i); - graph_out->AddEdge(subgraph.call_node_inputs, i, check_op, num_results + i); + graph_out->AddEdge(call_node_inputs_, i, check_op, num_results + i); } - *parallel_check_op = check_op; + call_node_outputs_ = check_op; return Status::OK(); } -Status Encapsulator::BuildOutputGraph(bool parallel_checking, - Graph* graph_out) { +Status Encapsulator::Subgraph::AddFunctionCallNode( + const std::unordered_map& node_images, + bool parallel_checking, Graph* graph_out) { Status s; + call_node_inputs_ = graph_out->AddNode(call_node_def_, &s); + if (!s.ok()) return s; - // Map from nodes in the input graph to nodes in the output graph. + // Copy the assigned device and the key_annotation over. + call_node_inputs_->set_assigned_device_name(device_); + call_node_outputs_ = call_node_inputs_; + + if (parallel_checking) { + TF_RETURN_IF_ERROR(BuildParallelCheckOp(node_images, graph_out)); + } + return Status::OK(); +} + +Status Encapsulator::GetFunctionNameAttr(Node const* node, string* attr) const { + Status s = GetNodeAttr(node->attrs(), group_attribute_, attr); + if (s.code() == error::Code::NOT_FOUND) { + // Return empty attr if there's no group_attribute. + attr->clear(); + return Status::OK(); + } + return s; +} + +bool IsInSubgraph(const string& func_id) { return !func_id.empty(); } + +Status Encapsulator::CopySubgraphNodes( + std::unordered_map* node_images) { + for (Node* node : graph_in_->op_nodes()) { + string func_id; + TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &func_id)); + if (!IsInSubgraph(func_id)) continue; + + Subgraph& subgraph = subgraphs_[func_id]; + Node* image = subgraph.MakeNodeImage(graph_in_, node); + image->ClearAttr(group_attribute_); + (*node_images)[node] = image; + } + return Status::OK(); +} + +Status Encapsulator::CopySubgraphEdges( + const std::unordered_map& node_images, + std::vector>* src_arg_pairs) { + for (const Edge* edge : graph_in_->edges()) { + string src_func_id; + TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id)); + string dst_func_id; + TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id)); + Node* src_image = gtl::FindWithDefault(node_images, edge->src(), nullptr); + Node* dst_image = gtl::FindWithDefault(node_images, edge->dst(), nullptr); + + // Copy edges that are local to a subgraph. + if (IsInSubgraph(src_func_id) && IsInSubgraph(dst_func_id) && + src_func_id == dst_func_id) { + Graph* g = subgraphs_[src_func_id].GetGraph(); + if (edge->IsControlEdge()) { + g->AddControlEdge(src_image, dst_image); + } else { + g->AddEdge(src_image, edge->src_output(), dst_image, edge->dst_input()); + } + continue; + } + + // Record 'src' as an output of its subgraph, if applicable. + if (IsInSubgraph(src_func_id)) { + Subgraph& src_subgraph = subgraphs_[src_func_id]; + // Ignore control edges leaving the subgraph. We will lift them onto the + // enclosing call operators in BuildOutputGraph(). + if (!edge->IsControlEdge()) { + TF_RETURN_IF_ERROR(src_subgraph.RecordResult(edge, node_images)); + } + } + + // Record 'dst' as an input of its subgraph, if applicable. + if (IsInSubgraph(dst_func_id)) { + Subgraph& dst_subgraph = subgraphs_[dst_func_id]; + // Ignore control edges entering the subgraph. We will lift them onto + // the enclosing call operators in BuildOutputGraph(). + if (!edge->IsControlEdge()) { + TF_RETURN_IF_ERROR( + dst_subgraph.RecordArg(edge, node_images, src_arg_pairs)); + } + } + } + return Status::OK(); +} + +Status Encapsulator::SplitIntoSubgraphs() { + Status s; + + // Map from input graph nodes to subgraph nodes. std::unordered_map node_images; - // Copy all unmarked nodes to the output graph. + // Each entry of src_arg_pairs is a pair whose first element is a node in the + // original graph that has an output edge in the subgraph, and whose second + // element is the arg node in the subgraph that it sends to. The vector will + // be filled in below in AddArgs. + std::vector> src_arg_pairs; + + TF_RETURN_IF_ERROR(CopySubgraphNodes(&node_images)); + TF_RETURN_IF_ERROR(CopySubgraphEdges(node_images, &src_arg_pairs)); + + MarkGuaranteedConstants(*graph_in_, src_arg_pairs); + + for (auto& entry : subgraphs_) { + Subgraph& subgraph = entry.second; + FixupSourceAndSinkEdges(subgraph.GetGraph()); + } + + return s; +} + +Status Encapsulator::BuildFunctionDefs( + const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions, + FunctionLibraryDefinition* library) { + for (auto& subgraph_entry : subgraphs_) { + string name = subgraph_entry.first; + Subgraph& subgraph = subgraph_entry.second; + TF_RETURN_IF_ERROR(subgraph.BuildFunctionDef( + name, rewrite_subgraph_fn, reuse_existing_functions, library)); + } + return Status::OK(); +} + +Status Encapsulator::CopyNodesToOutputGraph( + bool parallel_checking, Graph* graph_out, + std::unordered_map* node_images) { for (Node* node : graph_in_->op_nodes()) { - string func_id = GetFunctionNameAttr(node); + string func_id; + TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &func_id)); - // Don't copy nodes that going to be encapsulated, unless parallel checking - // is enabled. - if (!func_id.empty() && !parallel_checking) continue; + // Don't copy nodes that are going to be encapsulated, unless parallel + // checking is enabled. + if (IsInSubgraph(func_id) && !parallel_checking) continue; Node* image = graph_out->CopyNode(node); - node_images[node] = image; + (*node_images)[node] = image; } - node_images[graph_in_->source_node()] = graph_out->source_node(); - node_images[graph_in_->sink_node()] = graph_out->sink_node(); + (*node_images)[graph_in_->source_node()] = graph_out->source_node(); + (*node_images)[graph_in_->sink_node()] = graph_out->sink_node(); + return Status::OK(); +} - // Add function call nodes for each subgraph. +Status Encapsulator::AddFunctionCallNodes( + const std::unordered_map& node_images, + bool parallel_checking, Graph* graph_out) { for (auto& subgraph_entry : subgraphs_) { - Subgraph& subgraph = subgraph_entry.second; + TF_RETURN_IF_ERROR(subgraph_entry.second.AddFunctionCallNode( + node_images, parallel_checking, graph_out)); + } + return Status::OK(); +} - subgraph.call_node_inputs = graph_out->AddNode(subgraph.call_node_def, &s); - if (!s.ok()) return s; +Status Encapsulator::FindOutputImageOfEdgeSrc( + const string& src_func_id, const string& dst_func_id, + const std::unordered_map& node_images, + const Node* original_src_node, Node** src_image) { + if (IsInSubgraph(src_func_id)) { + // The edge is from a subgraph to a regular node in the output graph so + // use the subgraph's call node output. + *src_image = subgraphs_.at(src_func_id).GetCallNodeForOutputs(); + } else { + // The source of the edge is in the output graph so use the node image in + // the output graph. + *src_image = node_images.at(original_src_node); + } + return Status::OK(); +} + +int Encapsulator::FindOutputSlotOfEdgeSrc(const string& src_func_id, + const string& dst_func_id, + const Edge* edge) { + if (IsInSubgraph(src_func_id)) { + const Subgraph& src_subgraph = subgraphs_.at(src_func_id); + // 'src' is in a subgraph and 'dst' is a regular node in the output + // graph. Use the corresponding call output instead. + return src_subgraph.GetResultIndexForEdge(edge); + } else { + // The source of the edge is in the output graph so use the regular edge + // slot. + return edge->src_output(); + } +} - // Copy the assigned device and the key_annotation over. - subgraph.call_node_inputs->set_assigned_device_name(subgraph.device); - subgraph.call_node_outputs = subgraph.call_node_inputs; +Status Encapsulator::FindOutputImageOfEdgeDst( + const string& src_func_id, const string& dst_func_id, + const std::unordered_map& node_images, + const Node* original_dst_node, Node** dst_image) { + if (IsInSubgraph(dst_func_id)) { + // The edge is to a subgraph from a regular node in the output graph so + // use the subgraph's call node input. + *dst_image = subgraphs_.at(dst_func_id).GetCallNodeForInputs(); + } else { + // The destination of the edge is in the output graph so use the node image + // in the output graph. + *dst_image = node_images.at(original_dst_node); + } + return Status::OK(); +} +int Encapsulator::FindOutputSlotOfEdgeDst(const string& src_func_id, + const string& dst_func_id, + const Edge* edge) { + if (IsInSubgraph(dst_func_id)) { + const Subgraph& dst_subgraph = subgraphs_.at(dst_func_id); + // 'dst' is in a subgraph and 'src' is a regular node in the output + // graph. Use the corresponding call input instead. + return dst_subgraph.GetArgIndexForEdge(edge); + } else { + // The destination of the edge is in the output graph so use the regular + // edge slot. + return edge->dst_input(); + } +} + +Status Encapsulator::CopyEdgeToOutputGraph( + const Edge* edge, const string& src_func_id, const string& dst_func_id, + const std::unordered_map& node_images, + bool parallel_checking, Graph* graph_out, + std::unordered_set, NodeSlot::PairHasher>* + edges_added) { + Node* src_image; + TF_RETURN_IF_ERROR(FindOutputImageOfEdgeSrc( + src_func_id, dst_func_id, node_images, edge->src(), &src_image)); + Node* dst_image; + TF_RETURN_IF_ERROR(FindOutputImageOfEdgeDst( + src_func_id, dst_func_id, node_images, edge->dst(), &dst_image)); + + // If this is a control edge then copy it and return. Lift control edges onto + // the enclosing call operator. + if (edge->IsControlEdge()) { + // Add the control edge, if we have not already added it, using the images + // determined above (potentially call operators or RecvAtHost/SendFromHost). + if (edges_added->emplace(NodeSlot(src_image, -1), NodeSlot(dst_image, -1)) + .second) { + graph_out->AddControlEdge(src_image, dst_image); + } + + // If parallel checking is enabled, also add a control edge to the + // corresponding parallel check op. if (parallel_checking) { - TF_RETURN_IF_ERROR(BuildParallelCheckOp(node_images, subgraph, graph_out, - &subgraph.call_node_outputs)); + graph_out->AddControlEdge(src_image, node_images.at(edge->dst())); } + return Status::OK(); + } + + int src_output = FindOutputSlotOfEdgeSrc(src_func_id, dst_func_id, edge); + + int dst_input = FindOutputSlotOfEdgeDst(src_func_id, dst_func_id, edge); + + if (IsInSubgraph(dst_func_id) && parallel_checking) { + // If we are parallel checking, also feed the tensor as an input to the + // corresponding parallel check subgraph. + graph_out->AddEdge(src_image, src_output, node_images.at(edge->dst()), + edge->dst_input()); } + // Add the edge, if we have not already added it. + if (edges_added + ->emplace(NodeSlot(src_image, src_output), + NodeSlot(dst_image, dst_input)) + .second) { + graph_out->AddEdge(src_image, src_output, dst_image, dst_input); + } + return Status::OK(); +} + +Status Encapsulator::AddEdgesToOutputGraph( + const std::unordered_map& node_images, + bool parallel_checking, Graph* graph_out) { // Set of edges already added to the output graph, represented as (src, dst) // pairs. We use the set to deduplicate edges; multiple edges in the input // graph may map to one edge in the output graph. std::unordered_set, NodeSlot::PairHasher> edges_added; - // Add edges to the graph_out graph. for (const Edge* edge : graph_in_->edges()) { - string src_func_id = GetFunctionNameAttr(edge->src()); - string dst_func_id = GetFunctionNameAttr(edge->dst()); + string src_func_id; + TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id)); + string dst_func_id; + TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id)); // Ignore edges that are strictly contained within one subgraph, unless // we are constructing parallel check graphs. - if (!src_func_id.empty() && src_func_id == dst_func_id) { + if (IsInSubgraph(src_func_id) && IsInSubgraph(dst_func_id) && + src_func_id == dst_func_id) { if (parallel_checking) { Node* src_image = node_images.at(edge->src()); Node* dst_image = node_images.at(edge->dst()); @@ -493,63 +860,29 @@ Status Encapsulator::BuildOutputGraph(bool parallel_checking, continue; } - // We have an edge that crosses a cluster boundary. - Node* src_image = src_func_id.empty() - ? node_images.at(edge->src()) - : subgraphs_.at(src_func_id).call_node_outputs; - Node* dst_image = dst_func_id.empty() - ? node_images.at(edge->dst()) - : subgraphs_.at(dst_func_id).call_node_inputs; - - // Copy control edges. Lift control edges onto the enclosing call operator. - if (edge->IsControlEdge()) { - // Add the control edge, if we have not already added it. - if (edges_added.emplace(NodeSlot(src_image, -1), NodeSlot(dst_image, -1)) - .second) { - graph_out->AddControlEdge(src_image, dst_image); - } - - // If parallel checking is enabled, also add a control edge to the - // corresponding parallel check op. - if (parallel_checking) { - graph_out->AddControlEdge(src_image, node_images.at(edge->dst())); - } - continue; - } - - int src_output = edge->src_output(); - if (!src_func_id.empty()) { - // 'src' is in a subgraph. Use the corresponding call output instead. - const Subgraph& src_subgraph = subgraphs_.at(src_func_id); - src_output = - src_subgraph.results.at(NodeSlot(edge->src(), edge->src_output())); - } + // We have an edge that crosses a cluster boundary or is entirely within the + // unclustered graph. + TF_RETURN_IF_ERROR(CopyEdgeToOutputGraph(edge, src_func_id, dst_func_id, + node_images, parallel_checking, + graph_out, &edges_added)); + } - int dst_input = edge->dst_input(); + return Status::OK(); +} - if (!dst_func_id.empty()) { - // 'dst' is in a subgraph. Use the corresponding call input instead. - const Subgraph& dst_subgraph = subgraphs_.at(dst_func_id); - dst_input = - dst_subgraph.args_by_dst.at(NodeSlot(edge->dst(), edge->dst_input())); +Status Encapsulator::BuildOutputGraph(bool parallel_checking, + Graph* graph_out) { + // Map from nodes in the input graph to nodes in the output graph. + std::unordered_map node_images; - // If we are parallel checking, also feed the tensor as an input to the - // corresponding parallel check subgraph. - if (parallel_checking) { - graph_out->AddEdge(src_image, src_output, node_images.at(edge->dst()), - edge->dst_input()); - } - } - // Add the edge, if we have not already added it. - if (edges_added - .emplace(NodeSlot(src_image, src_output), - NodeSlot(dst_image, dst_input)) - .second) { - graph_out->AddEdge(src_image, src_output, dst_image, dst_input); - } - } + TF_RETURN_IF_ERROR( + CopyNodesToOutputGraph(parallel_checking, graph_out, &node_images)); + TF_RETURN_IF_ERROR( + AddFunctionCallNodes(node_images, parallel_checking, graph_out)); + TF_RETURN_IF_ERROR( + AddEdgesToOutputGraph(node_images, parallel_checking, graph_out)); - return s; + return Status::OK(); } } // anonymous namespace @@ -562,20 +895,18 @@ Status EncapsulateSubgraphsInFunctions( Status s; Encapsulator encapsulator(std::move(group_attribute), &graph_in); - s = encapsulator.SplitIntoSubgraphs(); - if (!s.ok()) return s; + TF_RETURN_IF_ERROR(encapsulator.SplitIntoSubgraphs()); - s = encapsulator.BuildFunctionDefs(rewrite_subgraph_fn, - reuse_existing_functions, library); - if (!s.ok()) return s; + TF_RETURN_IF_ERROR(encapsulator.BuildFunctionDefs( + rewrite_subgraph_fn, reuse_existing_functions, library)); std::unique_ptr out(new Graph(library)); out->set_versions(graph_in.versions()); - s = encapsulator.BuildOutputGraph(parallel_checking, out.get()); - if (!s.ok()) return s; + TF_RETURN_IF_ERROR( + encapsulator.BuildOutputGraph(parallel_checking, out.get())); *graph_out = std::move(out); - return s; + return Status::OK(); } // Finds the types of the _Arg nodes, indexed by position. @@ -691,8 +1022,8 @@ Status EncapsulateSubgraphsPass::Run( TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions( kXlaClusterAttr, **options.graph, rewrite_subgraph, - flags->tf_xla_parallel_checking, /*reuse_existing_functions=*/false, - &graph_out, library)); + flags->tf_xla_parallel_checking, + /*reuse_existing_functions=*/false, &graph_out, library)); if (VLOG_IS_ON(1)) { dump_graph::DumpGraphToFile("after_encapsulate_subgraphs", *graph_out, diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 4a1dbaf05dc7824835f3567c6abcf48222720230..717efb360185f1ce26ee1e9adb0ee5bf7f4799f8 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -398,5 +398,109 @@ TEST(EncapsulateSubgraphsTest, ParallelChecking) { EXPECT_EQ(expected_edges, GraphEdges(*graph)); } +const Node* FindNodeByName(const Graph& graph, const string& name) { + for (const Node* node : graph.nodes()) { + if (node->name() == name) return node; + } + return nullptr; +} + +bool HasGuaranteeConstAttr(const Node& n) { + bool is_guaranteed_constant = false; + if (!GetNodeAttr(n.attrs(), "_is_guaranteed_constant", + &is_guaranteed_constant) + .ok()) { + return false; + } + return is_guaranteed_constant; +} + +TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) { + Scope root = Scope::NewRootScope().ExitOnError().WithDevice( + "/job:localhost/replica:0/task:0/cpu:0"); + auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT); + auto const_x2 = ops::Const(root.WithOpName("const_x2"), 10.0f); + auto const_guarantee_x1 = + ops::GuaranteeConst(root.WithOpName("const_guarantee_x1"), x1); + auto add1 = ops::Add(root.WithOpName("add1"), const_guarantee_x1, const_x2); + add1.node()->AddAttr("_encapsulate", "encapsulate1"); + + Graph graph_before(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(&graph_before)); + + std::unique_ptr graph_after; + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + int guaranteed_consts = 0; + TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( + "_encapsulate", graph_before, + /*rewrite_subgraph_fn=*/ + [&guaranteed_consts](std::unique_ptr* graph_ptr, + std::vector* input_permutation, + std::vector* output_permutation, + NodeDef* call_def) { + Graph* graph = graph_ptr->get(); + for (const Node* n : graph->nodes()) { + if (n->type_string() == "_Arg" && + StringPiece(n->name()).starts_with("const")) { + ++guaranteed_consts; + EXPECT_TRUE(HasGuaranteeConstAttr(*n)); + } else { + EXPECT_FALSE(HasGuaranteeConstAttr(*n)); + } + } + return Status::OK(); + }, + /*parallel_checking=*/false, + /*reuse_existing_functions=*/false, &graph_after, &library)); + EXPECT_EQ(2, guaranteed_consts); +} + +TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) { + Scope root = Scope::NewRootScope().ExitOnError().WithDevice( + "/job:localhost/replica:0/task:0/cpu:0"); + auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT); + auto x2 = ops::Placeholder(root.WithOpName("x2"), DT_FLOAT); + auto const_guarantee_x1 = + ops::GuaranteeConst(root.WithOpName("const_guarantee_x1"), x1); + auto const_guarantee_x2 = + ops::GuaranteeConst(root.WithOpName("const_guarantee_x2"), x2); + auto const_guarantee_add1 = ops::Add(root.WithOpName("const_guarantee_add1"), + const_guarantee_x1, const_guarantee_x2); + auto add2 = ops::Add(root.WithOpName("add2"), const_guarantee_x1, x2); + auto mul1 = ops::Mul(root.WithOpName("mul1"), const_guarantee_add1, add2); + mul1.node()->AddAttr("_encapsulate", "encapsulate1"); + + Graph graph_before(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(&graph_before)); + + std::unique_ptr graph_after; + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + int guaranteed_consts = 0; + TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( + "_encapsulate", graph_before, + /*rewrite_subgraph_fn=*/ + [&guaranteed_consts](std::unique_ptr* graph_ptr, + std::vector* input_permutation, + std::vector* output_permutation, + NodeDef* call_def) { + Graph* graph = graph_ptr->get(); + for (const Node* n : graph->nodes()) { + if (n->type_string() == "_Arg" && + StringPiece(n->name()).starts_with("const")) { + ++guaranteed_consts; + EXPECT_TRUE(HasGuaranteeConstAttr(*n)); + } else { + EXPECT_FALSE(HasGuaranteeConstAttr(*n)); + } + } + return Status::OK(); + }, + /*parallel_checking=*/false, + /*reuse_existing_functions=*/false, &graph_after, &library)); + // Only 1 runtime const, which is const_guarantee_add1. Add2 has one const + // and another non-const, so overall non-const. + EXPECT_EQ(1, guaranteed_consts); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 459a582e157f5ddc63997ca93e7c0294293517d3..9bea5663319c8a25249fdc265cee0191556a7c04 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -16,7 +16,6 @@ cc_library( "//tensorflow/compiler/jit:xla_device", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/tf2xla:xla_local_runtime_context", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index e481796d9e626fc8cdf36687ad110b0a8a788be0..4f3f17df9c680c63546d17dcc5a2775a1014f6c3 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -103,7 +102,6 @@ xla::StatusOr XlaAllocator::Allocate( } void* data = reinterpret_cast(const_cast(t.tensor_data().data())); - TF_RET_CHECK(data != nullptr); tensors_[data] = t; return gpu::DeviceMemoryBase(data, size); } @@ -111,7 +109,6 @@ xla::StatusOr XlaAllocator::Allocate( Status XlaAllocator::RegisterArgument(const Tensor* t) { void* data = reinterpret_cast(const_cast(t->tensor_data().data())); - TF_RET_CHECK(data != nullptr); tensors_[data] = *t; return Status::OK(); } @@ -267,7 +264,6 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { // Builds an XLA allocator for the device. XlaAllocator xla_allocator(client->platform(), ctx); - XlaLocalRuntimeContext local_runtime_context; std::unique_ptr output; // Build xla::ShapedBuffers that point directly to the Tensor buffers. @@ -291,27 +287,22 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { gpu::DeviceMemoryBase dmem = gpu::DeviceMemoryBase( const_cast(t->tensor_data().data()), t->tensor_data().size()); - arg_buffers[i] = - xla::ShapedBuffer::MakeArrayShapedBuffer( - shape, client->platform(), client->default_device_ordinal(), dmem) - .ConsumeValueOrDie(); + const xla::Shape on_device_shape = + client->backend().transfer_manager()->HostShapeToDeviceShape(shape); + CHECK(xla::ShapeUtil::Equal(shape, on_device_shape)) + << "On-device shape " + << xla::ShapeUtil::HumanStringWithLayout(on_device_shape) + << " not the same as on-host shape " + << xla::ShapeUtil::HumanStringWithLayout(shape); + arg_buffers[i] = xla::MakeUnique( + /*on_host_shape=*/shape, /*on_device_shape=*/shape, client->platform(), + client->default_device_ordinal()); + arg_buffers[i]->set_buffer(dmem, /*index=*/{}); arg_ptrs[i] = arg_buffers[i].get(); OP_REQUIRES_OK(ctx, xla_allocator.RegisterArgument(t)); } - // Make the final parameter point at local_runtime_context. - if (kernel->requires_runtime_context) { - gpu::DeviceMemoryBase local_runtime_context_dmem( - &local_runtime_context, sizeof(local_runtime_context)); - arg_buffers.push_back( - xla::ShapedBuffer::MakeArrayShapedBuffer( - xla::ShapeUtil::MakeOpaqueShape(), client->platform(), - client->default_device_ordinal(), local_runtime_context_dmem) - .ConsumeValueOrDie()); - arg_ptrs.push_back(arg_buffers.back().get()); - } - // Execute the computation. VLOG(2) << "Executing computation."; xla::ExecutableRunOptions run_options; @@ -323,19 +314,13 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { auto run_result = executable->Run(arg_ptrs, run_options); OP_REQUIRES(ctx, run_result.ok(), run_result.status()); - if (local_runtime_context.error) { - ctx->CtxFailure(errors::InvalidArgument("Compiled kernel returned error: ", - local_runtime_context.error_msg)); - return; - } - output = run_result.ConsumeValueOrDie()->release(); auto elapsed = env->NowMicros() - start_time; VLOG(2) << "Elapsed time: " << elapsed << "us"; // Computation output should always be a tuple. if (VLOG_IS_ON(2)) { - VLOG(2) << "Result tuple shape: " << output->shape().DebugString(); + VLOG(2) << "Result tuple shape: " << output->on_host_shape().DebugString(); } CHECK_EQ(ctx->num_outputs(), kernel->outputs.size()); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 74c9791f5eaf1fbc43b152520df496a3b552af18..1f311a3aedbf7711ce6a081671f5848a81f2bd85 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -172,10 +172,15 @@ bool HasResourceInputOrOutput(const Node& node) { DT_RESOURCE) != node.output_types().end(); } +struct NodeCompare { + bool operator()(const Node* a, const Node* b) { return a->id() < b->id(); } +}; +using OrderedNodeSet = std::set; + Status FindCompilationCandidates( const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env, const std::function& is_compilable_fn, - std::unordered_set* candidates) { + OrderedNodeSet* candidates) { OptimizerOptions opts; std::unique_ptr pflr( new ProcessFunctionLibraryRuntime(nullptr, env, TF_GRAPH_DEF_VERSION, @@ -210,6 +215,13 @@ Status FindCompilationCandidates( !IsCompilableWhile(*node, jit_device_type, 0, lib_runtime)) { continue; } + // _Retval nodes in a top-level function represent fetches. + // Do not compile them. + if (node->type_string() == "_Retval") { + VLOG(2) << "Compilation rejected node: return value " << node->name() + << ": " << node->type_string(); + continue; + } candidates->insert(node); } return Status::OK(); @@ -347,7 +359,7 @@ Status MarkForCompilationPass::RunImpl( Graph* graph = options.graph->get(); - std::unordered_set compilation_candidates; + OrderedNodeSet compilation_candidates; TF_RETURN_IF_ERROR(FindCompilationCandidates( *graph, options.flib_def, (options.session_options != nullptr) ? options.session_options->env diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index b3d258aea177fbefa4bae51d8156da2ff86c9032..454f0aeae98d7afd51f12b2cfb1810de275a57f7 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -525,5 +525,32 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { "+-- c\n")); } +TEST(XlaCompilationTest, Retval) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + GraphDef graphdef; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor())); + Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); + ops::UnaryOp("_Retval", b, + builder.opts() + .WithName("R") + .WithAttr("T", DT_FLOAT) + .WithAttr("index", 0)); + + TF_EXPECT_OK(builder.ToGraph(graph.get())); + } + + TF_ASSERT_OK(MarkForCompilation(&graph)); + auto clusters = GetClusters(*graph); + + EXPECT_EQ(2, clusters.size()); + EXPECT_TRUE(clusters.find("R") == clusters.cend()); + EXPECT_EQ(clusters["A"], clusters["B"]); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index bc2eccd2779b9ff68ae2121f7bc53d6f74aec3e3..3717c2cc24283e0b218f92ec820d16893cbe0c35 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -214,17 +214,12 @@ Status XlaCompilationCache::BuildExecutable( const XlaCompiler::CompilationResult& result, std::unique_ptr* executable) { VLOG(2) << "Compiling to local executable"; - xla::Shape opaque_shape = xla::ShapeUtil::MakeOpaqueShape(); std::vector argument_layouts( result.xla_input_shapes.size()); for (int i = 0; i < result.xla_input_shapes.size(); ++i) { argument_layouts[i] = &result.xla_input_shapes[i]; } - if (result.requires_runtime_context) { - // The final arg is the XlaLocalRuntimeContext*. - argument_layouts.push_back(&opaque_shape); - } xla::ExecutableBuildOptions build_options; build_options.set_device_ordinal(client_->default_device_ordinal()); build_options.set_result_layout(result.xla_output_shape); diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index fed2c92d763c33aad3c5b3f07c1f33364c797793..c936222f32056e92efced82d5adb3a96c8041a17 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -71,12 +71,14 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, void* dst_ptr = DMAHelper::base(device_tensor); se::DeviceMemoryBase dev_dst_ptr(dst_ptr, total_bytes); - Status status = Status::OK(); + Status status; stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes); // TODO(hpucha): Make this asynchronous. - if (!stream_->BlockHostUntilDone()) { + Status block_status = stream_->BlockHostUntilDone(); + if (!block_status.ok()) { status = xla::InternalError( - "Failed to complete data transfer on stream %p", stream_); + "Failed to complete data transfer on stream %p: %s", stream_, + block_status.error_message().c_str()); } done(status); @@ -105,12 +107,14 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, se::DeviceMemoryBase dev_src_ptr(src_ptr, total_bytes); void* dst_ptr = DMAHelper::base(cpu_tensor); - Status status = Status::OK(); + Status status; stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes); // TODO(hpucha): Make this asynchronous. - if (!stream_->BlockHostUntilDone()) { + Status block_status = stream_->BlockHostUntilDone(); + if (!block_status.ok()) { status = xla::InternalError( - "Failed to complete data transfer on stream %p", stream_); + "Failed to complete data transfer on stream %p: %s", stream_, + block_status.error_message().c_str()); } done(status); diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 79c4befd3671e1da3fd67e644eb733d2503f9a8b..4f458ecff8f6523a23ca59e0cecb485a7988efad 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -279,6 +279,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "image_ops_test", + size = "small", + srcs = ["image_ops_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:image_ops", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "lrn_ops_test", size = "medium", @@ -367,7 +380,15 @@ tf_xla_py_test( size = "small", srcs = ["random_ops_test.py"], # TODO(b/31361304): enable RNG ops on GPU when parallelized. - disabled_backends = ["gpu"], + disabled_backends = [ + "gpu", + "cpu", + ], + tags = [ + "manual", + "no_oss", + "notap", + ], deps = [ ":xla_test", "//tensorflow/python:framework_for_generated_wrappers", @@ -416,6 +437,20 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "scan_ops_test", + size = "small", + srcs = ["scan_ops_test.py"], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "segment_reduction_ops_test", size = "medium", @@ -457,6 +492,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "stateless_random_ops_test", + size = "small", + srcs = ["stateless_random_ops_test.py"], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/contrib/stateless", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "tensor_array_ops_test", size = "small", diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 654dc15e86b21c7742d49281d53c1a75e6a45d3b..65706b35d616eb4dce94f0a7056a1604a97ff4c1 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -94,14 +94,12 @@ class BinaryOpsTest(XLATestCase): dtype(4), expected=np.array([[16], [81]], dtype=dtype)) - atan2_supported = self.device == "XLA_GPU" - if atan2_supported: - self._testBinary( - math_ops.atan2, - np.array([0, np.sqrt(2), 1, np.sqrt(2), 0], dtype), - np.array([1, np.sqrt(2), 0, -np.sqrt(2), -1], dtype), - expected=np.array( - [0, np.pi / 4, np.pi / 2, np.pi * 3 / 4, np.pi], dtype=dtype)) + self._testBinary( + math_ops.atan2, + np.array([0, np.sqrt(2), 1, np.sqrt(2), 0], dtype), + np.array([1, np.sqrt(2), 0, -np.sqrt(2), -1], dtype), + expected=np.array( + [0, np.pi / 4, np.pi / 2, np.pi * 3 / 4, np.pi], dtype=dtype)) self._testBinary( gen_math_ops._reciprocal_grad, @@ -388,30 +386,28 @@ class BinaryOpsTest(XLATestCase): ], dtype=dtype)) - atan2_supported = self.device == "XLA_GPU" - if atan2_supported: - self._testBinary( - math_ops.pow, - dtype(3 + 2j), - dtype(4 - 5j), - expected=np.power(dtype(3 + 2j), dtype(4 - 5j))) - self._testBinary( # empty rhs - math_ops.pow, - np.array([1 + 2j, 2 - 3j], dtype=dtype), - np.zeros(shape=[0, 2], dtype=dtype), - expected=np.zeros(shape=[0, 2], dtype=dtype)) - self._testBinary( # to zero power - math_ops.pow, - np.array([1 + 2j, 2 - 3j], dtype=dtype), - np.zeros(shape=[1, 2], dtype=dtype), - expected=np.ones(shape=[1, 2], dtype=dtype)) - lhs = np.array([1 - 2j, 4 + 3j, 2 - 3j, 3, 2j, 1, 4], dtype=dtype) - rhs = np.array([2, 3j, 3 + 4j, 2 + 3j, 3 - 2j, 2, 3 + 3j], dtype=dtype) - scalar = dtype(2 + 2j) - self._testBinary(math_ops.pow, lhs, rhs, expected=np.power(lhs, rhs)) - self._testBinary( - math_ops.pow, scalar, rhs, expected=np.power(scalar, rhs)) - self._testBinary(math_ops.pow, lhs, scalar, np.power(lhs, scalar)) + self._testBinary( + math_ops.pow, + dtype(3 + 2j), + dtype(4 - 5j), + expected=np.power(dtype(3 + 2j), dtype(4 - 5j))) + self._testBinary( # empty rhs + math_ops.pow, + np.array([1 + 2j, 2 - 3j], dtype=dtype), + np.zeros(shape=[0, 2], dtype=dtype), + expected=np.zeros(shape=[0, 2], dtype=dtype)) + self._testBinary( # to zero power + math_ops.pow, + np.array([1 + 2j, 2 - 3j], dtype=dtype), + np.zeros(shape=[1, 2], dtype=dtype), + expected=np.ones(shape=[1, 2], dtype=dtype)) + lhs = np.array([1 - 2j, 4 + 3j, 2 - 3j, 3, 2j, 1, 4], dtype=dtype) + rhs = np.array([2, 3j, 3 + 4j, 2 + 3j, 3 - 2j, 2, 3 + 3j], dtype=dtype) + scalar = dtype(2 + 2j) + self._testBinary(math_ops.pow, lhs, rhs, expected=np.power(lhs, rhs)) + self._testBinary( + math_ops.pow, scalar, rhs, expected=np.power(scalar, rhs)) + self._testBinary(math_ops.pow, lhs, scalar, np.power(lhs, scalar)) lhs = np.array([4 + 2j, -3 - 1j, 2j, 1], dtype=dtype) rhs = np.array([5, -6j, 7 - 3j, -8j], dtype=dtype) @@ -421,9 +417,8 @@ class BinaryOpsTest(XLATestCase): self._testBinary( gen_math_ops._sigmoid_grad, lhs, rhs, expected=rhs * lhs * (1 - lhs)) - if atan2_supported: - self._testBinary( - gen_math_ops._rsqrt_grad, lhs, rhs, expected=lhs**3 * rhs / -2) + self._testBinary( + gen_math_ops._rsqrt_grad, lhs, rhs, expected=lhs**3 * rhs / -2) self._testBinary( gen_math_ops._sqrt_grad, lhs, rhs, expected=rhs / (2 * lhs)) @@ -547,7 +542,7 @@ class BinaryOpsTest(XLATestCase): self._testDivision(dtype) def testFloatDivision(self): - for dtype in self.float_types + self.complex_types: + for dtype in self.float_types | self.complex_types: self._testDivision(dtype) def _testRemainder(self, dtype): diff --git a/tensorflow/compiler/tests/categorical_op_test.py b/tensorflow/compiler/tests/categorical_op_test.py index 5e06f9a72401935b9681c35a164b51f50a8538ae..035cdea1786d39f3d21bb63be5c8ccffe1608bdf 100644 --- a/tensorflow/compiler/tests/categorical_op_test.py +++ b/tensorflow/compiler/tests/categorical_op_test.py @@ -35,6 +35,9 @@ from tensorflow.python.platform import googletest class CategoricalTest(XLATestCase): """Test cases for random-number generating operators.""" + def output_dtypes(self): + return set(self.int_types).intersection([np.int32, np.int64]) + def _chi2(self, expected, actual): """Returns Chi2 GOF statistic.""" actual = np.asarray(actual) @@ -55,7 +58,8 @@ class CategoricalTest(XLATestCase): """ with self.test_session() as sess, self.test_scope(): random_seed.set_random_seed(1618) - op = random_ops.multinomial(logits, num_samples) + op = random_ops.multinomial(logits, num_samples, + output_dtype=dtypes.int32) d = sess.run(op) batch_size, num_classes = logits.shape @@ -73,11 +77,11 @@ class CategoricalTest(XLATestCase): return freqs_mat - def _testRngIsNotConstant(self, rng, dtype): + def _testRngIsNotConstant(self, rng, dtype, output_dtype): # Tests that 'rng' does not always return the same value. with self.test_session() as sess: with self.test_scope(): - x = rng(dtype) + x = rng(dtype, output_dtype) # The random-number generator, if working correctly, should produce the # same output multiple times with low probability. @@ -92,21 +96,25 @@ class CategoricalTest(XLATestCase): (not np.array_equal(y, w))) def testCategoricalIsNotConstant(self): - def rng(unused_dtype): - return random_ops.multinomial([[1., 1., 1.]], 10) + def rng(dtype, output_dtype): + return random_ops.multinomial(np.array([[1., 1., 1.]], dtype=dtype), 10, + output_dtype=output_dtype) - dtype = dtypes.float32 - self._testRngIsNotConstant(rng, dtype) + dtype = np.float32 + for output_dtype in self.output_dtypes(): + self._testRngIsNotConstant(rng, dtype, output_dtype) def testCategoricalIsInRange(self): - for dtype in [dtypes.float32, dtypes.float64]: - with self.test_session() as sess: - with self.test_scope(): - x = random_ops.multinomial( - array_ops.ones(shape=[1, 20], dtype=dtype), 1000) - y = sess.run(x) - self.assertTrue((y >= 0).sum() == 1000) - self.assertTrue((y < 20).sum() == 1000) + for dtype in self.float_types: + for output_dtype in self.output_dtypes(): + with self.test_session() as sess: + with self.test_scope(): + x = random_ops.multinomial( + array_ops.ones(shape=[1, 20], dtype=dtype), 1000, + output_dtype=output_dtype) + y = sess.run(x) + self.assertTrue((y >= 0).sum() == 1000) + self.assertTrue((y < 20).sum() == 1000) def testSamplingCorrectness(self): np.random.seed(1618) # Make it reproducible. diff --git a/tensorflow/compiler/tests/ftrl_test.py b/tensorflow/compiler/tests/ftrl_test.py index 7e3871312c86530b6d3cb0bbacc16c25d3469832..f9db4cf2017c0b4b6dc0cfeeda6dca7bb9d14f19 100644 --- a/tensorflow/compiler/tests/ftrl_test.py +++ b/tensorflow/compiler/tests/ftrl_test.py @@ -161,9 +161,9 @@ class FtrlOptimizerTest(XLATestCase): ftrl_update.run() # Validate updated params - self.assertAllClose( + self.assertAllCloseAccordingToType( np.array([-2.55607247, -3.98729396]), var0.eval(), 1e-5, 1e-5) - self.assertAllClose( + self.assertAllCloseAccordingToType( np.array([-0.28232238, -0.56096673]), var1.eval(), 1e-5, 1e-5) def testFtrlWithL1(self): @@ -189,10 +189,10 @@ class FtrlOptimizerTest(XLATestCase): ftrl_update.run() # Validate updated params - self.assertAllClose(np.array([-7.66718769, -10.91273689]), var0.eval(), - rtol=1e-4) - self.assertAllClose(np.array([-0.93460727, -1.86147261]), var1.eval(), - rtol=1e-4) + self.assertAllCloseAccordingToType( + np.array([-7.66718769, -10.91273689]), var0.eval(), rtol=1e-4) + self.assertAllCloseAccordingToType( + np.array([-0.93460727, -1.86147261]), var1.eval(), rtol=1e-4) def testFtrlWithL1_L2(self): for dtype in self.float_types: @@ -217,10 +217,10 @@ class FtrlOptimizerTest(XLATestCase): ftrl_update.run() # Validate updated params - self.assertAllClose(np.array([-0.24059935, -0.46829352]), var0.eval(), - rtol=1e-5) - self.assertAllClose(np.array([-0.02406147, -0.04830509]), var1.eval(), - rtol=1e-5) + self.assertAllCloseAccordingToType( + np.array([-0.24059935, -0.46829352]), var0.eval(), rtol=1e-5) + self.assertAllCloseAccordingToType( + np.array([-0.02406147, -0.04830509]), var1.eval(), rtol=1e-5) def testFtrlWithL1_L2_L2Shrinkage(self): """Test the new FTRL op with support for l2 shrinkage. @@ -244,18 +244,18 @@ class FtrlOptimizerTest(XLATestCase): ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) + self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval()) # Run 10 steps FTRL for _ in range(10): ftrl_update.run() # Validate updated params - self.assertAllClose(np.array([-0.21931979, -0.40642974]), var0.eval(), - rtol=1e-4) - self.assertAllClose(np.array([-0.0282721, -0.07188385]), var1.eval(), - rtol=1e-4) + self.assertAllCloseAccordingToType( + np.array([-0.21931979, -0.40642974]), var0.eval(), rtol=1e-4) + self.assertAllCloseAccordingToType( + np.array([-0.0282721, -0.07188385]), var1.eval(), rtol=1e-4) # When variables are initialized with Zero, FTRL-Proximal has two properties: # 1. Without L1&L2 but with fixed learning rate, FTRL-Proximal is identical @@ -272,8 +272,8 @@ class FtrlOptimizerTest(XLATestCase): with self.test_session(), self.test_scope(): val2, val3 = self.equivAdagradTest_AdagradPart(steps, dtype) - self.assertAllClose(val0, val2, rtol=1e-4) - self.assertAllClose(val1, val3, rtol=1e-4) + self.assertAllCloseAccordingToType(val0, val2, rtol=1e-4) + self.assertAllCloseAccordingToType(val1, val3, rtol=1e-4) def testEquivGradientDescentwithoutRegularization(self): steps = 5 @@ -284,8 +284,8 @@ class FtrlOptimizerTest(XLATestCase): val2, val3 = self.equivGradientDescentTest_GradientDescentPart( steps, dtype) - self.assertAllClose(val0, val2, rtol=1e-5) - self.assertAllClose(val1, val3, rtol=1e-5) + self.assertAllCloseAccordingToType(val0, val2, rtol=1e-5) + self.assertAllCloseAccordingToType(val1, val3, rtol=1e-5) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py index cbe2888696c87c6c2f50c3de71e8531977ea395a..11d8a99ffe1a136a54b16e20f1792062203f7969 100644 --- a/tensorflow/compiler/tests/function_test.py +++ b/tensorflow/compiler/tests/function_test.py @@ -24,10 +24,12 @@ from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest +@test_util.with_c_api class FunctionTest(XLATestCase): def testFunction(self): diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py index a773b5a94742062511bc8bdc6a202b513ce98db3..a80d69fa5f5099b8a8b67df0da9c92b957e9d194 100644 --- a/tensorflow/compiler/tests/fused_batchnorm_test.py +++ b/tensorflow/compiler/tests/fused_batchnorm_test.py @@ -76,7 +76,8 @@ class FusedBatchNormTest(XLATestCase): # To avoid constant folding t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") - offset = array_ops.placeholder(np.float32, shape=scale_shape, name="offset") + offset = array_ops.placeholder( + np.float32, shape=scale_shape, name="offset") epsilon = 0.001 y_ref, mean_ref, var_ref = self._reference_training( x_val, scale_val, offset_val, epsilon, data_format) @@ -112,7 +113,8 @@ class FusedBatchNormTest(XLATestCase): # To avoid constant folding t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") - offset = array_ops.placeholder(np.float32, shape=scale_shape, name="offset") + offset = array_ops.placeholder( + np.float32, shape=scale_shape, name="offset") epsilon = 0.001 y, mean, var = nn.fused_batch_norm( t_val, @@ -153,7 +155,7 @@ class FusedBatchNormTest(XLATestCase): def testLearningWithGradientChecker(self): self._testLearning(True) - def testGradient(self): + def testGradientTraining(self): # TODO(b/64270657): Use gradient_checker here in addition to comparing with # this reference implementation. channel = 3 @@ -173,7 +175,7 @@ class FusedBatchNormTest(XLATestCase): var = array_ops.placeholder(np.float32, shape=scale_shape, name="var") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") grad_x, grad_scale, grad_offset, _, _ = gen_nn_ops.fused_batch_norm_grad( - grad, x, scale, mean, var, data_format="NHWC") + grad, x, scale, mean, var, data_format="NHWC", is_training=True) grad_x_val, grad_scale_val, grad_offset_val = sess.run( [grad_x, grad_scale, grad_offset], { @@ -191,6 +193,53 @@ class FusedBatchNormTest(XLATestCase): self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2) self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3) + def testGradientInference(self): + # TODO(b/64270657): Use gradient_checker here in addition to comparing with + # this reference implementation. + channel = 3 + x_shape = [2, 2, 6, channel] + scale_shape = [channel] + grad_val = np.random.random_sample(x_shape).astype(np.float32) + x_val = np.random.random_sample(x_shape).astype(np.float32) + scale_val = np.random.random_sample(scale_shape).astype(np.float32) + mean_val = np.random.random_sample(scale_shape).astype(np.float32) + var_val = np.random.random_sample(scale_shape).astype(np.float32) + + with self.test_session() as sess, self.test_scope(): + grad = array_ops.placeholder(np.float32, shape=x_shape, name="grad") + x = array_ops.placeholder(np.float32, shape=x_shape, name="x") + mean = array_ops.placeholder(np.float32, shape=scale_shape, name="mean") + var = array_ops.placeholder(np.float32, shape=scale_shape, name="var") + scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") + with self.test_scope(): + out = gen_nn_ops.fused_batch_norm_grad( + grad, x, scale, mean, var, data_format="NHWC", is_training=False) + grad_x, grad_scale, grad_offset, _, _ = out + + ref_x, ref_scale, ref_offset, _, _ = gen_nn_ops.fused_batch_norm_grad( + grad, x, scale, mean, var, data_format="NHWC", is_training=False) + + grad_x_val, grad_scale_val, grad_offset_val, = sess.run( + [grad_x, grad_scale, grad_offset], { + grad: grad_val, + x: x_val, + mean: mean_val, + var: var_val, + scale: scale_val + }) + grad_x_ref, grad_scale_ref, grad_offset_ref, = sess.run( + [ref_x, ref_scale, ref_offset], { + grad: grad_val, + x: x_val, + mean: mean_val, + var: var_val, + scale: scale_val + }) + + self.assertAllClose(grad_x_val, grad_x_ref, atol=1e-2) + self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2) + self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3) + if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a04f376ebf6092fd9b6e879796454b1a5c648c96 --- /dev/null +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -0,0 +1,142 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for image ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_image_ops +from tensorflow.python.platform import test + + +class ResizeBilinearTest(XLATestCase): + + def _assertForwardOpMatchesExpected(self, + image_np, + target_shape, + expected=None): + if expected is None: + self.fail("expected must be specified") + with self.test_session() as sess, self.test_scope(): + image = array_ops.placeholder(image_np.dtype) + resized = gen_image_ops.resize_bilinear( + image, target_shape, align_corners=True) + out = sess.run(resized, {image: image_np[np.newaxis, :, :, np.newaxis]}) + self.assertAllClose(expected[np.newaxis, :, :, np.newaxis], out) + + def _assertBackwardOpMatchesExpected(self, + grads_np, + input_shape=None, + dtype=None, + expected=None): + if input_shape is None: + self.fail("input_shape must be specified") + if expected is None: + self.fail("expected must be specified") + with self.test_session() as sess, self.test_scope(): + dtype = dtype or np.float32 + grads = array_ops.placeholder(np.float32) + resized = gen_image_ops._resize_bilinear_grad( + grads, + np.zeros([1, input_shape[0], input_shape[1], 1], dtype=dtype), + align_corners=True) + out = sess.run(resized, {grads: grads_np[np.newaxis, :, :, np.newaxis]}) + self.assertAllClose(expected[np.newaxis, :, :, np.newaxis], out) + + def testAlignCorners1x2To3x2(self): + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array([[1, 2]], dtype=dtype), [3, 3], + expected=np.array( + [[1, 1.5, 2], [1, 1.5, 2], [1, 1.5, 2]], dtype=np.float32)) + + def testAlignCorners1x2To3x2Grad(self): + for dtype in self.float_types: + self._assertBackwardOpMatchesExpected( + np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32), + input_shape=[1, 2], + dtype=dtype, + expected=np.array([[9, 12]], dtype=np.float32)) + + def testAlignCorners2x2To1x1(self): + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array([[1, 2], [3, 4]], dtype=dtype), [1, 1], + expected=np.array([[1]], dtype=np.float32)) + + def testAlignCorners2x2To1x1Grad(self): + for dtype in self.float_types: + self._assertBackwardOpMatchesExpected( + np.array([[7]], dtype=np.float32), + input_shape=[2, 2], + dtype=dtype, + expected=np.array([[7, 0], [0, 0]], dtype=np.float32)) + + def testAlignCorners2x2To3x3(self): + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array([[1, 2], [3, 4]], dtype=dtype), [3, 3], + expected=np.array( + [[1, 1.5, 2], [2, 2.5, 3], [3, 3.5, 4]], dtype=np.float32)) + + def testAlignCorners2x2To3x3Grad(self): + self._assertBackwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32), + input_shape=[2, 2], + expected=np.array([[5.25, 8.25], [14.25, 17.25]], dtype=np.float32)) + + def testAlignCorners3x3To2x2(self): + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype), [2, 2], + expected=np.array([[1, 3], [7, 9]], dtype=np.float32)) + + def testAlignCorners3x3To2x2Grad(self): + for dtype in self.float_types: + self._assertBackwardOpMatchesExpected( + np.array([[7, 13], [22, 4]], dtype=np.float32), + input_shape=[3, 3], + dtype=dtype, + expected=np.array( + [[7, 0, 13], [0, 0, 0], [22, 0, 4]], dtype=np.float32)) + + def testAlignCorners4x4To3x3(self): + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], + dtype=dtype), [3, 3], + expected=np.array( + [[1, 2.5, 4], [7, 8.5, 10], [13, 14.5, 16]], dtype=np.float32)) + + def testAlignCorners4x4To3x3Grad(self): + for dtype in self.float_types: + self._assertBackwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32), + input_shape=[4, 4], + dtype=dtype, + expected=np.array( + [[1, 1, 1, 3], [2, 1.25, 1.25, 3], [2, 1.25, 1.25, 3], + [7, 4, 4, 9]], + dtype=np.float32)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/momentum_test.py b/tensorflow/compiler/tests/momentum_test.py index c00e3035a0982b2b2e59eb6f53499918515ae71d..af9394e7d7dc9cf7dd009420ff9c845aec8785bd 100644 --- a/tensorflow/compiler/tests/momentum_test.py +++ b/tensorflow/compiler/tests/momentum_test.py @@ -96,28 +96,27 @@ class MomentumOptimizerTest(XLATestCase): def testNesterovMomentum(self): for dtype in self.float_types: with self.test_session(), self.test_scope(): - var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) - var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) - var0_np = np.array([1.0, 2.0], dtype=dtype) - var1_np = np.array([3.0, 4.0], dtype=dtype) + var0 = resource_variable_ops.ResourceVariable([0.1, 0.2], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([0.3, 0.4], dtype=dtype) + var0_np = np.array([0.1, 0.2], dtype=dtype) + var1_np = np.array([0.3, 0.4], dtype=dtype) accum0_np = np.array([0.0, 0.0], dtype=dtype) accum1_np = np.array([0.0, 0.0], dtype=dtype) - cost = 5 * var0 * var0 + 3 * var1 + cost = 0.4 * var0 * var0 + 0.9 * var1 global_step = resource_variable_ops.ResourceVariable( array_ops.zeros([], dtypes.int32), name="global_step") mom_op = momentum_lib.MomentumOptimizer( - learning_rate=2.0, momentum=0.9, use_nesterov=True) + learning_rate=0.1, momentum=0.9, use_nesterov=True) opt_op = mom_op.minimize(cost, global_step, [var0, var1]) variables.global_variables_initializer().run() for _ in range(1, 5): opt_op.run() var0_np, accum0_np = self._update_nesterov_momentum_numpy( - var0_np, accum0_np, var0_np * 10, 2.0, 0.9) - var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np, - accum1_np, - 3, 2.0, 0.9) - self.assertAllClose(var0_np, var0.eval()) - self.assertAllClose(var1_np, var1.eval()) + var0_np, accum0_np, var0_np * 0.8, 0.1, 0.9) + var1_np, accum1_np = self._update_nesterov_momentum_numpy( + var1_np, accum1_np, 0.9, 0.1, 0.9) + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) def testTensorLearningRateAndMomentum(self): for dtype in self.float_types: diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index 6a8c3bcd55a6e454a19b6249cf4eb48739c8657f..798daaadbc5be50ef9cf7e1205f6d5a0bde59640 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -2460,6 +2460,36 @@ TEST_F(OpTest, Reshape) { }); } +TEST_F(OpTest, ResizeBilinear) { + Repeatedly([this]() { + std::vector in_dims = RandomDims(4, 4); + std::vector out_dims = RandomDims(2, 2); + + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("ResizeBilinear") + .RandomInput(DT_FLOAT, in_dims) + .Input(test::AsTensor( + std::vector(out_dims.begin(), out_dims.end()))) + .Attr("T", DT_FLOAT) + .Attr("align_corners", true)); + }); +} + +TEST_F(OpTest, ResizeBilinearGrad) { + Repeatedly([this]() { + std::vector in_dims = RandomDims(4, 4); + std::vector out_dims = RandomDims(2, 2); + + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("ResizeBilinearGrad") + .RandomInput(DT_FLOAT, in_dims) + .RandomInput(DT_FLOAT, + {in_dims[0], out_dims[0], out_dims[1], in_dims[3]}) + .Attr("T", DT_FLOAT) + .Attr("align_corners", true)); + }); +} + TEST_F(OpTest, Reverse) { Repeatedly([this]() { std::vector dims = RandomDims(1); diff --git a/tensorflow/compiler/tests/scan_ops_test.py b/tensorflow/compiler/tests/scan_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3260e63b23226d736a7ddc0f21a94a8c791e0442 --- /dev/null +++ b/tensorflow/compiler/tests/scan_ops_test.py @@ -0,0 +1,229 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional tests for scan ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +def numpy_reverse(x, axis): + length = len(x.shape) + if axis < 0: + axis = length + axis + + ix = [ + slice(None, None, -1) if i == axis else slice(None) for i in range(length) + ] + return x[ix] + + +def handle_options(func, x, axis, exclusive, reverse): + """Adds tf options to numpy scan ops.""" + length = len(x.shape) + if axis < 0: + axis = length + axis + + if reverse: + x = numpy_reverse(x, axis) + + if exclusive: + ix_head = [slice(0, 1) if i == axis else slice(None) for i in range(length)] + ix_init = [ + slice(0, -1) if i == axis else slice(None) for i in range(length) + ] + if func == np.cumsum: + init = np.zeros_like(x[ix_head]) + elif func == np.cumprod: + init = np.ones_like(x[ix_head]) + else: + raise ValueError("Unknown scan function.") + x = np.concatenate([init, func(x[ix_init], axis)], axis=axis) + else: + x = func(x, axis=axis) + + if reverse: + x = numpy_reverse(x, axis) + return x + + +class CumsumTest(XLATestCase): + + valid_dtypes = [np.float32] + + def axis_dtypes(self): + return set(self.int_types).intersection([np.int32, np.int64]) + + def _compare(self, x, axis, exclusive, reverse): + np_out = handle_options(np.cumsum, x, axis, exclusive, reverse) + with self.test_session(), self.test_scope(): + p = array_ops.placeholder(x.dtype) + tf_out = math_ops.cumsum(p, axis, exclusive, reverse).eval( + feed_dict={p: x}) + + self.assertAllClose(np_out, tf_out) + + def _compareAll(self, x, axis): + for exclusive in [True, False]: + for reverse in [True, False]: + self._compare(x, axis, exclusive, reverse) + + def testEmpty(self): + for dtype in self.valid_dtypes: + x = np.zeros([0]).astype(dtype) + for axis in (-1, 0): + self._compareAll(x, axis) + + def testAxisType(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 6).reshape([5]).astype(dtype) + for axis_dtype in self.axis_dtypes(): + with self.test_session(), self.test_scope(): + p = array_ops.placeholder(x.dtype) + axis = constant_op.constant(0, axis_dtype) + math_ops.cumsum(p, axis).eval(feed_dict={p: x}) + + def test1D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 6).reshape([5]).astype(dtype) + for axis in (-1, 0): + self._compareAll(x, axis) + + def test2D(self): + for dtype in self.valid_dtypes: + x = np.arange(0, 10).reshape([2, 5]).astype(dtype) + for axis in (-2, -1, 0, 1): + self._compareAll(x, axis) + + def test3D(self): + for dtype in self.valid_dtypes: + x = np.arange(0, 20).reshape([2, 2, 5]).astype(dtype) + for axis in (-3, -2, -1, 0, 1, 2): + self._compareAll(x, axis) + + def test6D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 145).reshape([2, 2, 3, 3, 2, 2]).astype(dtype) + for axis in range(-6, 6, 3): + self._compareAll(x, axis) + + def testInvalidAxis(self): + x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) + with self.test_session(), self.test_scope(): + input_tensor = ops.convert_to_tensor(x) + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "Expected scan axis in the range [-2, 2)" in str(e)): + math_ops.cumsum(input_tensor, -3).eval() + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "Expected scan axis in the range [-2, 2)" in str(e)): + math_ops.cumsum(input_tensor, 2).eval() + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "axis must be a scalar" in str(e)): + math_ops.cumsum(input_tensor, [0]).eval() + + +class CumprodTest(XLATestCase): + + valid_dtypes = [np.float32] + + def axis_dtypes(self): + return set(self.int_types).intersection([np.int32, np.int64]) + + def _compare(self, x, axis, exclusive, reverse): + np_out = handle_options(np.cumprod, x, axis, exclusive, reverse) + with self.test_session(), self.test_scope(): + p = array_ops.placeholder(x.dtype) + prod = math_ops.cumprod(p, axis, exclusive, reverse) + tf_out = prod.eval(feed_dict={p: x}) + + self.assertAllClose(np_out, tf_out) + + def _compareAll(self, x, axis): + for exclusive in [True, False]: + for reverse in [True, False]: + self._compare(x, axis, exclusive, reverse) + + def testEmpty(self): + for dtype in self.valid_dtypes: + x = np.zeros([0]).astype(dtype) + for axis in (-1, 0): + self._compareAll(x, axis) + + def testAxisType(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 6).reshape([5]).astype(dtype) + for axis_dtype in self.axis_dtypes(): + with self.test_session(), self.test_scope(): + p = array_ops.placeholder(x.dtype) + axis = constant_op.constant(0, axis_dtype) + math_ops.cumprod(x, axis).eval(feed_dict={p: x}) + + def test1D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 6).reshape([5]).astype(dtype) + for axis in (-1, 0): + self._compareAll(x, axis) + + def test2D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 11).reshape([2, 5]).astype(dtype) + for axis in (-2, -1, 0, 1): + self._compareAll(x, axis) + + def test3D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 21).reshape([2, 2, 5]).astype(dtype) + for axis in (-3, -2, -1, 0, 1, 2): + self._compareAll(x, axis) + + def test6D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 145).reshape([2, 2, 3, 3, 2, 2]).astype(dtype) + for axis in range(-6, 6, 3): + self._compareAll(x, axis) + + def testInvalidAxis(self): + x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) + with self.test_session(), self.test_scope(): + input_tensor = ops.convert_to_tensor(x) + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "Expected scan axis in the range [-2, 2)" in str(e)): + math_ops.cumprod(input_tensor, -3).eval() + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "Expected scan axis in the range [-2, 2)" in str(e)): + math_ops.cumprod(input_tensor, 2).eval() + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "axis must be a scalar" in str(e)): + math_ops.cumprod(input_tensor, [0]).eval() + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4336ebdbd184a081619f0a6951dd4514735c6eb6 --- /dev/null +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -0,0 +1,118 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for stateless random-number generation ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.contrib import stateless +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class StatelessRandomOpsTest(XLATestCase): + """Test cases for stateless random-number generator operators.""" + + def _random_types(self): + return [dtypes.float32] + + def testDeterminism(self): + # Stateless values should be equal iff the seeds are equal (roughly) + with self.test_session(), self.test_scope(): + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + seeds = [(x, y) for x in range(5) for y in range(5)] * 3 + for stateless_op in [ + stateless.stateless_random_uniform, stateless.stateless_random_normal + ]: + for shape in (), (3,), (2, 5): + for dtype in self._random_types(): + pure = stateless_op(shape, seed=seed_t, dtype=dtype) + values = [(seed, pure.eval(feed_dict={ + seed_t: seed + })) for seed in seeds] + for s0, v0 in values: + for s1, v1 in values: + self.assertEqual(s0 == s1, np.all(v0 == v1)) + + def testRandomUniformIsInRange(self): + with self.test_session() as sess, self.test_scope(): + for dtype in self._random_types(): + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + x = stateless.stateless_random_uniform( + shape=[1000], seed=seed_t, dtype=dtype) + y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) + self.assertTrue(np.all(y >= 0)) + self.assertTrue(np.all(y < 1)) + + def _chi_squared(self, x, bins): + """Pearson's Chi-squared test.""" + x = np.ravel(x) + n = len(x) + histogram, _ = np.histogram(x, bins=bins, range=(0, 1)) + expected = n / float(bins) + return np.sum(np.square(histogram - expected) / expected) + + def testDistributionOfStatelessRandomUniform(self): + """Use Pearson's Chi-squared test to test for uniformity.""" + with self.test_session() as sess, self.test_scope(): + for dtype in self._random_types(): + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + n = 1000 + x = stateless.stateless_random_uniform( + shape=[n], seed=seed_t, dtype=dtype) + y = sess.run(x, {seed_t: [565656, 121212]}) + # Tests that the values are distributed amongst 10 bins with equal + # probability. 16.92 is the Chi^2 value for 9 degrees of freedom with + # p=0.05. This test is probabilistic and would be flaky if the random + # seed were not fixed. + self.assertTrue(self._chi_squared(y, 10) < 16.92) + + def _normal_cdf(self, x): + """Cumulative distribution function for a standard normal distribution.""" + return 0.5 + 0.5 * np.vectorize(math.erf)(x / math.sqrt(2)) + + def _anderson_darling(self, x): + """Anderson-Darling test for a standard normal distribution.""" + x = np.sort(np.ravel(x)) + n = len(x) + i = np.linspace(1, n, n) + z = np.sum((2 * i - 1) * np.log(self._normal_cdf(x)) + + (2 * (n - i) + 1) * np.log(1 - self._normal_cdf(x))) + return -n - z / n + + def testDistributionOfStatelessRandomNormal(self): + """Use Anderson-Darling test to test distribution appears normal.""" + with self.test_session() as sess, self.test_scope(): + for dtype in self._random_types(): + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + n = 1000 + x = stateless.stateless_random_normal( + shape=[n], seed=seed_t, dtype=dtype) + y = sess.run(x, {seed_t: [25252, 314159]}) + # The constant 2.492 is the 5% critical value for the Anderson-Darling + # test where the mean and variance are known. This test is probabilistic + # so to avoid flakiness the seed is fixed. + self.assertTrue(self._anderson_darling(y) < 2.492) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index ac039e01623b954e291760fb9b50ef8eae3da7c1..a62925a1818da00cb0a9e82e1281db20fb38b208 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -330,8 +330,7 @@ class TensorArrayTest(xla_test.XLATestCase): # Find two different floating point types, create an array of # the first type, but try to read the other type. if len(self.float_types) > 1: - dtype1 = self.float_types[0] - dtype2 = self.float_types[1] + dtype1, dtype2 = list(self.float_types)[:2] with self.test_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtype1, tensor_array_name="foo", size=3) diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index a9a3f4f97f649260e9863fff8ff05d046bd91947..0a6fe04d3cdd29f1d40d33be1f4319090e7ba3d1 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -33,6 +33,17 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.platform import googletest +def nhwc_to_format(x, data_format): + """Converts a numpy array from NHWC format to `data_format`.""" + rank = len(x.shape) + if data_format == "NCHW": + return np.transpose(x, [0, rank - 1] + list(range(1, rank - 1))) + elif data_format == "NHWC": + return x + else: + raise ValueError("Unknown format {}".format(data_format)) + + class UnaryOpsTest(XLATestCase): """Test cases for unary operators.""" @@ -56,7 +67,7 @@ class UnaryOpsTest(XLATestCase): output = op(pinp) result = session.run(output, {pinp: inp}) if equality_test is None: - equality_test = self.assertAllClose + equality_test = self.assertAllCloseAccordingToType equality_test(result, expected, rtol=rtol, atol=atol) def ListsAreClose(self, result, expected, rtol, atol): @@ -76,6 +87,12 @@ class UnaryOpsTest(XLATestCase): array_ops.diag_part, np.arange(36).reshape([2, 3, 2, 3]).astype(dtype), np.array([[0, 7, 14], [21, 28, 35]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + array_ops.diag, np.array([[1, 2], [3, 4]], dtype=dtype), + np.array( + [[[[1, 0], [0, 0]], [[0, 2], [0, 0]]], [[[0, 0], [3, 0]], + [[0, 0], [0, 4]]]], + dtype=dtype)) self._assertOpOutputMatchesExpected( array_ops.identity, @@ -86,6 +103,21 @@ class UnaryOpsTest(XLATestCase): array_ops.matrix_diag, np.array([[1, 2], [3, 4]], dtype=dtype), np.array([[[1, 0], [0, 2]], [[3, 0], [0, 4]]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + array_ops.matrix_diag, np.array([1, 2, 3, 4], dtype=dtype), + np.array( + [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]], + dtype=dtype)) + self._assertOpOutputMatchesExpected( + array_ops.matrix_diag, + np.array( + [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], dtype=dtype), + np.array( + [[[[1, 0, 0], [0, 2, 0], [0, 0, 3]], + [[4, 0, 0], [0, 5, 0], [0, 0, 6]]], + [[[7, 0, 0], [0, 8, 0], [0, 0, 9]], + [[10, 0, 0], [0, 11, 0], [0, 0, 12]]]], + dtype=dtype)) self._assertOpOutputMatchesExpected( array_ops.matrix_diag_part, np.arange(3 * 2 * 4).reshape([3, 2, 4]).astype(dtype), @@ -331,26 +363,23 @@ class UnaryOpsTest(XLATestCase): def testComplexOps(self): for dtype in self.complex_types: - # TODO(b/65408531): Wider support for log (needs atan2). - atan2_supported = self.device == "XLA_GPU" - if atan2_supported: - self._assertOpOutputMatchesExpected( - math_ops.acosh, - np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), - expected=np.arccosh( - np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) + self._assertOpOutputMatchesExpected( + math_ops.acosh, + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), + expected=np.arccosh( + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) - self._assertOpOutputMatchesExpected( - math_ops.asinh, - np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), - expected=np.arcsinh( - np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) + self._assertOpOutputMatchesExpected( + math_ops.asinh, + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), + expected=np.arcsinh( + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) - self._assertOpOutputMatchesExpected( - math_ops.atanh, - np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), - expected=np.arctanh( - np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) + self._assertOpOutputMatchesExpected( + math_ops.atanh, + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), + expected=np.arctanh( + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) self._assertOpOutputMatchesExpected( math_ops.cosh, @@ -377,11 +406,10 @@ class UnaryOpsTest(XLATestCase): np.array([[1, 2j, 2 + 3j]], dtype=dtype), expected=1.0 / np.array([[1, 2j, 2 + 3j]], dtype=dtype)) - if atan2_supported: - self._assertOpOutputMatchesExpected( - math_ops.log, - np.array([[5j, 3 - 2j]], dtype=dtype), - expected=np.log(np.array([[5j, 3 - 2j]], dtype=dtype))) + self._assertOpOutputMatchesExpected( + math_ops.log, + np.array([[5j, 3 - 2j]], dtype=dtype), + expected=np.log(np.array([[5j, 3 - 2j]], dtype=dtype))) self._assertOpOutputMatchesExpected( math_ops.sin, @@ -395,27 +423,26 @@ class UnaryOpsTest(XLATestCase): # TODO(b/34703906): improve log1p implementation and make tolerance # tighter. - if atan2_supported: # TODO(b/34703906): log support - self._assertOpOutputMatchesExpected( - math_ops.log1p, - np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype), - expected=np.log1p( - np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype))) + self._assertOpOutputMatchesExpected( + math_ops.log1p, + np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype), + expected=np.log1p( + np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype))) - val = np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype) - self._assertOpOutputMatchesExpected( - math_ops.rsqrt, val, expected=1 / np.sqrt(val)) + val = np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype) + self._assertOpOutputMatchesExpected( + math_ops.rsqrt, val, expected=1 / np.sqrt(val)) - self._assertOpOutputMatchesExpected( - math_ops.sigmoid, val, expected=1 / (1 + np.exp(-val))) + self._assertOpOutputMatchesExpected( + math_ops.sigmoid, val, expected=1 / (1 + np.exp(-val))) - self._assertOpOutputMatchesExpected( - math_ops.sqrt, val, expected=np.sqrt(val)) + self._assertOpOutputMatchesExpected( + math_ops.sqrt, val, expected=np.sqrt(val)) - self._assertOpOutputMatchesExpected( - math_ops.tanh, - np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype), - expected=np.tanh(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype))) + self._assertOpOutputMatchesExpected( + math_ops.tanh, + np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype), + expected=np.tanh(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype))) self._assertOpOutputMatchesExpected( math_ops.tan, @@ -448,12 +475,10 @@ class UnaryOpsTest(XLATestCase): np.array([[-4j, 3 + 2j], [2, -1j]], dtype=dtype), expected=np.array([[1, 1], [1, 1]], dtype=dtype)) - if atan2_supported: # TODO(b/34703906): atan2 support - self._assertOpOutputMatchesExpected( - math_ops.angle, - np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype), - expected=np.angle( - np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype))) + self._assertOpOutputMatchesExpected( + math_ops.angle, + np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype), + expected=np.angle(np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype))) self._assertOpOutputMatchesExpected( math_ops.conj, @@ -541,7 +566,8 @@ class UnaryOpsTest(XLATestCase): def testCast(self): shapes = [[], [4], [2, 3], [2, 0, 4]] - types = [dtypes.bool, dtypes.int32, dtypes.float32] + self.complex_tf_types + types = (set([dtypes.bool, dtypes.int32, dtypes.float32]) | + self.complex_tf_types) for shape in shapes: for src_type in types: for dst_type in types: @@ -641,55 +667,88 @@ class UnaryOpsTest(XLATestCase): equality_test=self.ListsAreClose) def testDepthToSpace(self): + def make_op(data_format): + def op(x): + return array_ops.depth_to_space(x, block_size=2, + data_format=data_format) + return op + for dtype in self.numeric_types: - self._assertOpOutputMatchesExpected( - lambda x: array_ops.depth_to_space(x, block_size=2), - np.array([[[[1, 2, 3, 4]]]], dtype=dtype), - expected=np.array([[[[1], [2]], - [[3], [4]]]], dtype=dtype)) + for data_format in ["NCHW", "NHWC"]: + self._assertOpOutputMatchesExpected( + make_op(data_format), + nhwc_to_format(np.array([[[[1, 2, 3, 4]]]], dtype=dtype), + data_format), + expected=nhwc_to_format(np.array([[[[1], [2]], + [[3], [4]]]], dtype=dtype), + data_format)) - self._assertOpOutputMatchesExpected( - lambda x: array_ops.depth_to_space(x, block_size=2), - np.array([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], dtype=dtype), - expected=np.array([[[[1, 2, 3], [4, 5, 6]], - [[7, 8, 9], [10, 11, 12]]]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + make_op(data_format), + nhwc_to_format( + np.array([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], + dtype=dtype), + data_format), + expected=nhwc_to_format( + np.array([[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]], + dtype=dtype), + data_format)) - self._assertOpOutputMatchesExpected( - lambda x: array_ops.depth_to_space(x, block_size=2), - np.array([[[[1, 2, 3, 4], - [5, 6, 7, 8]], - [[9, 10, 11, 12], - [13, 14, 15, 16]]]], dtype=dtype), - expected=np.array([[[[1], [2], [5], [6]], - [[3], [4], [7], [8]], - [[9], [10], [13], [14]], - [[11], [12], [15], [16]]]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + make_op(data_format), + nhwc_to_format( + np.array([[[[1, 2, 3, 4], + [5, 6, 7, 8]], + [[9, 10, 11, 12], + [13, 14, 15, 16]]]], dtype=dtype), + data_format), + expected=nhwc_to_format( + np.array([[[[1], [2], [5], [6]], + [[3], [4], [7], [8]], + [[9], [10], [13], [14]], + [[11], [12], [15], [16]]]], dtype=dtype), + data_format)) def testSpaceToDepth(self): + def make_op(data_format): + def op(x): + return array_ops.space_to_depth(x, block_size=2, + data_format=data_format) + return op + for dtype in self.numeric_types: - self._assertOpOutputMatchesExpected( - lambda x: array_ops.space_to_depth(x, block_size=2), - np.array([[[[1], [2]], - [[3], [4]]]], dtype=dtype), - expected=np.array([[[[1, 2, 3, 4]]]], dtype=dtype)) + for data_format in ["NCHW", "NHWC"]: + self._assertOpOutputMatchesExpected( + make_op(data_format), + nhwc_to_format(np.array([[[[1], [2]], + [[3], [4]]]], dtype=dtype), + data_format), + expected=nhwc_to_format(np.array([[[[1, 2, 3, 4]]]], dtype=dtype), + data_format)) - self._assertOpOutputMatchesExpected( - lambda x: array_ops.space_to_depth(x, block_size=2), - np.array([[[[1, 2, 3], [4, 5, 6]], - [[7, 8, 9], [10, 11, 12]]]], dtype=dtype), - expected=np.array([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], - dtype=dtype)) + self._assertOpOutputMatchesExpected( + make_op(data_format), + nhwc_to_format(np.array([[[[1, 2, 3], [4, 5, 6]], + [[7, 8, 9], [10, 11, 12]]]], dtype=dtype), + data_format), + expected=nhwc_to_format( + np.array([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], + dtype=dtype), + data_format)) - self._assertOpOutputMatchesExpected( - lambda x: array_ops.space_to_depth(x, block_size=2), - np.array([[[[1], [2], [5], [6]], - [[3], [4], [7], [8]], - [[9], [10], [13], [14]], - [[11], [12], [15], [16]]]], dtype=dtype), - expected=np.array([[[[1, 2, 3, 4], - [5, 6, 7, 8]], - [[9, 10, 11, 12], - [13, 14, 15, 16]]]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + make_op(data_format), + nhwc_to_format(np.array([[[[1], [2], [5], [6]], + [[3], [4], [7], [8]], + [[9], [10], [13], [14]], + [[11], [12], [15], [16]]]], dtype=dtype), + data_format), + expected=nhwc_to_format( + np.array([[[[1, 2, 3, 4], + [5, 6, 7, 8]], + [[9, 10, 11, 12], + [13, 14, 15, 16]]]], dtype=dtype), + data_format)) def _assertSoftplusMatchesExpected(self, features, dtype): features = np.array(features, dtype=dtype) diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index c50342dee45eba6ae54f01653ecc81ef096b547b..b08d6ab21e0746558cb3d4818d4c822c45d2e9ee 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -107,11 +107,26 @@ class VariableOpsTest(XLATestCase): [[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]]], ).astype(dtype), sess.run(x)) + def testShape(self): + for dtype in self.numeric_types: + init = np.ones([2, 3]).astype(dtype) + with self.test_session() as session, self.test_scope(): + v = resource_variable_ops.ResourceVariable(init) + session.run(variables.variables_initializer([v])) + h = v.handle + s32, s64 = session.run([ + resource_variable_ops.variable_shape(h), + resource_variable_ops.variable_shape(h, out_type=dtypes.int64) + ]) + self.assertEqual(s32.dtype, np.int32) + self.assertEqual(s64.dtype, np.int64) + self.assertAllEqual(s32, [2, 3]) + self.assertAllEqual(s64, [2, 3]) + def testReadWrite(self): """Tests initialization, reading, and writing a resource variable.""" for dtype in self.numeric_types: with self.test_session() as session: - print(ops.get_default_graph()) with self.test_scope(): with variable_scope.variable_scope("ascope", use_resource=True): x = variable_scope.get_variable( diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py index 0be127997e5211f810ca791187486760881fe172..7e1f5c76ed65946363cc3c113ab1a9862f87b289 100644 --- a/tensorflow/compiler/tests/xla_test.py +++ b/tensorflow/compiler/tests/xla_test.py @@ -53,41 +53,100 @@ class XLATestCase(test.TestCase): super(XLATestCase, self).__init__(method_name) self.device = FLAGS.test_device self.has_custom_call = (self.device == 'XLA_CPU') - self.all_tf_types = [ + self._all_tf_types = set([ dtypes.as_dtype(types_pb2.DataType.Value(name)) for name in FLAGS.types.split(',') - ] - self.int_tf_types = [ - dtype for dtype in self.all_tf_types if dtype.is_integer - ] - self.float_tf_types = [ - dtype for dtype in self.all_tf_types if dtype.is_floating - ] - self.complex_tf_types = [ - dtype for dtype in self.all_tf_types if dtype.is_complex - ] - self.numeric_tf_types = ( - self.int_tf_types + self.float_tf_types + self.complex_tf_types) - - self.all_types = [dtype.as_numpy_dtype for dtype in self.all_tf_types] - self.int_types = [dtype.as_numpy_dtype for dtype in self.int_tf_types] - self.float_types = [dtype.as_numpy_dtype for dtype in self.float_tf_types] - self.complex_types = [ + ]) + self.int_tf_types = set([ + dtype for dtype in self._all_tf_types if dtype.is_integer + ]) + self._float_tf_types = set([ + dtype for dtype in self._all_tf_types if dtype.is_floating + ]) + self.complex_tf_types = set([ + dtype for dtype in self._all_tf_types if dtype.is_complex + ]) + self._numeric_tf_types = set( + self.int_tf_types | self._float_tf_types | self.complex_tf_types) + + self._all_types = set( + [dtype.as_numpy_dtype for dtype in self._all_tf_types]) + self.int_types = set([dtype.as_numpy_dtype for dtype in self.int_tf_types]) + self._float_types = set( + [dtype.as_numpy_dtype for dtype in self._float_tf_types]) + self.complex_types = set([ dtype.as_numpy_dtype for dtype in self.complex_tf_types - ] - self.numeric_types = self.int_types + self.float_types + self.complex_types + ]) + self._numeric_types = set( + self.int_types | self._float_types | self.complex_types) # Parse the manifest file, if any, into a regex identifying tests to # disable self.disabled_regex = None + self._method_types_filter = dict() + # TODO(xpan): Make it text proto if it doesn't scale. + # Each line of the manifest file specifies an entry. The entry can be + # 1) TestNameRegex // E.g. CumprodTest.* Or + # 2) TestName TypeName // E.g. AdamOptimizerTest.testSharing DT_BFLOAT16 + # The 1) disables the entire test. While 2) only filter some numeric types + # so that they are not used in those tests. + if FLAGS.disabled_manifest is not None: comments_re = re.compile('#.*$') manifest_file = open(FLAGS.disabled_manifest, 'r') - lines = manifest_file.read().splitlines() - lines = [comments_re.sub('', l).strip() for l in lines] - self.disabled_regex = re.compile('|'.join(lines)) + disabled_tests = [] + disabled_method_types = [] + for l in manifest_file.read().splitlines(): + entry = comments_re.sub('', l).strip().split(' ') + if len(entry) == 1: + disabled_tests.append(entry[0]) + elif len(entry) == 2: + disabled_method_types.append( + (entry[0], entry[1].strip().split(','))) + else: + raise ValueError('Bad entry in manifest file.') + + self.disabled_regex = re.compile('|'.join(disabled_tests)) + for method, types in disabled_method_types: + self._method_types_filter[method] = set([ + dtypes.as_dtype(types_pb2.DataType.Value(name)).as_numpy_dtype + for name in types]) manifest_file.close() + @property + def all_tf_types(self): + name = '{}.{}'.format(type(self).__name__, self._testMethodName) + tf_types = set([dtypes.as_dtype(t) + for t in self._method_types_filter.get(name, set())]) + return self._all_tf_types - tf_types + + @property + def float_types(self): + name = '{}.{}'.format(type(self).__name__, self._testMethodName) + return self._float_types - self._method_types_filter.get(name, set()) + + @property + def float_tf_types(self): + name = '{}.{}'.format(type(self).__name__, self._testMethodName) + return self._float_tf_types - self._method_types_filter.get(name, set()) + + @property + def numeric_tf_types(self): + name = '{}.{}'.format(type(self).__name__, self._testMethodName) + tf_types = set([dtypes.as_dtype(t) + for t in self._method_types_filter.get(name, set())]) + return self._numeric_tf_types - tf_types + + @property + def numeric_types(self): + name = '{}.{}'.format(type(self).__name__, self._testMethodName) + return self._numeric_types - self._method_types_filter.get(name, set()) + + @property + def all_types(self): + name = '{}.{}'.format(type(self).__name__, self._testMethodName) + return self._all_types - self._method_types_filter.get(name, set()) + def setUp(self): super(XLATestCase, self).setUp() name = '{}.{}'.format(type(self).__name__, self._testMethodName) diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 5a81438b1c48e7f0ef66dae072092974db24c621..5d1cb6d73570a1a3efbe0d2d37d9746bc0e2528f 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -1,6 +1,6 @@ licenses(["notice"]) # Apache 2.0 -load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") package_group( name = "internal", @@ -25,6 +25,30 @@ package( load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") +cc_library( + name = "tf2xla_supported_ops_lib", + srcs = ["tf2xla_supported_ops.cc"], + hdrs = ["tf2xla_supported_ops.h"], + visibility = ["//visibility:public"], + deps = [ + ":xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + ], +) + +tf_cc_binary( + name = "tf2xla_supported_ops", + srcs = ["tf2xla_supported_ops_main.cc"], + visibility = ["//visibility:public"], + deps = [":tf2xla_supported_ops_lib"], +) + xla_proto_library( name = "tf2xla_proto", srcs = ["tf2xla.proto"], @@ -67,7 +91,6 @@ cc_library( # Keep dependencies to a minimum here; this library is used in every AOT # binary produced by tfcompile. "//tensorflow/compiler/aot:runtime", - "//tensorflow/compiler/tf2xla:xla_local_runtime_context", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/core:framework_lite", ], @@ -215,6 +238,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", + "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", ], @@ -357,13 +381,6 @@ tf_cc_test( ], ) -cc_library( - name = "xla_local_runtime_context", - hdrs = ["xla_local_runtime_context.h"], - visibility = ["//visibility:public"], - deps = ["//tensorflow/core:framework_lite"], -) - cc_library( name = "dump_graph", srcs = [ @@ -400,6 +417,7 @@ cc_library( "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core:graph", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index d57273d84442c17565a6ace1c29170a0f3ba583b..ab2f1e9a7ab577bbe704e568b21d9912439605ca 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -52,6 +52,8 @@ Status BackwardsConstAnalysis(const Graph& g, {"Conv2DBackpropInput", "input_sizes"}, {"Conv3DBackpropFilterV2", "filter_sizes"}, {"Conv3DBackpropInputV2", "input_sizes"}, + {"Cumprod", "axis"}, + {"Cumsum", "axis"}, {"DepthwiseConv2dNativeBackpropFilter", "filter_sizes"}, {"DepthwiseConv2dNativeBackpropInput", "input_sizes"}, {"DynamicStitch", "indices"}, @@ -78,6 +80,7 @@ Status BackwardsConstAnalysis(const Graph& g, {"Range", "limit"}, {"Range", "delta"}, {"Reshape", "shape"}, + {"ResizeBilinear", "size"}, {"ResourceStridedSliceAssign", "begin"}, {"ResourceStridedSliceAssign", "end"}, {"ResourceStridedSliceAssign", "strides"}, diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc index ddd912b87315f7943915153b5bf73531107af54d..03603ee9baefd1d20d220faf63c9c1c427ebdf31 100644 --- a/tensorflow/compiler/tf2xla/dump_graph.cc +++ b/tensorflow/compiler/tf2xla/dump_graph.cc @@ -63,7 +63,12 @@ string MakeUniquePath(string name) { string DumpGraphDefToFile(const string& name, GraphDef const& graph_def) { string path = MakeUniquePath(name); - TF_CHECK_OK(WriteTextProto(Env::Default(), path, graph_def)); + Status status = WriteTextProto(Env::Default(), path, graph_def); + if (!status.ok()) { + VLOG(1) << "Failed to dump GraphDef to file: " << path << " : " << status; + path.clear(); + path = "(unavailable)"; + } return path; } @@ -79,7 +84,13 @@ string DumpGraphToFile(const string& name, Graph const& graph, string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef) { string path = MakeUniquePath(name); - TF_CHECK_OK(WriteTextProto(Env::Default(), path, fdef)); + Status status = WriteTextProto(Env::Default(), path, fdef); + if (!status.ok()) { + VLOG(1) << "Failed to dump FunctionDef to file: " << path << " : " + << status; + path.clear(); + path = "(unavailable)"; + } return path; } diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 40a484da0980004b43564f1c57be0426d21379fb..dd67a1dea9656bac9cf3eaa09295a0d42e283706 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/lib/gtl/optional.h" @@ -528,255 +529,101 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, class FunctionalizeCond { public: - // Identifies the connected parts of the tf.Cond. - struct ClusterHandle { - explicit ClusterHandle(int representative = -1) - : representative(representative) {} - - bool operator==(const ClusterHandle& other) const { - return representative == other.representative; - } - - bool operator!=(const ClusterHandle& other) const { - return !(*this == other); - } + // All nodes are assumed to be either in no branch, then branch, else branch, + // or both branches (such as merge nodes). + enum Branch { + kElseBranch = 0, + kThenBranch = 1, + kBoth = 2, + kNeither = 3, + kNumBranchTypes = 4 + }; - bool operator<(const ClusterHandle& other) const { - return representative < other.representative; + // Returns a textual representation of the Branch b. + static string Branch_Name(FunctionalizeCond::Branch b); + + // Comparison function used for sorting nodes consistently. + struct CondCmp { + bool operator()(const Node* lhs, const Node* rhs) const { + bool lhs_is_resource = + lhs->num_inputs() > 0 ? (lhs->input_type(0) == DT_RESOURCE) : false; + bool rhs_is_resource = + rhs->num_inputs() > 0 ? (rhs->input_type(0) == DT_RESOURCE) : false; + return std::tie(lhs_is_resource, lhs->name()) < + std::tie(rhs_is_resource, rhs->name()); } + }; - bool operator>(const ClusterHandle& other) const { - return representative > other.representative; - } + // Functionalize all the switch-merge nodes of a loop-free graph into XlaIf + // nodes. That is, attempt to transform every remaining switch and merge nodes + // in the graph into XlaIf nodes. + // Precondition: All while loops have been removed from graph. + static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library); + private: + struct ForwardFlowNode { + explicit ForwardFlowNode(Branch branch = Branch::kNeither) + : branch(branch), count(0) {} string ToString() const { - return strings::StrCat("Cluster_", representative); + return strings::StrCat("branch=", Branch_Name(branch), " count=", count); } - - // Vector of UnionFind indexable by ClusterHandle and Node*. - struct Vector { - explicit Vector(size_t size) : clusters(size) {} - - UnionFind& at(const ClusterHandle& cluster) { - return clusters.at(cluster.representative); - } - - UnionFind& at(const Node* node) { - return clusters.at(node->id()); - } - - UnionFind& operator[](const Node* node) { - return clusters.at(node->id()); - } - - size_t size() const { return clusters.size(); } - - void resize(size_t count) { return clusters.resize(count); } - - private: - std::vector> clusters; - }; - - private: - int representative; - }; - - // Represents a node in the clustered graph consisting of switch_nodes, - // merge_nodes as well as the edges into and out of this node to other - // Clusters. Each Cluster corresponds to a ClusterHandle and has a - // corresponding representative. - struct Cluster { - std::unordered_set switch_nodes; - std::unordered_set merge_nodes; - std::unordered_set in_nodes; - std::unordered_set out_nodes; - - // A member of the ClusterHandle corresponding to this Cluster. - ClusterHandle representative; - bool visited = false; - }; - - // Represent the clustered graph as map from cluster representative to - // Cluster. - using ClusteredGraph = std::map; - - // The arguments and condition of a XlaIf. The arguments are ordered by node - // id in the original graph. - struct CondArgs { - struct CondCmp { - bool operator()(const Node* lhs, const Node* rhs) const { - bool lhs_is_resource = - lhs->num_inputs() > 0 ? (lhs->input_type(0) == DT_RESOURCE) : false; - bool rhs_is_resource = - rhs->num_inputs() > 0 ? (rhs->input_type(0) == DT_RESOURCE) : false; - return std::tie(lhs_is_resource, lhs->name()) < - std::tie(rhs_is_resource, rhs->name()); - } - }; - Node* conditional = nullptr; - std::set args; + Branch branch; + int count; }; - static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library); - - private: FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library) - : clusters_(graph->num_node_ids()), library_(library), graph_(graph) {} - - // Returns a vector of Merge nodes from the clustered graph where the nodes - // are sorted by the number of switch nodes minus number of merge nodes - // from a root of the clustered graph to the given Merge node, with ties - // broken by the representative of the Cluster. - std::vector> SortedMergeNodes(); - - // Returns whether the graph has no conditionals. - bool NoConditionals() const { return merge_nodes_.empty(); } - - // Construct the clustered graph by creating nodes for each cluster and the - // connections between the clusters. Switch and Merge nodes partition - // clusters, so iterate over those. Note: a Cluster may have neither a - // Merge or Switch but will have an in/out edge from a Cluster that has. - void CreateClusters(); - - // Creates the clustered graph by identifying all the edges between different - // clusters and collecting all switch and merge nodes that correspond to a - // cluster. - void CreateClusteredGraph(); - - // If `from` and `to` correspond to different clusters, then merge the nodes - // in the clustered graph corresponding to `from` and `to`. - // - // If `remove_from_graph` is specified then the `from` node is also removed - // from the clustered graph post contracting the edge. - void ContractEdge(Cluster* from, Cluster* to, bool remove_from_graph = false); + : library_(library), graph_(graph) {} + + // Perform the actual cond functionalization. Iterate over groups of switch + // nodes (linked by common predicate), from innermost to outermost, and + // extract into XlaIf nodes. + Status FunctionalizeInternal(); // Converts a Merge node to a XlaIf. This encapsulates the process of // extracting the bodies needed for the then and else branch, creates a XlaIf // node, removing the nodes of the branches from the graph and replacing the // merge node with a XlaIf. - Status ConvertMergeToXlaIf(Cluster* merge_cluster); - - // Removes a Switch cluster feeding directly into a Merge cluster by removing - // the Switch and Merge nodes and collapsing into a single cluster. - Status RemoveTrivialMerge(Cluster* merge_cluster); - - // Returns the switch cluster corresponding to the merge node. This function - // only returns the switch cluster in the simple case where we have a switch - // node is the entry of a diamond corresponding to a conditional: - // - // Switch - // / \ - // Branch Branch - // \ / - // merge_cluster - // - // Note: either of the branches may be empty. The case where both branches are - // empty is handled by RemoveTrivialMerge. - gtl::optional GetSwitchCluster(const Cluster& merge_cluster); - - // Determines the arguments needed as input to the Merge cluster originating - // from the Switch cluster. - xla::StatusOr DetermineCondArgs(const Cluster& merge_cluster, - const Cluster& switch_cluster); - - // Builds a XlaIfOp to replace the Merge node with. - xla::StatusOr BuildAndAddXlaIfOp(const CondArgs& cond_args, - const Cluster& merge_cluster, - const std::vector& outputs); + Status ConvertCorrespondingMergeToXlaIf( + const std::vector& switch_nodes, + const std::vector& merge_nodes, Node* predicate); + + // Builds a XlaIfOp to replace the Switch-Graph-Merge cluster with. + xla::StatusOr BuildAndAddXlaIfOp( + const std::vector& switch_nodes, + const std::vector& merge_nodes, Node* predicate); // Extracts a function body corresponding to the given input edge of the merge // node. - Status ExtractBody(const CondArgs& cond_args, const Cluster& merge_cluster, - const std::vector& outputs, int input_edge, + Status ExtractBody(const std::vector& switch_nodes, + const std::vector& merge_nodes, int input_edge, Graph* body); // Adds all the input edges to `if_node` corresponding to the arguments. - Status AddInputEdges(const CondArgs& cond_args, Node* if_node); + Status AddInputEdges(const std::vector& cond_args, Node* predicate, + Node* if_node); // Adds all output edges from the `if_node`. Status AddOutputEdges(const std::vector& outputs, Node* if_node); - // Removes all nodes from the graph that are part of cluster. - void RemoveClusterNodes(Cluster* cluster); + // Returns the switches of graph_ in postorder. Dead switch nodes are skipped + // and removed from the graph. + std::vector DetermineSwitchOrder(); - // Removes all argument nodes that are unused. - template - void RemoveUnusedArgs(const T& args); + // Update the state for destination based on the state of source and the node + // being updated. + Status Join(const ForwardFlowNode& src_state, const Node* dst, + ForwardFlowNode* dst_state); - // Removes all Merge nodes in merge_cluster. - void RemoveMergeNodes(Cluster* merge_cluster); - - // Returns the representative member of the corresponding cluster. - ClusterHandle Representative(const Node* node) { - return clusters_.at(node).Get(); - } + // Validates that the branch_map and frontier of nodes for the conditional + // section are as expected. + Status ValidBranchMapAndFrontier( + const std::unordered_map& branch_map, + const std::unordered_set& frontier); - ClusteredGraph clustered_graph_; - ClusterHandle::Vector clusters_; - std::unordered_set merge_nodes_; - std::unordered_set switch_nodes_; FunctionLibraryDefinition* library_; Graph* graph_; }; -std::ostream& operator<<(std::ostream& os, - const FunctionalizeCond::ClusterHandle& c) { - os << c.ToString(); - return os; -} - -// Returns a dot representation of the clustered graph showing the connections -// between the nodes and the nodes in each cluster. -string DebugString(const Graph& graph, - FunctionalizeCond::ClusterHandle::Vector* clusters) { - string ret = "digraph {\ncompound=true;labeljust=\"r\";ranksep=0.24\n"; - std::map subgraphs; - auto name = [](const Node* n) { - return strings::StrCat(n->type_string(), "_", n->id()); - }; - for (Node* n : graph.nodes()) { - strings::StrAppend(&subgraphs[clusters->at(n).Get()], n->id(), " [label=\"", - name(n), "\"];\n"); - } - for (auto kv : subgraphs) { - strings::StrAppend(&ret, "subgraph cluster_", kv.first.ToString(), " {\n", - "style=filled; color=lightgrey;", "label = \"", - kv.first.ToString(), "\";\n", kv.second, "}\n"); - } - for (Node* n : graph.nodes()) { - for (Node* in : n->in_nodes()) { - strings::StrAppend(&ret, in->id(), " -> ", n->id(), ";\n"); - } - } - return strings::StrCat(ret, "} // end"); -} - -string DebugString(const FunctionalizeCond::ClusteredGraph& clustered_graph) { - string ret = "digraph {\ncompound=true;labeljust=\"r\";\n"; - auto name = [](const FunctionalizeCond::Cluster& cluster) { - return cluster.representative.ToString(); - }; - for (auto kv : clustered_graph) { - if (!kv.second.switch_nodes.empty() || !kv.second.merge_nodes.empty()) { - strings::StrAppend( - &ret, kv.first.ToString(), " [label=\"", name(kv.second), - kv.second.switch_nodes.empty() - ? "" - : strings::StrCat(" switches=", kv.second.switch_nodes.size()), - kv.second.merge_nodes.empty() - ? "" - : strings::StrCat(" merges=", kv.second.merge_nodes.size()), - "\"];\n"); - } - } - for (auto kv : clustered_graph) { - for (auto in : kv.second.in_nodes) { - strings::StrAppend(&ret, name(*in), " -> ", name(kv.second), ";\n"); - } - } - return strings::StrCat(ret, "} // end"); -} - bool IsDeadSwitch(const Node* node) { for (const Edge* e : node->out_edges()) { const Node* dst = e->dst(); @@ -792,243 +639,212 @@ bool IsDeadSwitch(const Node* node) { return true; } -void FunctionalizeCond::CreateClusters() { - for (Node* node : graph_->nodes()) { - if (IsSwitch(node)) { - switch_nodes_.insert(node); - } else if (IsMerge(node)) { - merge_nodes_.insert(node); - } - ClusterHandle& cluster = clusters_.at(node).Get(); - cluster = ClusterHandle(node->id()); - } - - // If there are no Merge nodes, then terminate. - if (merge_nodes_.empty()) { - return; - } - - // Remove all dead Switch nodes. - RemoveUnusedArgs(switch_nodes_); - - // All parent_'s are still nullptr so clusters_ may still be resized. Resize - // conservatively assuming all merge nodes become XlaIf nodes. - clusters_.resize(clusters_.size() + merge_nodes_.size()); +string FunctionalizeCond::Branch_Name(FunctionalizeCond::Branch b) { + const string branch_name[FunctionalizeCond::kNumBranchTypes + 1] = { + "else", "then", "both", "neither", "count"}; + return branch_name[b]; +} - // Merge a cluster with its input, unless the input is a Switch node or - // the node is a Merge node. - for (const Node* node : graph_->nodes()) { - if (IsMerge(node) || IsSwitch(node) || !node->IsOp()) { - continue; +Status FunctionalizeCond::ValidBranchMapAndFrontier( + const std::unordered_map& + branch_map, + const std::unordered_set& frontier) { + std::unordered_set pending[kNumBranchTypes]; + for (const auto& kv : branch_map) { + if (kv.second.count != kv.first->in_edges().size()) { + return errors::FailedPrecondition("Value ", kv.first->DebugString(), + " not dominated by switch nodes."); } - for (const Node* in : node->in_nodes()) { - if (in->IsOp() && !IsSwitch(in) && !IsMerge(in)) { - clusters_.at(node).Merge(&clusters_.at(in)); - } + if (VLOG_IS_ON(1)) { + // Append attribute to the graph if running with logging to make the + // changes clearer in the visualization. + kv.first->AddAttr("_XlaFunctionalizeBranch", + Branch_Name(kv.second.branch)); } - // Group all source clusters together. - if (node->IsSource() || node->in_edges().empty()) { - clusters_.at(node).Merge(&clusters_.at(ClusterHandle(Graph::kSourceId))); + } + for (Node* n : frontier) { + pending[branch_map.at(n).branch].insert(n); + } + TF_RET_CHECK(pending[kNeither].empty()) << NodesToString(pending[kNeither]); + for (const Node* n : pending[kBoth]) { + TF_RET_CHECK(IsMerge(n)) << n->DebugString(); + // Merge nodes may be in then or else branch too + } + int index = (pending[kThenBranch].size() <= pending[kElseBranch].size()) + ? kThenBranch + : kElseBranch; + int other = 1 - index; + for (const Node* n : pending[index]) { + if (pending[other].find(n) != pending[other].end()) { + return errors::Internal( + "Node (", n->DebugString().c_str(), + ") in both Else and Then branch should be in Both."); } } + return Status::OK(); } -void FunctionalizeCond::ContractEdge(Cluster* from, Cluster* to, - bool remove_from_graph) { - VLOG(3) << "ContractEdge from = " << from->representative - << " to = " << to->representative; - if (from->representative == to->representative) { - return; - } - to->merge_nodes.insert(from->merge_nodes.begin(), from->merge_nodes.end()); - from->merge_nodes.clear(); - to->switch_nodes.insert(from->switch_nodes.begin(), from->switch_nodes.end()); - from->switch_nodes.clear(); - - for (Cluster* from_out : from->out_nodes) { - from_out->in_nodes.erase(from); - if (from_out->representative != to->representative) { - from_out->in_nodes.insert(to); - to->out_nodes.insert(from_out); +Status FunctionalizeCond::Join(const ForwardFlowNode& src_state, + const Node* dst, ForwardFlowNode* dst_state) { + TF_RET_CHECK(dst_state->branch != Branch::kBoth && + dst_state->branch != Branch::kNumBranchTypes) + << "Unexpected/Invalid branch type: Merging " + << Branch_Name(src_state.branch) << " with " + << Branch_Name(dst_state->branch); + if (dst_state->branch == Branch::kNeither) { + dst_state->branch = src_state.branch; + } else if (src_state.branch != dst_state->branch && + src_state.branch != Branch::kNeither) { + if (IsMerge(dst)) { + dst_state->branch = Branch::kBoth; + } else { + return errors::Internal("Illegal merge: ", src_state.ToString(), " with ", + dst_state->ToString(), " for ", + dst->DebugString()); } } - from->out_nodes.clear(); + ++dst_state->count; + return Status::OK(); +} - for (Cluster* from_in : from->in_nodes) { - from_in->out_nodes.erase(from); - if (from_in->representative != to->representative) { - from_in->out_nodes.insert(to); - to->in_nodes.insert(from_in); +std::vector FunctionalizeCond::DetermineSwitchOrder() { + std::vector dead_switches; + std::vector switch_order; + DFS(*graph_, nullptr, [this, &dead_switches, &switch_order](Node* n) { + if (IsSwitch(n)) { + if (IsDeadSwitch(n)) { + dead_switches.push_back(n); + } else { + switch_order.push_back(n); + } } + }); + + // Remove all dead switch nodes. + for (Node* n : dead_switches) { + graph_->RemoveNode(n); } - from->in_nodes.clear(); - to->in_nodes.erase(from); - to->out_nodes.erase(from); - clusters_.at(to->representative).Merge(&clusters_.at(from->representative)); - from->visited = true; + return switch_order; +} - if (remove_from_graph) { - clustered_graph_.erase(from->representative); +Status FunctionalizeCond::FunctionalizeInternal() { + std::vector switch_order = DetermineSwitchOrder(); + // If there are no switch nodes, then terminate. + if (switch_order.empty()) { + return Status::OK(); } -} -void FunctionalizeCond::CreateClusteredGraph() { - auto update_cluster_for_node = [this](Node* node) -> Cluster& { - ClusterHandle repr = Representative(node); - Cluster& cluster_node = clustered_graph_[repr]; - cluster_node.representative = repr; - for (const Node* in : node->in_nodes()) { - ClusterHandle other_repr = Representative(in); - // Skip source, sink and internal edges. - if (other_repr == repr) { - continue; - } - Cluster& cluster_node_in = clustered_graph_[other_repr]; - cluster_node.in_nodes.insert(&cluster_node_in); - cluster_node_in.out_nodes.insert(&cluster_node); - cluster_node_in.representative = other_repr; - } - for (const Node* out : node->out_nodes()) { - ClusterHandle other_repr = Representative(out); - // Skip source, sink and internal edges. - if (other_repr == repr) { - continue; - } - Cluster& cluster_node_out = clustered_graph_[other_repr]; - cluster_node.out_nodes.insert(&cluster_node_out); - cluster_node_out.in_nodes.insert(&cluster_node); - cluster_node_out.representative = other_repr; - } - return cluster_node; + struct PredicateSwitches { + explicit PredicateSwitches(Node* predicate) : predicate(predicate) {} + + Node* predicate; + std::vector switches; }; - update_cluster_for_node(graph_->source_node()); - for (Node* node : switch_nodes_) { - update_cluster_for_node(node).switch_nodes.insert(node); - } - for (Node* node : merge_nodes_) { - update_cluster_for_node(node).merge_nodes.insert(node); - } // Merge Switch nodes with common predicate. - std::unordered_map> predicate_to_switch; - for (Node* node : switch_nodes_) { - Node* tmp; - TF_CHECK_OK(node->input_node(1, &tmp)); - predicate_to_switch[tmp].push_back(node); - } - for (auto kv : predicate_to_switch) { - Cluster& first = clustered_graph_.at(Representative(kv.second.front())); - for (Node* switch_node : kv.second) { - ClusterHandle handle = Representative(switch_node); - Cluster& cluster = clustered_graph_.at(handle); - ContractEdge(&cluster, &first, /*remove_from_graph=*/true); + std::vector predicate_switch_order; + std::unordered_map predicate_index; + // The nodes in switch_order are in reverse topological order, but the + // clustered switches need not be (i.e., when considered as a cluster one + // element of a cluster may be later in the topological order than another + // node whose cluster is later in the topological order of clustered + // switches). + for (auto it = switch_order.rbegin(); it != switch_order.rend(); ++it) { + Node* pred; + TF_CHECK_OK((*it)->input_node(1, &pred)); + if (predicate_index.find(pred) == predicate_index.end()) { + predicate_index[pred] = predicate_switch_order.size(); + predicate_switch_order.emplace_back(pred); } + predicate_switch_order[predicate_index[pred]].switches.push_back(*it); } - // Merge Merge nodes with common input together. - for (Node* node : merge_nodes_) { - Cluster& cluster = clustered_graph_.at(Representative(node)); - for (const Node* in : node->in_nodes()) { - if (!in->IsOp()) { - continue; - } - Cluster& cluster_node_in = clustered_graph_.at(Representative(in)); - // ContractEdge can modify out_nodes of cluster_node_in, so traverse - // over out_nodes assuming it does. - for (auto it = cluster_node_in.out_nodes.begin(); - it != cluster_node_in.out_nodes.end();) { - if (!(*it)->merge_nodes.empty()) { - ContractEdge(*it++, &cluster, /*remove_from_graph=*/true); - } else { - ++it; - } - } - } - } + // Iterate from innermost set of clustered switches to outermost, replacing + // matching switch->merge subgraphs with single XlaIf nodes. + for (auto it = predicate_switch_order.rbegin(); + it != predicate_switch_order.rend(); ++it) { + auto& ps = *it; + VLOG(3) << "Flow down from: " << ps.predicate->name() << " -> " + << NodesToString(ps.switches); - VLOG(3) << "Graph with clusters: " << DebugString(*graph_, &clusters_); - VLOG(3) << "ClusteredGraph: " << DebugString(clustered_graph_); -} + std::unordered_map branch_map; + std::unordered_set frontier; -gtl::optional FunctionalizeCond::GetSwitchCluster( - const Cluster& merge_cluster) { - VLOG(3) << "GetSwitchCluster for " << merge_cluster.representative; - gtl::optional switch_cluster; - if (merge_cluster.in_nodes.size() > 2) { - return gtl::nullopt; - } - for (Cluster* in : merge_cluster.in_nodes) { - Cluster* cluster = in; - if (in->switch_nodes.empty()) { - if (in->in_nodes.size() != 1 || in->out_nodes.size() != 1) { - return gtl::nullopt; - } - // There is only a single `in` cluster. - cluster = *in->in_nodes.begin(); - } - if (cluster->switch_nodes.empty()) { - return gtl::nullopt; - } + std::vector stack = ps.switches; + std::vector visited(graph_->num_node_ids(), false); + while (!stack.empty()) { + Node* n = stack.back(); + stack.pop_back(); - if (switch_cluster.has_value() && *switch_cluster != cluster) { - return gtl::nullopt; - } else { - switch_cluster = cluster; - } - } - return switch_cluster; -} + if (visited[n->id()]) { + continue; + } + visited[n->id()] = true; -xla::StatusOr FunctionalizeCond::DetermineCondArgs( - const Cluster& merge_cluster, const Cluster& switch_cluster) { - VLOG(2) << "DetermineCondArgs for " << merge_cluster.representative - << " with switch cluster " << switch_cluster.representative; - CondArgs ret; - auto feeds_into_branch_cluster = [&](Node* switch_cluster) { - for (Node* out : switch_cluster->out_nodes()) { - ClusterHandle repr = Representative(out); - if (repr == merge_cluster.representative) { - return true; + // Propagate branch state along each edge of a switch node. + bool sink_only = true; + for (const Edge* e : n->out_edges()) { + Node* out = e->dst(); + if (!out->IsOp()) { + continue; + } + sink_only = false; + // Propagate branch information. + ForwardFlowNode& ffn = branch_map[out]; + if (IsSwitch(n)) { + int index = e->IsControlEdge() ? Branch::kNeither : e->src_output(); + TF_RETURN_IF_ERROR(Join(ForwardFlowNode(Branch(index)), out, &ffn)); + } else { + TF_RETURN_IF_ERROR(Join(branch_map[n], out, &ffn)); + } + if (IsMerge(out)) { + if (out->in_edges().size() == ffn.count) { + frontier.insert(out); + } + } else if (!visited[out->id()] && ffn.count == out->in_edges().size()) { + // If all predecessors are dominated by the switch nodes, then add + // the output to the stack. + stack.push_back(out); + } } - for (Cluster* in : merge_cluster.in_nodes) { - if (repr == in->representative) { - return true; + if (sink_only) { + if (!IsIdentity(n)) { + VLOG(1) << "Feeding into sink: " << n->DebugString(); } } } - return false; - }; - for (Node* switch_cluster_node : switch_cluster.switch_nodes) { - if (!feeds_into_branch_cluster(switch_cluster_node)) { - continue; - } - Node* tmp; - TF_RETURN_IF_ERROR(switch_cluster_node->input_node(1, &tmp)); - if (ret.conditional == nullptr) { - ret.conditional = tmp; - } else if (ret.conditional != tmp) { - return errors::Unimplemented( - "Switch statements with different conditionals cannot be " - "converted into functional conditional."); + TF_RETURN_IF_ERROR(ValidBranchMapAndFrontier(branch_map, frontier)); + VLOG(2) << "FunctionalizeControlFlow (before XlaIf conversion): " + << dump_graph::DumpGraphToFile("functionalize_bc", *graph_); + std::vector switch_nodes(ps.switches); + std::sort(switch_nodes.begin(), switch_nodes.end(), CondCmp()); + std::vector merge_nodes(frontier.begin(), frontier.end()); + std::sort(merge_nodes.begin(), merge_nodes.end(), CondCmp()); + TF_RETURN_IF_ERROR(ConvertCorrespondingMergeToXlaIf( + switch_nodes, merge_nodes, ps.predicate)); + for (auto& del_kv : branch_map) { + graph_->RemoveNode(del_kv.first); + } + for (Node* node : switch_nodes) { + graph_->RemoveNode(node); } - ret.args.insert(switch_cluster_node); + VLOG(2) << "FunctionalizeControlFlow (after XlaIf conversion): " + << dump_graph::DumpGraphToFile("functionalize_ac", *graph_); } - return ret; + return Status::OK(); } xla::StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( - const CondArgs& cond_args, const Cluster& merge_cluster, - const std::vector& outputs) { - VLOG(2) << "Build if op for " << NodesToString(merge_cluster.merge_nodes) - << " with input " << NodesToString(cond_args.args); + const std::vector& switch_nodes, + const std::vector& merge_nodes, Node* predicate) { + VLOG(2) << "Build if op for " << NodesToString(merge_nodes) << " with input " + << NodesToString(switch_nodes); NodeDef if_def; // Create a new If node using the name of the merge node. - NodeDefBuilder builder( - strings::StrCat((*merge_cluster.merge_nodes.begin())->name(), "_If"), - "XlaIf"); + NodeDefBuilder builder(strings::StrCat(predicate->name(), "_If"), "XlaIf"); string branch[] = {"else_branch", "then_branch"}; for (int i = 0; i < 2; ++i) { static std::atomic sequence_num(0LL); @@ -1038,8 +854,7 @@ xla::StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( body_name.set_name( strings::StrCat("_functionalize_if_", branch[i], "_", id)); auto body = xla::MakeUnique(graph_->op_registry()); - TF_RETURN_IF_ERROR( - ExtractBody(cond_args, merge_cluster, outputs, i, body.get())); + TF_RETURN_IF_ERROR(ExtractBody(switch_nodes, merge_nodes, i, body.get())); VLOG(3) << "Body " << branch[i] << ": " << DebugString(body.get()); FunctionDef body_fdef; TF_RETURN_IF_ERROR(GraphToFunctionDef(*body, body_name.name(), &body_fdef)); @@ -1050,7 +865,7 @@ xla::StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( // Build input type. std::vector inputs; DataTypeVector in_arg_types; - for (const Node* arg : cond_args.args) { + for (const Node* arg : switch_nodes) { const Edge* in_edge; TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); if (in_edge->IsControlEdge()) { @@ -1066,17 +881,17 @@ xla::StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( // Build output type. DataTypeVector out_type; - for (const Node* merge : merge_cluster.merge_nodes) { + for (const Node* merge : merge_nodes) { DataType dtype = merge->output_type(0); out_type.push_back(dtype); } builder.Attr("Tout", out_type); builder.Attr("Tcond", DT_BOOL); - builder.Device(cond_args.conditional->assigned_device_name()); + builder.Device(predicate->assigned_device_name()); // Conditional should be the first input ... - builder.Input(NodeDefBuilder::NodeOut(cond_args.conditional->name(), 0, - cond_args.conditional->output_type(0))); + builder.Input( + NodeDefBuilder::NodeOut(predicate->name(), 0, predicate->output_type(0))); // ... followed by the other inputs. builder.Input(inputs); @@ -1085,53 +900,15 @@ xla::StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( return if_node; } -void FunctionalizeCond::RemoveClusterNodes(Cluster* cluster) { - VLOG(3) << "RemoveClusterNodes for " << cluster->representative; - ClusterHandle repr = cluster->representative; - std::deque to_delete; - for (Node* node : graph_->nodes()) { - if (Representative(node) == repr) { - to_delete.push_back(node); - } - } - for (Node* n : to_delete) { - graph_->RemoveNode(n); - } -} - -template -void FunctionalizeCond::RemoveUnusedArgs(const T& args) { - VLOG(2) << "RemoveUnusedArgs among: " << NodesToString(args); - - std::deque to_delete; - for (Node* arg : args) { - if (IsDeadSwitch(arg)) { - to_delete.push_back(arg); - for (Node* n : arg->out_nodes()) { - to_delete.push_back(n); - } - } - } - for (Node* n : to_delete) { - switch_nodes_.erase(n); - auto it = clustered_graph_.find(Representative(n)); - if (it != clustered_graph_.end()) { - it->second.switch_nodes.erase(n); - } - graph_->RemoveNode(n); - } -} - -Status FunctionalizeCond::ExtractBody(const CondArgs& cond_args, - const Cluster& merge_cluster, - const std::vector& outputs, +Status FunctionalizeCond::ExtractBody(const std::vector& switch_nodes, + const std::vector& merge_nodes, int input_edge, Graph* body) { - VLOG(2) << "ExtractBody for " << merge_cluster.representative - << " along edge " << input_edge; + VLOG(2) << "ExtractBody for " << NodesToString(merge_nodes) << " along edge " + << input_edge; std::vector squash_src_outputs(graph_->num_node_ids(), false); std::vector node_map(graph_->num_node_ids(), nullptr); int arg_count = 0; - for (const auto* arg : cond_args.args) { + for (const auto* arg : switch_nodes) { DataType dtype = arg->input_type(0); TF_ASSIGN_OR_RETURN(Node * arg_node, BuildArgNode(body, dtype, arg_count++)); @@ -1140,9 +917,9 @@ Status FunctionalizeCond::ExtractBody(const CondArgs& cond_args, } std::vector stack; - stack.reserve(outputs.size()); - for (int j = 0; j < outputs.size(); ++j) { - Node* node = outputs[j]; + stack.reserve(switch_nodes.size()); + for (int j = 0; j < merge_nodes.size(); ++j) { + Node* node = merge_nodes[j]; TF_ASSIGN_OR_RETURN(node_map.at(node->id()), BuildRetvalNode(body, node->output_type(0), /*index=*/j)); @@ -1153,7 +930,8 @@ Status FunctionalizeCond::ExtractBody(const CondArgs& cond_args, node_map.at(in->id()) = body->CopyNode(in); } - if (cond_args.args.find(in) == cond_args.args.end()) { + if (std::find(switch_nodes.begin(), switch_nodes.end(), in) == + switch_nodes.end()) { body->AddEdge(node_map.at(in->id()), in_edge->src_output(), node_map.at(node->id()), 0); } else { @@ -1168,12 +946,12 @@ Status FunctionalizeCond::ExtractBody(const CondArgs& cond_args, body); } -Status FunctionalizeCond::AddInputEdges(const CondArgs& cond_args, - Node* if_node) { +Status FunctionalizeCond::AddInputEdges(const std::vector& cond_args, + Node* predicate, Node* if_node) { VLOG(3) << "AddInputEdges for " << if_node->name(); int i = 0; - graph_->AddEdge(cond_args.conditional, 0, if_node, i++); - for (const Node* arg : cond_args.args) { + graph_->AddEdge(predicate, 0, if_node, i++); + for (const Node* arg : cond_args) { const Edge* in_edge; TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); if (in_edge->IsControlEdge()) { @@ -1210,173 +988,26 @@ Status FunctionalizeCond::AddOutputEdges(const std::vector& outputs, return Status::OK(); } -void FunctionalizeCond::RemoveMergeNodes(Cluster* merge_cluster) { - VLOG(3) << "RemoveMergeNodes for " << merge_cluster->representative; - // Remove all merge nodes now dead post extraction of If. - for (auto it = merge_cluster->merge_nodes.begin(); - it != merge_cluster->merge_nodes.end();) { - Node* node = *it; - graph_->RemoveNode(node); - merge_cluster->merge_nodes.erase(*it++); - } -} - -Status FunctionalizeCond::RemoveTrivialMerge(Cluster* merge_cluster) { - Cluster* switch_cluster = *merge_cluster->in_nodes.begin(); - if (switch_cluster->switch_nodes.empty()) { - return errors::FailedPrecondition( - "Not a trivial merge: no Switch node feeding into Merge node"); - } - - for (auto it = merge_cluster->merge_nodes.begin(); - it != merge_cluster->merge_nodes.end();) { - // We have the following structure: - // Op -> Switch -> Merge -> Consumer - // and we want to transform it to: - // Op -> Consumer - Node* merge_node = *it; - Node* switch_node; - const Edge* in = nullptr; - TF_RETURN_IF_ERROR(merge_node->input_node(0, &switch_node)); - TF_RETURN_IF_ERROR(switch_node->input_edge(0, &in)); - for (auto out : merge_node->out_edges()) { - int src_output = out->dst_input() == Graph::kControlSlot - ? Graph::kControlSlot - : in->src_output(); - graph_->AddEdge(in->src(), src_output, out->dst(), out->dst_input()); - } - graph_->RemoveNode(*it++); - } - RemoveUnusedArgs(switch_cluster->switch_nodes); - - return Status::OK(); -} - -Status FunctionalizeCond::ConvertMergeToXlaIf(Cluster* merge_cluster) { - VLOG(1) << "ConvertMergeToXlaIf for " << merge_cluster->representative; - gtl::optional switch_cluster = GetSwitchCluster(*merge_cluster); - if (!switch_cluster.has_value()) { - return errors::FailedPrecondition( - "Merge cluster was not part of a simple conditional in the clustered " - "graph. Graph nodes in merge cluster ", - NodesToString(merge_cluster->merge_nodes)); - } - TF_ASSIGN_OR_RETURN(auto cond_args, - DetermineCondArgs(*merge_cluster, **switch_cluster)); - - // Sort the outputs by ID to produce more stable output. - std::vector outputs(merge_cluster->merge_nodes.begin(), - merge_cluster->merge_nodes.end()); - std::sort(outputs.begin(), outputs.end(), CondArgs::CondCmp()); +Status FunctionalizeCond::ConvertCorrespondingMergeToXlaIf( + const std::vector& switch_nodes, + const std::vector& merge_nodes, Node* predicate) { + VLOG(1) << "ConvertMergeToXlaIf for " << NodesToString(switch_nodes) << " -> " + << NodesToString(merge_nodes); // Extract bodies and builds a If operator. TF_ASSIGN_OR_RETURN(Node * if_node, - BuildAndAddXlaIfOp(cond_args, *merge_cluster, outputs)); - TF_RETURN_IF_ERROR(AddInputEdges(cond_args, if_node)); - TF_RETURN_IF_ERROR(AddOutputEdges(outputs, if_node)); - - // Remove the old nodes from the graph_ and contract the edges of the - // clustered graph. - for (auto in : merge_cluster->in_nodes) { - if (in != *switch_cluster) { - RemoveClusterNodes(in); - } - } - RemoveMergeNodes(merge_cluster); - RemoveUnusedArgs(cond_args.args); - auto in_nodes = merge_cluster->in_nodes; - for (auto it = in_nodes.begin(); it != in_nodes.end();) { - ContractEdge(*it++, merge_cluster); - } - ContractEdge(*switch_cluster, merge_cluster); - clusters_[if_node].Get() = ClusterHandle(merge_cluster->representative); + BuildAndAddXlaIfOp(switch_nodes, merge_nodes, predicate)); + TF_RETURN_IF_ERROR(AddInputEdges(switch_nodes, predicate, if_node)); + TF_RETURN_IF_ERROR(AddOutputEdges(merge_nodes, if_node)); return Status::OK(); } -std::vector> -FunctionalizeCond::SortedMergeNodes() { - VLOG(2) << "ProcessClusteredGraph"; - std::stack> stack; - // Initialize with the source node. - stack.push({0, &clustered_graph_[ClusterHandle(Graph::kSourceId)]}); - - // Perform a depth-first traversal of the clustered graph computing the - // switch-merge depth. - std::vector> queue; - std::unordered_set visited; - while (!stack.empty()) { - Cluster* n = stack.top().second; - size_t depth = stack.top().first; - stack.pop(); - - auto inserted = visited.insert(n); - if (!inserted.second) { - continue; - } - - size_t new_depth = depth; - if (!n->merge_nodes.empty()) { - queue.emplace_back(depth, n); - --new_depth; - } - if (!n->switch_nodes.empty()) { - ++new_depth; - } - for (Cluster* e : n->out_nodes) { - stack.emplace(new_depth, e); - } - } - - // Sort in reverse order of switch-merge depth with ties broken by the - // ClusterHandle. - std::sort(queue.begin(), queue.end(), - [](const std::pair& lhs, - const std::pair& rhs) { - return std::tie(lhs.first, lhs.second->representative) > - std::tie(rhs.first, rhs.second->representative); - }); - - return queue; -} - Status FunctionalizeCond::Functionalize(Graph* graph, FunctionLibraryDefinition* library) { VLOG(1) << "FunctionalizeCond::Functionalize"; FunctionalizeCond fc(graph, library); - fc.CreateClusters(); - if (fc.NoConditionals()) { - return Status::OK(); - } - fc.CreateClusteredGraph(); - - auto queue = fc.SortedMergeNodes(); - for (auto it = queue.begin(); it != queue.end();) { - Cluster* merge_cluster = (*it).second; - ++it; - if (merge_cluster->in_nodes.size() == 1) { - TF_RETURN_IF_ERROR(fc.RemoveTrivialMerge(merge_cluster)); - } else { - TF_RETURN_IF_ERROR(fc.ConvertMergeToXlaIf(merge_cluster)); - } - - // Contract newly Merge free merge_cluster with incoming nodes without - // Switch or Merge nodes. - std::vector in_nodes(merge_cluster->in_nodes.begin(), - merge_cluster->in_nodes.end()); - for (auto in : in_nodes) { - if (in->merge_nodes.empty() && in->switch_nodes.empty()) { - fc.ContractEdge(in, merge_cluster); - } - } - } - - if (!fc.switch_nodes_.empty()) { - return errors::Internal( - "Failed to functionalize control flow with Switch nodes remaining: ", - NodesToString(fc.switch_nodes_)); - } - return Status::OK(); + return fc.FunctionalizeInternal(); } } // namespace diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index 01d2b282751f387cfa9c8887cdeb48090c96bff4..71f12a13339b9b5495631b8f9350579f6a0785a3 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -109,7 +109,7 @@ TEST(FunctionalizeControlFlow, Conditional) { auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); - auto if_op = ops::XlaIf(scope.WithOpName("cond/Merge_If"), less, + auto if_op = ops::XlaIf(scope.WithOpName("cond/Less_If"), less, std::initializer_list{less, y, x}, then_fn, else_fn, {DT_INT32}); GraphDef expected; diff --git a/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md b/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md new file mode 100644 index 0000000000000000000000000000000000000000..82b3b46a2f1e97001d1e0c6b993ec243170bc7d8 --- /dev/null +++ b/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md @@ -0,0 +1,242 @@ +**Supported operators for device: XLA_CPU_JIT** + +Operator | Type Constraint +------------------------------------- | --------------- +`Abs` | `T={double,float,int32,int64}` +`Acosh` | `T={complex64,double,float}` +`Add` | `T={complex64,double,float,int32,int64}` +`AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`All` | `Tidx={int32,int64}` +`Angle` | `Tout={double,float}`
`T={complex64}` +`Any` | `Tidx={int32,int64}` +`ApproximateEqual` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`ArgMax` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={float}` +`ArgMin` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Asinh` | `T={complex64,double,float}` +`AssignAddVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` +`AssignSubVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` +`AssignVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Atan2` | `T={double,float}` +`Atanh` | `T={complex64,double,float}` +`AvgPool` | `T={double,float}` +`AvgPool3D` | `T={double,float}` +`AvgPool3DGrad` | `T={double,float}` +`AvgPoolGrad` | `T={double,float}` +`BatchMatMul` | `T={complex64,double,float,int32}` +`BatchToSpace` | `Tidx={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`BatchToSpaceND` | `Tcrops={int32,int64}`
`Tblock_shape={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`BiasAdd` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`BiasAddGrad` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`BiasAddV1` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`BitwiseAnd` | `T={int32,int64,uint32,uint64}` +`BitwiseOr` | `T={int32,int64,uint32,uint64}` +`BroadcastArgs` | `T={int32,int64}` +`BroadcastGradientArgs` | `T={int32,int64}` +`Cast` | `DstT={bool,complex64,double,float,int32,int64,uint32,uint64}`
`SrcT={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Ceil` | `T={double,float}` +`Cholesky` | `T={complex64,double,float}` +`Complex` | `Tout={complex64}`
`T={double,float}` +`ComplexAbs` | `Tout={double,float}`
`T={complex64}` +`Concat` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ConcatOffset` | +`ConcatV2` | `Tidx={int32}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Conj` | `T={complex64}` +`Const` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ControlTrigger` | +`Conv2D` | `T={float}` +`Conv2DBackpropFilter` | `T={float}` +`Conv2DBackpropInput` | `T={float}` +`Conv3D` | `T={double,float}` +`Conv3DBackpropFilterV2` | `T={double,float}` +`Conv3DBackpropInputV2` | `T={double,float}` +`Cos` | `T={complex64,double,float}` +`Cosh` | `T={complex64,double,float}` +`Cross` | `T={double,float,int32,int64,uint32,uint64}` +`Cumprod` | `Tidx={int32,int64}`
`T={float}` +`Cumsum` | `Tidx={int32,int64}`
`T={float}` +`DepthToSpace` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`DepthwiseConv2dNative` | `T={double,float}` +`DepthwiseConv2dNativeBackpropFilter` | `T={double,float}` +`DepthwiseConv2dNativeBackpropInput` | `T={double,float}` +`Diag` | `T={complex64,double,float,int32,int64}` +`DiagPart` | `T={complex64,double,float,int32,int64}` +`Div` | `T={complex64,double,float,int32,int64}` +`DynamicStitch` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Elu` | `T={double,float}` +`EluGrad` | `T={double,float}` +`Equal` | `T={bool,complex64,double,float,int32,int64}` +`Exp` | `T={complex64,double,float}` +`ExpandDims` | `Tdim={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Expm1` | `T={complex64,double,float}` +`Fill` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Floor` | `T={double,float}` +`FloorDiv` | `T={complex64,double,float,int32,int64}` +`FloorMod` | `T={double,float,int32,int64}` +`FusedBatchNorm` | `T={float}` +`FusedBatchNormGrad` | `T={float}` +`FusedBatchNormGradV2` | `U={float}`
`T={float}` +`FusedBatchNormV2` | `U={float}`
`T={float}` +`Gather` | `Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` +`GatherV2` | `Taxis={int32,int64}`
`Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Greater` | `T={double,float,int32,int64,uint32,uint64}` +`GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}` +`Identity` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`IdentityN` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Imag` | `Tout={double,float}`
`T={complex64}` +`Inv` | `T={complex64,double,float,int32,int64}` +`Invert` | `T={int32,int64,uint32,uint64}` +`InvertPermutation` | `T={int32}` +`IsFinite` | `T={double,float}` +`IsInf` | `T={double,float}` +`IsNan` | `T={double,float}` +`L2Loss` | `T={double,float}` +`LRN` | `T={float}` +`LRNGrad` | `T={float}` +`LeftShift` | `T={int32,int64,uint32,uint64}` +`Less` | `T={double,float,int32,int64,uint32,uint64}` +`LessEqual` | `T={double,float,int32,int64,uint32,uint64}` +`LinSpace` | `Tidx={int32,int64}`
`T={double,float}` +`Log` | `T={complex64,double,float}` +`Log1p` | `T={complex64,double,float}` +`LogSoftmax` | `T={double,float}` +`LogicalAnd` | +`LogicalNot` | +`LogicalOr` | +`MatMul` | `T={complex64,double,float}` +`MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Max` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`MaxPool` | `T={double,float,int32,int64}` +`MaxPool3D` | `T={float}` +`MaxPool3DGrad` | `TInput={float}`
`T={float}` +`MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}` +`Maximum` | `T={double,float,int32,int64}` +`Mean` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Min` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Minimum` | `T={double,float,int32,int64}` +`MirrorPad` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Mod` | `T={double,float,int32,int64}` +`Mul` | `T={complex64,double,float,int32,int64}` +`Multinomial` | `output_dtype={int32,int64}`
`T={double,float,int32,int64,uint32,uint64}` +`Neg` | `T={complex64,double,float,int32,int64}` +`NoOp` | +`NotEqual` | `T={bool,complex64,double,float,int32,int64}` +`OneHot` | `TI={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`OnesLike` | `T={bool,complex64,double,float,int32,int64}` +`Pack` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Pad` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`PadV2` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ParallelDynamicStitch` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Pow` | `T={complex64,double,float,int32,int64}` +`PreventGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Prod` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`QuantizeAndDequantizeV2` | `T={double,float}` +`RandomStandardNormal` | `dtype={float}` +`RandomUniform` | `T={int32,int64}`
`dtype={double,float}` +`RandomUniformInt` | `T={int32,int64}`
`Tout={int32,int64}` +`Range` | `Tidx={double,float,int32,int64}` +`Rank` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ReadVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Real` | `Tout={double,float}`
`T={complex64}` +`RealDiv` | `T={complex64,double,float,int32,int64}` +`Reciprocal` | `T={complex64,double,float,int32,int64}` +`ReciprocalGrad` | `T={complex64,double,float}` +`Relu` | `T={double,float,int32,int64,uint32,uint64}` +`Relu6` | `T={double,float,int32,int64,uint32,uint64}` +`Relu6Grad` | `T={double,float,int32,int64,uint32,uint64}` +`ReluGrad` | `T={double,float,int32,int64,uint32,uint64}` +`Reshape` | `Tshape={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ResourceApplyAdagrad` | `T={double,float}` +`ResourceApplyAdam` | `T={double,float}` +`ResourceApplyFtrl` | `T={double,float}` +`ResourceApplyFtrlV2` | `T={double,float}` +`ResourceApplyGradientDescent` | `T={double,float}` +`ResourceApplyMomentum` | `T={double,float}` +`ResourceApplyRMSProp` | `T={double,float}` +`ResourceGather` | `Tindices={int32,int64}`
`dtype={complex64,double,float,int32,int64,uint32,uint64}` +`ResourceStridedSliceAssign` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Reverse` | `T={bool,complex64,double,float,int32,int64}` +`ReverseV2` | `T={bool,complex64,double,float,int32,int64}`
`Tidx={int32,int64}` +`RightShift` | `T={int32,int64,uint32,uint64}` +`Rint` | `T={double,float}` +`Round` | `T={complex64,double,float,int32,int64}` +`Rsqrt` | `T={complex64,double,float}` +`RsqrtGrad` | `T={complex64,double,float}` +`Select` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Selu` | `T={double,float}` +`SeluGrad` | `T={double,float}` +`Shape` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ShapeN` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Sigmoid` | `T={complex64,double,float}` +`SigmoidGrad` | `T={complex64,double,float}` +`Sign` | `T={complex64,double,float,int32,int64}` +`Sin` | `T={complex64,double,float}` +`Sinh` | `T={complex64,double,float}` +`Size` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Slice` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Softmax` | `T={double,float}` +`SoftmaxCrossEntropyWithLogits` | `T={double,float}` +`Softplus` | `T={double,float,int32,int64,uint32,uint64}` +`SoftplusGrad` | `T={double,float,int32,int64,uint32,uint64}` +`Softsign` | `T={double,float,int32,int64,uint32,uint64}` +`SoftsignGrad` | `T={double,float,int32,int64,uint32,uint64}` +`SpaceToBatch` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SpaceToBatchND` | `Tblock_shape={int32,int64}`
`Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SpaceToDepth` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SparseMatMul` | `Tb={float}`
`Ta={float}` +`SparseSoftmaxCrossEntropyWithLogits` | `Tlabels={int32,int64}`
`T={double,float}` +`Split` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SplitV` | `Tlen={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Sqrt` | `T={complex64,double,float}` +`SqrtGrad` | `T={complex64,double,float}` +`Square` | `T={complex64,double,float,int32,int64}` +`SquaredDifference` | `T={complex64,double,float,int32,int64}` +`Squeeze` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StackCloseV2` | +`StackPopV2` | `elem_type={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StackPushV2` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StackV2` | `elem_type={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StatelessRandomNormal` | `Tseed={int32}`
`T={int32,int64}`
`dtype={float}` +`StatelessRandomUniform` | `Tseed={int32}`
`T={int32,int64}`
`dtype={float}` +`StopGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StridedSlice` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StridedSliceGrad` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Sub` | `T={complex64,double,float,int32,int64}` +`Sum` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`SymbolicGradient` | `Tout={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Tin={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Tan` | `T={complex64,double,float,int32,int64}` +`Tanh` | `T={complex64,double,float}` +`TanhGrad` | `T={complex64,double,float}` +`TensorArrayCloseV3` | +`TensorArrayConcatV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayGatherV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayGradV3` | +`TensorArrayReadV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayScatterV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArraySizeV3` | +`TensorArraySplitV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayWriteV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Tile` | `Tmultiples={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Transpose` | `Tperm={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TruncateDiv` | `T={complex64,double,float,int32,int64}` +`TruncateMod` | `T={double,float,int32,int64}` +`TruncatedNormal` | `T={int32,int64}`
`dtype={double,float}` +`Unpack` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`UnsortedSegmentSum` | `Tnumsegments={int32,int64}`
`Tindices={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`VarIsInitializedOp` | +`VariableShape` | `out_type={int32,int64}` +`XlaWhile` | `T={bool,complex64,double,float,int32,int64,resource,uint32,uint64}` +`ZerosLike` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_Arg` | `T={bool,complex64,double,float,int32,int64,resource,uint32,uint64}` +`_ArrayToList` | `out_types={bool,complex64,double,float,int32,int64,uint32,uint64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_ListToArray` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Tin={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_Retval` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_XLARecv` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_XLASend` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` + +To regenerate this table, run: + +```shell +bazel run -c opt -- tensorflow/compiler/tf2xla:tf2xla_supported_ops --device=XLA_CPU_JIT +``` diff --git a/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md b/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md new file mode 100644 index 0000000000000000000000000000000000000000..d4b7621ad2858fe17e93d292dd807e4f7c1c336b --- /dev/null +++ b/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md @@ -0,0 +1,238 @@ +**Supported operators for device: XLA_GPU_JIT** + +Operator | Type Constraint +------------------------------------- | --------------- +`Abs` | `T={double,float,int32,int64}` +`Acosh` | `T={complex64,double,float}` +`Add` | `T={complex64,double,float,int32,int64}` +`AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`All` | `Tidx={int32,int64}` +`Angle` | `Tout={double,float}`
`T={complex64}` +`Any` | `Tidx={int32,int64}` +`ApproximateEqual` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`ArgMax` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`ArgMin` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Asinh` | `T={complex64,double,float}` +`AssignAddVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` +`AssignSubVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` +`AssignVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Atan2` | `T={double,float}` +`Atanh` | `T={complex64,double,float}` +`AvgPool` | `T={double,float}` +`AvgPool3D` | `T={double,float}` +`AvgPool3DGrad` | `T={double,float}` +`AvgPoolGrad` | `T={double,float}` +`BatchMatMul` | `T={complex64,double,float,int32}` +`BatchToSpace` | `Tidx={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`BatchToSpaceND` | `Tcrops={int32,int64}`
`Tblock_shape={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`BiasAdd` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`BiasAddGrad` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`BiasAddV1` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`BitwiseAnd` | `T={int32,int64,uint32,uint64}` +`BitwiseOr` | `T={int32,int64,uint32,uint64}` +`BroadcastArgs` | `T={int32,int64}` +`BroadcastGradientArgs` | `T={int32,int64}` +`Cast` | `DstT={bool,complex64,double,float,int32,int64,uint32,uint64}`
`SrcT={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Ceil` | `T={double,float}` +`Cholesky` | `T={complex64,double,float}` +`Complex` | `Tout={complex64}`
`T={double,float}` +`ComplexAbs` | `Tout={double,float}`
`T={complex64}` +`Concat` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ConcatOffset` | +`ConcatV2` | `Tidx={int32}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Conj` | `T={complex64}` +`Const` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ControlTrigger` | +`Conv2D` | `T={float}` +`Conv2DBackpropFilter` | `T={float}` +`Conv2DBackpropInput` | `T={float}` +`Conv3D` | `T={double,float}` +`Conv3DBackpropFilterV2` | `T={double,float}` +`Conv3DBackpropInputV2` | `T={double,float}` +`Cos` | `T={complex64,double,float}` +`Cosh` | `T={complex64,double,float}` +`Cross` | `T={double,float,int32,int64,uint32,uint64}` +`Cumprod` | `Tidx={int32,int64}`
`T={float}` +`Cumsum` | `Tidx={int32,int64}`
`T={float}` +`DepthToSpace` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`DepthwiseConv2dNative` | `T={double,float}` +`DepthwiseConv2dNativeBackpropFilter` | `T={double,float}` +`DepthwiseConv2dNativeBackpropInput` | `T={double,float}` +`Diag` | `T={complex64,double,float,int32,int64}` +`DiagPart` | `T={complex64,double,float,int32,int64}` +`Div` | `T={complex64,double,float,int32,int64}` +`DynamicStitch` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Elu` | `T={double,float}` +`EluGrad` | `T={double,float}` +`Equal` | `T={bool,complex64,double,float,int32,int64}` +`Exp` | `T={complex64,double,float}` +`ExpandDims` | `Tdim={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Expm1` | `T={complex64,double,float}` +`Fill` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Floor` | `T={double,float}` +`FloorDiv` | `T={complex64,double,float,int32,int64}` +`FloorMod` | `T={double,float,int32,int64}` +`FusedBatchNorm` | `T={float}` +`FusedBatchNormGrad` | `T={float}` +`FusedBatchNormGradV2` | `U={float}`
`T={float}` +`FusedBatchNormV2` | `U={float}`
`T={float}` +`Gather` | `Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` +`GatherV2` | `Taxis={int32,int64}`
`Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Greater` | `T={double,float,int32,int64,uint32,uint64}` +`GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}` +`Identity` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`IdentityN` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Imag` | `Tout={double,float}`
`T={complex64}` +`Inv` | `T={complex64,double,float,int32,int64}` +`Invert` | `T={int32,int64,uint32,uint64}` +`InvertPermutation` | `T={int32}` +`IsFinite` | `T={double,float}` +`IsInf` | `T={double,float}` +`IsNan` | `T={double,float}` +`L2Loss` | `T={double,float}` +`LRN` | `T={float}` +`LRNGrad` | `T={float}` +`LeftShift` | `T={int32,int64,uint32,uint64}` +`Less` | `T={double,float,int32,int64,uint32,uint64}` +`LessEqual` | `T={double,float,int32,int64,uint32,uint64}` +`LinSpace` | `Tidx={int32,int64}`
`T={double,float}` +`Log` | `T={complex64,double,float}` +`Log1p` | `T={complex64,double,float}` +`LogSoftmax` | `T={double,float}` +`LogicalAnd` | +`LogicalNot` | +`LogicalOr` | +`MatMul` | `T={complex64,double,float}` +`MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Max` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`MaxPool` | `T={double,float,int32,int64}` +`MaxPool3D` | `T={float}` +`MaxPool3DGrad` | `TInput={float}`
`T={float}` +`MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}` +`Maximum` | `T={double,float,int32,int64}` +`Mean` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Min` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Minimum` | `T={double,float,int32,int64}` +`MirrorPad` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Mod` | `T={double,float,int32,int64}` +`Mul` | `T={complex64,double,float,int32,int64}` +`Multinomial` | `output_dtype={int32,int64}`
`T={double,float,int32,int64,uint32,uint64}` +`Neg` | `T={complex64,double,float,int32,int64}` +`NoOp` | +`NotEqual` | `T={bool,complex64,double,float,int32,int64}` +`OneHot` | `TI={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`OnesLike` | `T={bool,complex64,double,float,int32,int64}` +`Pack` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Pad` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`PadV2` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ParallelDynamicStitch` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Pow` | `T={complex64,double,float,int32,int64}` +`PreventGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Prod` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`QuantizeAndDequantizeV2` | `T={double,float}` +`Range` | `Tidx={double,float,int32,int64}` +`Rank` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ReadVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Real` | `Tout={double,float}`
`T={complex64}` +`RealDiv` | `T={complex64,double,float,int32,int64}` +`Reciprocal` | `T={complex64,double,float,int32,int64}` +`ReciprocalGrad` | `T={complex64,double,float}` +`Relu` | `T={double,float,int32,int64,uint32,uint64}` +`Relu6` | `T={double,float,int32,int64,uint32,uint64}` +`Relu6Grad` | `T={double,float,int32,int64,uint32,uint64}` +`ReluGrad` | `T={double,float,int32,int64,uint32,uint64}` +`Reshape` | `Tshape={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ResourceApplyAdagrad` | `T={double,float}` +`ResourceApplyAdam` | `T={double,float}` +`ResourceApplyFtrl` | `T={double,float}` +`ResourceApplyFtrlV2` | `T={double,float}` +`ResourceApplyGradientDescent` | `T={double,float}` +`ResourceApplyMomentum` | `T={double,float}` +`ResourceApplyRMSProp` | `T={double,float}` +`ResourceGather` | `Tindices={int32,int64}`
`dtype={complex64,double,float,int32,int64,uint32,uint64}` +`ResourceStridedSliceAssign` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Reverse` | `T={bool,complex64,double,float,int32,int64}` +`ReverseV2` | `T={bool,complex64,double,float,int32,int64}`
`Tidx={int32,int64}` +`RightShift` | `T={int32,int64,uint32,uint64}` +`Rint` | `T={double,float}` +`Round` | `T={complex64,double,float,int32,int64}` +`Rsqrt` | `T={complex64,double,float}` +`RsqrtGrad` | `T={complex64,double,float}` +`Select` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Selu` | `T={double,float}` +`SeluGrad` | `T={double,float}` +`Shape` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ShapeN` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Sigmoid` | `T={complex64,double,float}` +`SigmoidGrad` | `T={complex64,double,float}` +`Sign` | `T={complex64,double,float,int32,int64}` +`Sin` | `T={complex64,double,float}` +`Sinh` | `T={complex64,double,float}` +`Size` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Slice` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Softmax` | `T={double,float}` +`SoftmaxCrossEntropyWithLogits` | `T={double,float}` +`Softplus` | `T={double,float,int32,int64,uint32,uint64}` +`SoftplusGrad` | `T={double,float,int32,int64,uint32,uint64}` +`Softsign` | `T={double,float,int32,int64,uint32,uint64}` +`SoftsignGrad` | `T={double,float,int32,int64,uint32,uint64}` +`SpaceToBatch` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SpaceToBatchND` | `Tblock_shape={int32,int64}`
`Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SpaceToDepth` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SparseMatMul` | `Tb={float}`
`Ta={float}` +`SparseSoftmaxCrossEntropyWithLogits` | `Tlabels={int32,int64}`
`T={double,float}` +`Split` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SplitV` | `Tlen={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Sqrt` | `T={complex64,double,float}` +`SqrtGrad` | `T={complex64,double,float}` +`Square` | `T={complex64,double,float,int32,int64}` +`SquaredDifference` | `T={complex64,double,float,int32,int64}` +`Squeeze` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StackCloseV2` | +`StackPopV2` | `elem_type={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StackPushV2` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StackV2` | `elem_type={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StatelessRandomNormal` | `Tseed={int32}`
`T={int32,int64}`
`dtype={float}` +`StatelessRandomUniform` | `Tseed={int32}`
`T={int32,int64}`
`dtype={float}` +`StopGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StridedSlice` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StridedSliceGrad` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Sub` | `T={complex64,double,float,int32,int64}` +`Sum` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`SymbolicGradient` | `Tout={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Tin={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Tan` | `T={complex64,double,float,int32,int64}` +`Tanh` | `T={complex64,double,float}` +`TanhGrad` | `T={complex64,double,float}` +`TensorArrayCloseV3` | +`TensorArrayConcatV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayGatherV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayGradV3` | +`TensorArrayReadV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayScatterV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArraySizeV3` | +`TensorArraySplitV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayWriteV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Tile` | `Tmultiples={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Transpose` | `Tperm={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TruncateDiv` | `T={complex64,double,float,int32,int64}` +`TruncateMod` | `T={double,float,int32,int64}` +`Unpack` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`UnsortedSegmentSum` | `Tnumsegments={int32,int64}`
`Tindices={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`VarIsInitializedOp` | +`VariableShape` | `out_type={int32,int64}` +`XlaWhile` | `T={bool,complex64,double,float,int32,int64,resource,uint32,uint64}` +`ZerosLike` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_Arg` | `T={bool,complex64,double,float,int32,int64,resource,uint32,uint64}` +`_ArrayToList` | `out_types={bool,complex64,double,float,int32,int64,uint32,uint64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_ListToArray` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Tin={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_Retval` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_XLARecv` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_XLASend` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` + +To regenerate this table, run: + +```shell +bazel run -c opt -- tensorflow/compiler/tf2xla:tf2xla_supported_ops --device=XLA_GPU_JIT +``` diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 948d7f0b407124613dbd58efb2e189b5fca4f6ed..3e24cf042e17ad4e212d82ac4f24fec06a6c780f 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -35,6 +35,7 @@ tf_kernel_library( "gather_op.cc", "gather_op_helpers.h", "identity_op.cc", + "image_resize_ops.cc", "index_ops.cc", "l2loss_op.cc", "lrn_ops.cc", @@ -54,17 +55,20 @@ tf_kernel_library( "reshape_op.cc", "retval_op.cc", "reverse_op.cc", + "scan_ops.cc", "segment_reduction_ops.cc", "select_op.cc", "sendrecv_ops.cc", "sequence_ops.cc", "shape_op.cc", + "shape_util.cc", "slice_op.cc", "softmax_op.cc", "spacetobatch_op.cc", "spacetodepth_op.cc", "split_op.cc", "stack_ops.cc", + "stateless_random_ops.cc", "strided_slice_op.cc", "tensor_array_ops.cc", "tile_ops.cc", @@ -77,6 +81,7 @@ tf_kernel_library( hdrs = [ "gather_op.h", "index_ops.h", + "shape_util.h", ], deps = [ ":while_op", @@ -84,7 +89,9 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/lib:batch_dot", "//tensorflow/compiler/tf2xla/lib:cholesky", + "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/tf2xla/ops:sendrecv_ops", + "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", @@ -93,9 +100,11 @@ tf_kernel_library( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/core:framework", + "//tensorflow/core:image_ops_op_lib", "//tensorflow/core:lib", "//tensorflow/core:linalg_ops_op_lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:stateless_random_ops_op_lib", "//tensorflow/core/kernels:bounds_check", "//tensorflow/core/kernels:concat_lib", "//tensorflow/core/kernels:constant_op", diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index 248e9d111e556dcdd75581aa6562a66fc8b57063..a249b1869f547f8e5aa725f9f5cf391b10429928 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ // XLA implementation of BatchNorm operations. -#include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -26,43 +26,63 @@ namespace { class FusedBatchNormOp : public XlaOpKernel { public: explicit FusedBatchNormOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - string data_format; OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("is_training", &is_training_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); - TensorFormat tensor_format; - if (ctx->GetAttr("data_format", &data_format).ok()) { - OP_REQUIRES(ctx, FormatFromString(data_format, &tensor_format), - errors::InvalidArgument("Invalid data format")); - OP_REQUIRES( - ctx, (tensor_format == FORMAT_NHWC || tensor_format == FORMAT_NCHW), - errors::InvalidArgument("Not supported format")); - feature_index_ = GetTensorFeatureDimIndex(/*num_dims=*/4, tensor_format); - } + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES( + ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format: ", data_format_str)); + OP_REQUIRES(ctx, + (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW), + errors::InvalidArgument( + "Unsupported data format ", ToString(data_format_), + "; supported formats are NHWC and NCHW")); } void Compile(XlaOpKernelContext* ctx) override { + xla::PrimitiveType input_type; + OP_REQUIRES_OK(ctx, + DataTypeToPrimitiveType(ctx->input_type(0), &input_type)); + xla::PrimitiveType scale_type; + OP_REQUIRES_OK(ctx, + DataTypeToPrimitiveType(ctx->input_type(1), &scale_type)); + + xla::ComputationBuilder* builder = ctx->builder(); + + xla::ComputationDataHandle input = ctx->Input(0); + TensorShape input_shape = ctx->InputShape(0); + + int feature_index = + GetTensorFeatureDimIndex(input_shape.dims(), data_format_); + + // TODO(b/69928690): support mixed precision in the XLA batch normalization + // operators. As a workaround, cast everything to the statistics type (which + // may be more precise than the input type). + input = builder->ConvertElementType(input, scale_type); + if (is_training_) { - xla::ComputationDataHandle output = ctx->builder()->BatchNormTraining( - ctx->Input(0), ctx->Input(1), ctx->Input(2), epsilon_, - feature_index_); + xla::ComputationDataHandle output = builder->BatchNormTraining( + input, ctx->Input(1), ctx->Input(2), epsilon_, feature_index); // In training mode, outputs the normalized value as well as the // calculated mean and variance. - for (int i = 0; i < 3; i++) { - ctx->SetOutput(i, ctx->builder()->GetTupleElement(output, i)); - } + ctx->SetOutput(0, builder->ConvertElementType( + builder->GetTupleElement(output, 0), input_type)); + ctx->SetOutput(1, builder->GetTupleElement(output, 1)); + ctx->SetOutput(2, builder->GetTupleElement(output, 2)); + // Output 3 and 4 for "FusedBatchNorm" are currently marked as "reserved // space 1 & 2". They are used to pass the per-batch mean and // variance to the gradient. Here we maintain the same behavior by setting // them to the mean and variance calculated by BatchNormTraining. - ctx->SetOutput(3, ctx->builder()->GetTupleElement(output, 1)); - ctx->SetOutput(4, ctx->builder()->GetTupleElement(output, 2)); + ctx->SetOutput(3, builder->GetTupleElement(output, 1)); + ctx->SetOutput(4, builder->GetTupleElement(output, 2)); } else { - xla::ComputationDataHandle output = ctx->builder()->BatchNormInference( - ctx->Input(0), ctx->Input(1), ctx->Input(2), ctx->Input(3), - ctx->Input(4), epsilon_, feature_index_); - ctx->SetOutput(0, output); + xla::ComputationDataHandle output = builder->BatchNormInference( + input, ctx->Input(1), ctx->Input(2), ctx->Input(3), ctx->Input(4), + epsilon_, feature_index); + ctx->SetOutput(0, builder->ConvertElementType(output, input_type)); // Directly send input to output as mean and variance in inference mode. ctx->SetOutput(1, ctx->Input(3)); ctx->SetOutput(2, ctx->Input(4)); @@ -73,55 +93,113 @@ class FusedBatchNormOp : public XlaOpKernel { private: float epsilon_; - int64 feature_index_; + TensorFormat data_format_; bool is_training_; }; REGISTER_XLA_OP(Name("FusedBatchNorm"), FusedBatchNormOp); +REGISTER_XLA_OP(Name("FusedBatchNormV2"), FusedBatchNormOp); class FusedBatchNormGradOp : public XlaOpKernel { public: explicit FusedBatchNormGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - string data_format; OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); - bool is_training; - OP_REQUIRES_OK(ctx, ctx->GetAttr("is_training", &is_training)); - CHECK(is_training) << "FusedBatchNormGradOp with is_training=False cannot " - "be used with XLA for now!"; - TensorFormat tensor_format; - if (ctx->GetAttr("data_format", &data_format).ok()) { - OP_REQUIRES(ctx, FormatFromString(data_format, &tensor_format), - errors::InvalidArgument("Invalid data format")); - OP_REQUIRES( - ctx, (tensor_format == FORMAT_NHWC || tensor_format == FORMAT_NCHW), - errors::InvalidArgument("Not supported format")); - feature_index_ = GetTensorFeatureDimIndex(4, tensor_format); - } + OP_REQUIRES_OK(ctx, ctx->GetAttr("is_training", &is_training_)); + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES( + ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format: ", data_format_str)); + OP_REQUIRES(ctx, + (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW), + errors::InvalidArgument( + "Unsupported data format ", ToString(data_format_), + "; supported formats are NHWC and NCHW")); } void Compile(XlaOpKernelContext* ctx) override { - auto grad_output = ctx->Input(0); - auto activation = ctx->Input(1); + xla::ComputationBuilder* b = ctx->builder(); + + auto grad_backprop = ctx->Input(0); + auto activations = ctx->Input(1); auto scale = ctx->Input(2); auto mean = ctx->Input(3); auto var = ctx->Input(4); - xla::ComputationDataHandle output = ctx->builder()->BatchNormGrad( - activation, scale, mean, var, grad_output, epsilon_, feature_index_); - for (int i = 0; i < 3; i++) { - ctx->SetOutput(i, ctx->builder()->GetTupleElement(output, i)); + TensorShape input_shape = ctx->InputShape(0); + int feature_index = + GetTensorFeatureDimIndex(input_shape.dims(), data_format_); + + DataType input_dtype = ctx->input_type(0); + DataType scale_dtype = ctx->input_type(2); + xla::PrimitiveType input_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_dtype, &input_type)); + xla::PrimitiveType scale_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(scale_dtype, &scale_type)); + + // TODO(b/69928690): support mixed precision in the XLA batch normalization + // operators. For now, cast everything to the statistics type (which + // may be more precise than the input type). + grad_backprop = b->ConvertElementType(grad_backprop, scale_type); + activations = b->ConvertElementType(activations, scale_type); + + xla::ComputationDataHandle x_backprop; + xla::ComputationDataHandle scale_backprop; + xla::ComputationDataHandle offset_backprop; + if (is_training_) { + xla::ComputationDataHandle output = + b->BatchNormGrad(activations, scale, mean, var, grad_backprop, + epsilon_, feature_index); + + x_backprop = b->GetTupleElement(output, 0); + scale_backprop = b->GetTupleElement(output, 1); + offset_backprop = b->GetTupleElement(output, 2); + } else { + // Reduce over all dimensions except the feature dim. + std::vector reduction_dims(input_shape.dims() - 1); + std::iota(reduction_dims.begin(), reduction_dims.begin() + feature_index, + 0); + std::iota(reduction_dims.begin() + feature_index, reduction_dims.end(), + feature_index + 1); + // offset_backprop = sum(y_backprop) + // scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var + + // epsilon)) + // x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon)) + offset_backprop = + b->Reduce(grad_backprop, XlaHelpers::Zero(b, scale_dtype), + *ctx->GetOrCreateAdd(scale_dtype), reduction_dims); + + // scratch1 = rsqrt(pop_var + epsilon) + auto neg_half = XlaHelpers::FloatLiteral(b, scale_dtype, -0.5); + auto scratch1 = + b->Pow(b->Add(var, b->ConstantR0(epsilon_)), neg_half); + + // scratch2 = sum(y_backprop * (x - mean)) + auto scratch2 = b->Reduce( + b->Mul(grad_backprop, b->Sub(activations, mean, {feature_index})), + XlaHelpers::Zero(b, scale_dtype), *ctx->GetOrCreateAdd(scale_dtype), + reduction_dims); + + x_backprop = + b->Mul(grad_backprop, b->Mul(scratch1, scale), {feature_index}); + scale_backprop = b->Mul(scratch1, scratch2); } - ctx->SetOutput(3, ctx->builder()->GetTupleElement(output, 1)); - ctx->SetOutput(4, ctx->builder()->GetTupleElement(output, 2)); + + ctx->SetOutput(0, b->ConvertElementType(x_backprop, input_type)); + ctx->SetOutput(1, scale_backprop); + ctx->SetOutput(2, offset_backprop); + ctx->SetConstantOutput(3, Tensor(scale_dtype, {})); + ctx->SetConstantOutput(4, Tensor(scale_dtype, {})); } private: + TensorFormat data_format_; float epsilon_; - int64 feature_index_; + bool is_training_; }; REGISTER_XLA_OP(Name("FusedBatchNormGrad"), FusedBatchNormGradOp); +REGISTER_XLA_OP(Name("FusedBatchNormGradV2"), FusedBatchNormGradOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 1de91924326464338352b1ac9edf77141f25ad35..2436a6074a11ad66387b232dd1c5aa135875bfc3 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" namespace tensorflow { namespace { @@ -75,7 +76,7 @@ static xla::ComputationDataHandle FloorDivImpl(xla::ComputationBuilder* b, auto abs_y = b->Abs(y); auto t = b->Neg(b->Sub(b->Add(abs_x, abs_y), one)); auto result = b->Select(different_sign, b->Div(t, abs_y), b->Div(x, y)); - if (dtype == DT_FLOAT || dtype == DT_DOUBLE) { + if (DataTypeIsFloating(dtype)) { result = b->Floor(result); } return result; diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index 885f716afafca7ba23770e38f6693eed1ba50982..aaddbe811c6fbf6da296640eb5a75e82b2fedcfa 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -46,72 +46,130 @@ TensorShape ExpandedFilterShapeForDepthwiseConvolution( return expanded_shape; } +// Broadcast zeros to ExpandedFilterShapeForDepthwiseConvolution. +xla::ComputationDataHandle CreateExpandedZero( + const TensorShape& filter_shape, DataType dtype, + xla::ComputationBuilder* builder) { + TensorShape expanded_filter_shape = + ExpandedFilterShapeForDepthwiseConvolution(filter_shape); + return builder->Broadcast(XlaHelpers::Zero(builder, dtype), + expanded_filter_shape.dim_sizes()); +} + +// Create a mask for depthwise convolution that will make a normal convolution +// produce the same results as a depthwise convolution. For a [2, 2, 3, 2] +// depthwise filter this returns a [2, 2, 3, 6] tesnsor +// 1 1 0 0 0 0 1 1 0 0 0 0 +// 0 0 1 1 0 0 0 0 1 1 0 0 +// 0 0 0 0 1 1 0 0 0 0 1 1 +// +// 1 1 0 0 0 0 1 1 0 0 0 0 +// 0 0 1 1 0 0 0 0 1 1 0 0 +// 0 0 0 0 1 1 0 0 0 0 1 1 +// +// The first step is to create a one tensor, A, that is [3] +// 0 1 2 +// +// and another tensor, B, that is [3 * 2] +// 0 1 2 3 4 5 +// +// and divide B it by 2 to get +// 0 0 1 1 2 2 +// +// then we broadcast the B to [2, 2, 3, 3 * 2] +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// +// Finally compare A and broadcasted B in dimension 2 amd return the result at +// the beginning of the comment. +xla::ComputationDataHandle CreateExpandedFilterMask( + const TensorShape& filter_shape, xla::ComputationBuilder* builder) { + TensorShape expanded_filter_shape = + ExpandedFilterShapeForDepthwiseConvolution(filter_shape); + int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1); + int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2); + + // Create a M sized linspace and an M*N sized linspace that will be + // broadcasted into perpendicular dimensions and compared. + xla::ComputationDataHandle input_feature_iota; + // DT_INT32 Iota will always return status::OK(). + TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, input_feature, + &input_feature_iota)); + xla::ComputationDataHandle expanded_feature_iota; + TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, + input_feature * depthwise_multiplier, + &expanded_feature_iota)); + + // Divide the M*N sized linspace by the depthwise_multiplier to create + // [0 0 1 1 2 2] in the example in the function comment. + expanded_feature_iota = + builder->Div(expanded_feature_iota, + XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32, + depthwise_multiplier)); + + // Broadcast the N*M linspace to [H, W, ..., M, M*N]. + auto expanded_feature_broadcast_dims = expanded_filter_shape.dim_sizes(); + expanded_feature_broadcast_dims.pop_back(); + auto broadcasted_expanded_feature_iota = builder->Broadcast( + expanded_feature_iota, expanded_feature_broadcast_dims); + + // Compare the broadcasted linspace to the input feature linspace in the + // input feature dimension to create a diagonal predicate. + return builder->Eq(broadcasted_expanded_feature_iota, input_feature_iota, + {expanded_filter_shape.dims() - 2}); +} + // Expands a filter of shape [H, W, ..., M, N] to [H, W, ..., M, M*N] by adding // zeros for the cross-depth filters. Used to build a depthwise convolution. xla::ComputationDataHandle ExpandFilterForDepthwiseConvolution( const TensorShape& filter_shape, DataType dtype, const xla::ComputationDataHandle& filter, xla::ComputationBuilder* builder) { - // Filter has shape [H, W, ..., M, N] - // Dilate to [H, W, ..., M*M, N] using M inter-element padding, and then - // reshape to [H, W, ..., M, M*N]. - int num_spatial_dims = filter_shape.dims() - 2; - const int64 in_depth = filter_shape.dim_size(num_spatial_dims); - xla::PaddingConfig padding = xla::MakeNoPaddingConfig(filter_shape.dims()); - padding.mutable_dimensions(num_spatial_dims)->set_interior_padding(in_depth); - auto dilated_filter = - builder->Pad(filter, XlaHelpers::Zero(builder, dtype), padding); - + int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1); + int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2); TensorShape expanded_filter_shape = ExpandedFilterShapeForDepthwiseConvolution(filter_shape); - return builder->Reshape(dilated_filter, expanded_filter_shape.dim_sizes()); + + // Create a [H, W, ..., 1, N*M] reshape of the filter. + TensorShape implicit_broadcast_filter_shape = expanded_filter_shape; + implicit_broadcast_filter_shape.set_dim( + implicit_broadcast_filter_shape.dims() - 2, 1); + implicit_broadcast_filter_shape.set_dim( + implicit_broadcast_filter_shape.dims() - 1, + depthwise_multiplier * input_feature); + auto implicit_broadcast_filter = + builder->Reshape(filter, implicit_broadcast_filter_shape.dim_sizes()); + + // Broadcast the filter to [H, W, ..., M, M*N]. + auto expanded_zero = CreateExpandedZero(filter_shape, dtype, builder); + auto expanded_filter = builder->Add(implicit_broadcast_filter, expanded_zero); + + // If the filter mask is set, choose the broadcasted filter, othwerwise, + // choose zero. + return builder->Select(CreateExpandedFilterMask(filter_shape, builder), + expanded_filter, expanded_zero); } // Inverse of ExpandFilterForDepthwiseConvolution. xla::ComputationDataHandle ContractFilterForDepthwiseBackprop( - const TensorShape& filter_shape, DataType dtype, + XlaOpKernelContext* ctx, const TensorShape& filter_shape, DataType dtype, const xla::ComputationDataHandle& filter_backprop, xla::ComputationBuilder* builder) { - int num_spatial_dims = filter_shape.dims() - 2; - - // Reshape to [H, W, ..., M*M, N] - TensorShape shape = filter_shape; - int64 in_depth = filter_shape.dim_size(num_spatial_dims); - shape.set_dim(num_spatial_dims, in_depth * in_depth); - auto reshaped = builder->Reshape(filter_backprop, shape.dim_sizes()); - - std::vector zeros(filter_shape.dims()); - std::vector strides(filter_shape.dims(), 1LL); - strides[num_spatial_dims] = in_depth + 1; - return builder->Slice(reshaped, zeros, shape.dim_sizes(), strides); - - // Alternate implementation for backends without strided Slice() support. - // TODO(phawkins): Remove when all backends support strided slice. - // // Pad [..., M * (M + 1), N] - // xla::PaddingConfig config = - // xla::MakeNoPaddingConfig(filter_shape.dims()); - // config.mutable_dimensions(num_spatial_dims) - // ->set_edge_padding_high(in_depth); - // auto zero = XlaHelpers::Zero(builder, dtype); - // auto padded = builder->Pad(reshaped, zero, config); - // - // // Reshape to [..., M, M + 1, N] - // shape = filter_shape; - // shape.set_dim(num_spatial_dims, in_depth); - // shape.set_dim(num_spatial_dims + 1, in_depth + 1); - // int64 out_depth = filter_shape.dim_size(num_spatial_dims + 1); - // shape.AddDim(out_depth); - // reshaped = builder->Reshape(padded, shape.dim_sizes()); - // - // // Slice to [..., M, 1, N] - // std::vector zeros(shape.dims()); - // std::vector strides(shape.dims(), 1LL); - // shape.set_dim(num_spatial_dims + 1, 1); - // auto sliced = builder->Slice(reshaped, zeros, shape.dim_sizes(), - // strides); - // - // // Reshape to [..., M, N] - // return builder->Reshape(sliced, filter_shape.dim_sizes()); + TensorShape expanded_filter_shape = + ExpandedFilterShapeForDepthwiseConvolution(filter_shape); + auto masked_expanded_filter = builder->Select( + CreateExpandedFilterMask(filter_shape, builder), filter_backprop, + CreateExpandedZero(filter_shape, dtype, builder)); + return builder->Reshape( + builder->Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype), + *ctx->GetOrCreateAdd(dtype), + {expanded_filter_shape.dims() - 2}), + filter_shape.dim_sizes()); } class ConvOp : public XlaOpKernel { @@ -121,6 +179,7 @@ class ConvOp : public XlaOpKernel { : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims), depthwise_(depthwise) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); @@ -144,6 +203,23 @@ class ConvOp : public XlaOpKernel { errors::Unimplemented("Current implementation does not yet support " "strides in the batch and depth dimensions.")); + OP_REQUIRES(ctx, dilations_.size() == num_dims(), + errors::InvalidArgument("Dilations field must " + "specify ", + num_dims(), " dimensions")); + OP_REQUIRES( + ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1, + errors::Unimplemented("Current implementation does not yet support " + "dilations in the batch and depth dimensions.")); + for (int i = 0; i < num_spatial_dims_; ++i) { + int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + OP_REQUIRES( + ctx, dilations_[input_dim] == 1, + errors::Unimplemented("Current implementation does not yet support " + "dilations in the ", + i, "th spatial dimension.")); + } + const TensorShape input_shape = ctx->InputShape(0); // Input filter is of the following dimensions: // [ filter_rows, filter_cols, ..., in_depth, out_depth] @@ -184,10 +260,11 @@ class ConvOp : public XlaOpKernel { dims.set_input_feature_dimension(feature_dim); dims.set_output_feature_dimension(feature_dim); for (int i = 0; i < num_spatial_dims_; ++i) { - int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - dims.add_spatial_dimensions(input_dim); + const int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + dims.add_input_spatial_dimensions(dim); dims.add_kernel_spatial_dimensions(i); - window_strides.push_back(strides_.at(input_dim)); + dims.add_output_spatial_dimensions(dim); + window_strides.push_back(strides_.at(dim)); } dims.set_kernel_input_feature_dimension(num_spatial_dims_); dims.set_kernel_output_feature_dimension(num_spatial_dims_ + 1); @@ -203,6 +280,7 @@ class ConvOp : public XlaOpKernel { protected: const int num_spatial_dims_; const bool depthwise_; + std::vector dilations_; std::vector strides_; Padding padding_; TensorFormat data_format_ = FORMAT_NHWC; @@ -240,6 +318,7 @@ class ConvBackpropInputOp : public XlaOpKernel { : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims), depthwise_(depthwise) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); string data_format; @@ -262,6 +341,23 @@ class ConvBackpropInputOp : public XlaOpKernel { errors::Unimplemented("Current implementation does not yet support " "strides in the batch and depth dimensions.")); + OP_REQUIRES(ctx, dilations_.size() == num_dims(), + errors::InvalidArgument("Dilations field must " + "specify ", + num_dims(), " dimensions")); + OP_REQUIRES( + ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1, + errors::Unimplemented("Current implementation does not yet support " + "dilations in the batch and depth dimensions.")); + for (int i = 0; i < num_spatial_dims_; ++i) { + int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + OP_REQUIRES( + ctx, dilations_[input_dim] == 1, + errors::Unimplemented("Current implementation does not yet support " + "dilations in the ", + i, "th spatial dimension.")); + } + TensorShape input_shape; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape)); @@ -302,9 +398,10 @@ class ConvBackpropInputOp : public XlaOpKernel { std::vector lhs_dilation(num_spatial_dims_); std::vector ones(num_spatial_dims_, 1); for (int i = 0; i < num_spatial_dims_; ++i) { - dnums.add_spatial_dimensions( - GetTensorSpatialDimIndex(num_dims(), data_format_, i)); + int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + dnums.add_input_spatial_dimensions(dim); dnums.add_kernel_spatial_dimensions(i); + dnums.add_output_spatial_dimensions(dim); kernel_spatial_dims[i] = i; padding[i] = {dims.spatial_dims[i].pad_before, @@ -334,6 +431,7 @@ class ConvBackpropInputOp : public XlaOpKernel { protected: const int num_spatial_dims_; const bool depthwise_; + std::vector dilations_; std::vector strides_; Padding padding_; TensorFormat data_format_ = FORMAT_NHWC; @@ -371,6 +469,7 @@ class ConvBackpropFilterOp : public XlaOpKernel { : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims), depthwise_(depthwise) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); string data_format; @@ -390,6 +489,23 @@ class ConvBackpropFilterOp : public XlaOpKernel { errors::InvalidArgument("Current implementation does not yet support " "strides in the batch and depth dimensions.")); + OP_REQUIRES(ctx, dilations_.size() == num_dims(), + errors::InvalidArgument("Dilations field must " + "specify ", + num_dims(), " dimensions")); + OP_REQUIRES( + ctx, dilations_[n_dim] == 1 && dilations_[c_dim] == 1, + errors::Unimplemented("Current implementation does not yet support " + "dilations in the batch and depth dimensions.")); + for (int i = 0; i < num_spatial_dims_; ++i) { + int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + OP_REQUIRES( + ctx, dilations_[input_dim] == 1, + errors::Unimplemented("Current implementation does not yet support " + "dilations in the ", + i, "th spatial dimension.")); + } + const TensorShape activations_shape = ctx->InputShape(0); TensorShape filter_shape; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_shape)); @@ -424,9 +540,7 @@ class ConvBackpropFilterOp : public XlaOpKernel { // Swap n_dim and c_dim in the activations. dnums.set_input_batch_dimension(c_dim); - dnums.set_output_batch_dimension(c_dim); dnums.set_input_feature_dimension(n_dim); - dnums.set_output_feature_dimension(n_dim); // The gradients become the RHS of the convolution. // The gradients have shape [batch, out_rows, out_cols, ..., out_depth] @@ -438,9 +552,16 @@ class ConvBackpropFilterOp : public XlaOpKernel { std::vector rhs_dilation(num_spatial_dims_); std::vector ones(num_spatial_dims_, 1); + // Tensorflow filter shape is [ H, W, ..., inC, outC ]. for (int i = 0; i < num_spatial_dims_; ++i) { - int dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - dnums.add_spatial_dimensions(dim); + dnums.add_output_spatial_dimensions(i); + } + dnums.set_output_batch_dimension(num_spatial_dims_); + dnums.set_output_feature_dimension(num_spatial_dims_ + 1); + + for (int i = 0; i < num_spatial_dims_; ++i) { + int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + dnums.add_input_spatial_dimensions(dim); dnums.add_kernel_spatial_dimensions(dim); // We will also need to pad the input with zeros such that after the @@ -498,31 +619,17 @@ class ConvBackpropFilterOp : public XlaOpKernel { /*window_strides=*/ones, padding, /*lhs_dilation=*/ones, rhs_dilation, dnums); - // The layout of filter_backprop will match the layout of - // padded_activations - // and so will have layout: [out_feature, h, w, ..., in_feature] - // Tensorflow filter shape is [ H, W, ..., inC, outC ], so we transpose the - // output. - std::vector transpose_dims; - transpose_dims.reserve(num_dims()); - for (int i = 0; i < num_spatial_dims_; ++i) { - transpose_dims.push_back(dnums.spatial_dimensions(i)); - } - transpose_dims.push_back(c_dim); - transpose_dims.push_back(n_dim); - xla::ComputationDataHandle filter_backprop_reshaped = - b->Transpose(filter_backprop, transpose_dims); - if (depthwise_) { - filter_backprop_reshaped = ContractFilterForDepthwiseBackprop( - filter_shape, ctx->input_type(0), filter_backprop_reshaped, b); + filter_backprop = ContractFilterForDepthwiseBackprop( + ctx, filter_shape, ctx->input_type(0), filter_backprop, b); } - ctx->SetOutput(0, filter_backprop_reshaped); + ctx->SetOutput(0, filter_backprop); } protected: const int num_spatial_dims_; const bool depthwise_; + std::vector dilations_; std::vector strides_; Padding padding_; TensorFormat data_format_ = FORMAT_NHWC; diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc index a4ea65ea89e348cb77412efb0c5c0fcb1a9f33f3..96d7809f7995634b6bc31ab801b93526d9da7e6f 100644 --- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/util/tensor_format.h" namespace tensorflow { namespace { @@ -23,6 +24,16 @@ namespace { class DepthToSpaceOp : public XlaOpKernel { public: explicit DepthToSpaceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format")); + + OP_REQUIRES(ctx, data_format_ == FORMAT_NCHW || data_format_ == FORMAT_NHWC, + errors::InvalidArgument("Unsupported data format ", + ToString(data_format_), + "; expected formats NHWC or NCHW")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_)); OP_REQUIRES( ctx, block_size_ > 1, @@ -31,18 +42,79 @@ class DepthToSpaceOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const TensorShape input_tensor_shape = ctx->InputShape(0); - // The input is presumed to be [batch, height, width, depth] int input_rank = input_tensor_shape.dims(); static const int kRequiredDims = 4; OP_REQUIRES(ctx, kRequiredDims == input_rank, - errors::InvalidArgument("Input rank should be: ", kRequiredDims, - " instead of: ", input_rank)); + errors::InvalidArgument("Input rank should be ", kRequiredDims, + "; got: ", input_rank)); const gtl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); xla::ComputationBuilder* b = ctx->builder(); xla::ComputationDataHandle input = ctx->Input(0); + int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_); + int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format_); + + std::vector reshaped_shape; + std::vector transpose_order; + std::vector output_shape; + reshaped_shape.reserve(input_rank); + transpose_order.reserve(input_rank); + output_shape.reserve(input_rank); + if (data_format_ == FORMAT_NHWC) { + reshaped_shape.push_back(input_shape[0]); + for (int i = 0; i < num_spatial_dims; ++i) { + reshaped_shape.push_back(input_shape[1 + i]); + } + int64 block_elems = 1; + for (int i = 0; i < num_spatial_dims; ++i) { + reshaped_shape.push_back(block_size_); + block_elems *= block_size_; + } + reshaped_shape.push_back(input_shape[feature_dim] / block_elems); + + transpose_order.push_back(0); + for (int i = 0; i < num_spatial_dims; ++i) { + transpose_order.push_back(i + 1); + transpose_order.push_back(i + 1 + num_spatial_dims); + } + transpose_order.push_back(feature_dim + num_spatial_dims); + + output_shape.push_back(input_shape[0]); + for (int i = 0; i < num_spatial_dims; ++i) { + output_shape.push_back(input_shape[1 + i] * block_size_); + } + output_shape.push_back(input_shape[feature_dim] / block_elems); + } else { + // NCHW format. + reshaped_shape.push_back(input_shape[0]); + int64 block_elems = 1; + for (int i = 0; i < num_spatial_dims; ++i) { + reshaped_shape.push_back(block_size_); + block_elems *= block_size_; + } + reshaped_shape.push_back(input_shape[feature_dim] / block_elems); + for (int i = 0; i < num_spatial_dims; ++i) { + reshaped_shape.push_back(input_shape[2 + i]); + } + + transpose_order.push_back(0); + transpose_order.push_back(1 + num_spatial_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + transpose_order.push_back(2 + num_spatial_dims + i); + transpose_order.push_back(1 + i); + } + + output_shape.push_back(input_shape[0]); + output_shape.push_back(input_shape[feature_dim] / block_elems); + for (int i = 0; i < num_spatial_dims; ++i) { + output_shape.push_back(input_shape[2 + i] * block_size_); + } + } + + // Note: comments are given in NHWC format; NCHW is similar with a different + // dimension order. // 1. Reshape `input` to `reshaped` of shape: // // [batch, @@ -51,14 +123,14 @@ class DepthToSpaceOp : public XlaOpKernel { // block_size_, // block_size_, // depth / (block_size_ * block_size_)] - OP_REQUIRES(ctx, input_shape[3] % (block_size_ * block_size_) == 0, + OP_REQUIRES(ctx, + input_shape[feature_dim] % (block_size_ * block_size_) == 0, errors::InvalidArgument( "Input depth dimension (", input_shape[3], ") is not divisible by square of the block size (", block_size_, ")")); - xla::ComputationDataHandle reshaped = b->Reshape( - input, {input_shape[0], input_shape[1], input_shape[2], block_size_, - block_size_, input_shape[3] / (block_size_ * block_size_)}); + + xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape); // 2. Permute dimensions of `reshaped` to produce // `permuted_reshaped` of shape: @@ -70,7 +142,7 @@ class DepthToSpaceOp : public XlaOpKernel { // block_size_, // depth / (block_size_ * block_size_)] xla::ComputationDataHandle permuted_reshaped = - b->Transpose(reshaped, {0, 1, 3, 2, 4, 5}); + b->Transpose(reshaped, transpose_order); // 3. Reshape `permuted_reshaped` to flatten `block_shape` into the // batch dimension, producing an output tensor of shape: @@ -80,15 +152,14 @@ class DepthToSpaceOp : public XlaOpKernel { // input_shape[2] * block_size_, // depth / (block_size_ * block_size_)] // - xla::ComputationDataHandle output = b->Reshape( - permuted_reshaped, {input_shape[0], input_shape[1] * block_size_, - input_shape[2] * block_size_, - input_shape[3] / (block_size_ * block_size_)}); + xla::ComputationDataHandle output = + b->Reshape(permuted_reshaped, output_shape); ctx->SetOutput(0, output); } private: + TensorFormat data_format_; int block_size_; }; REGISTER_XLA_OP(Name("DepthToSpace"), DepthToSpaceOp); diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index ec5017f6ab96bd3fc273a746b77fbb7e74fd9f35..765ea922a532a085a552192348ab360c4c30ff0a 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -22,6 +24,62 @@ limitations under the License. namespace tensorflow { namespace { +// Create a diagonal / batch diagonal matrix with 'input' on the diagonal. +xla::StatusOr CreateDiagonal( + const xla::ComputationDataHandle& input, int64 last_dim_size, + tensorflow::gtl::ArraySlice other_dims, XlaOpKernelContext* ctx, + xla::ComputationBuilder* builder) { + // Create two matrices that have the following forms, and compare them: + // + // [[0, 0, 0, 0] [[0, 1, 2, 3] + // [1, 1, 1, 1] [0, 1, 2, 3] + // [2, 2, 2, 2] [0, 1, 2, 3] + // [3, 3, 3, 3]] [0, 1, 2, 3]] + // + // This produces a predicate matrix of the right size, with "true" on the + // diagonal. + xla::ComputationDataHandle iota; + TF_RETURN_IF_ERROR( + XlaHelpers::Iota(builder, DataType::DT_INT32, last_dim_size, &iota)); + xla::ComputationDataHandle iota_broadcast = + builder->Broadcast(iota, {last_dim_size}); + xla::ComputationDataHandle mask = builder->Eq(iota_broadcast, iota, {0}); + + // If this is a batched diagonal, broadcast the mask across the other + // dimensions. + if (!other_dims.empty()) { + mask = builder->Broadcast(mask, other_dims); + } + + // Broadcast the input, and then use the mask computed above to select the + // diagonal: + // e.g, in 2D: + // [[t, f, f] [[1, 1, 1] [[0, 0, 0] [[1, 0, 0] + // select( [f, t, f] , [4, 4, 4] , [0, 0, 0] ) = [0, 4, 0] + // [f, f, t]] [9, 9, 9]] [0, 0, 0]] [0, 0, 9]] + // + // Broadcasting the input is less-than-trivial, since we need to broadcast + // into a "middle" dimension. We can do this with a reshape + implicit + // broadcast. + // TODO(b/30112114): Replace with in-dim broadcast when those are supported. + std::vector broadcast_dims(other_dims.begin(), other_dims.end()); + broadcast_dims.push_back(1LL); + broadcast_dims.push_back(last_dim_size); + xla::ComputationDataHandle input_broadcast = + builder->Reshape(input, broadcast_dims); + + broadcast_dims[broadcast_dims.size() - 2] = last_dim_size; + xla::PrimitiveType element_type; + TF_RETURN_IF_ERROR( + DataTypeToPrimitiveType(ctx->input_type(0), &element_type)); + auto broadcast_shape = + xla::ShapeUtil::MakeShape(element_type, broadcast_dims); + xla::ComputationDataHandle zeros = Zeros(builder, broadcast_shape); + + input_broadcast = builder->Add(input_broadcast, zeros); + return builder->Select(mask, input_broadcast, zeros); +} + class DiagOp : public XlaOpKernel { public: explicit DiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} @@ -29,6 +87,8 @@ class DiagOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::ComputationBuilder* builder = ctx->builder(); + OP_REQUIRES(ctx, ctx->num_inputs() >= 1, + errors::InvalidArgument("Diag op must have at an input")); const TensorShape input_shape = ctx->InputShape(0); auto dims = input_shape.dim_sizes(); @@ -36,7 +96,7 @@ class DiagOp : public XlaOpKernel { errors::InvalidArgument("Expected 1 <= dims, got shape ", input_shape.DebugString())); - xla::ComputationDataHandle diag = ctx->Input(0); + xla::ComputationDataHandle input = ctx->Input(0); // Picture: // tf.diag([1, 2, 3, 4]) ==> [[1, 0, 0, 0] @@ -46,13 +106,13 @@ class DiagOp : public XlaOpKernel { // Flattens the input to 1D. int64 size = input_shape.num_elements(); - diag = builder->Reshape(diag, {size}); + input = builder->Reshape(input, {size}); - // Adds inter-element padding of 'size'. - xla::PaddingConfig config; - auto* dim = config.add_dimensions(); - dim->set_interior_padding(size); - diag = builder->Pad(diag, XlaHelpers::Zero(builder, input_type(0)), config); + // Create an R2 with the R1 diagonal. + auto diag_or_status = + CreateDiagonal(input, size, /*other_dims=*/{}, ctx, builder); + OP_REQUIRES_OK(ctx, diag_or_status.status()); + xla::ComputationDataHandle diag = diag_or_status.ValueOrDie(); // Reshapes to the final shape. std::vector new_dims(dims.size() * 2); @@ -141,6 +201,8 @@ class MatrixDiagOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::ComputationBuilder* builder = ctx->builder(); + OP_REQUIRES(ctx, ctx->num_inputs() >= 1, + errors::InvalidArgument("MatrixDiag op must have at an input")); const TensorShape input_shape = ctx->InputShape(0); auto dims = input_shape.dim_sizes(); @@ -152,17 +214,13 @@ class MatrixDiagOp : public XlaOpKernel { int last_dim = dims.size() - 1; int64 last_dim_size = input_shape.dim_size(last_dim); + tensorflow::gtl::ArraySlice other_dims(dims); + other_dims.pop_back(); - // Adds inter-element padding of 'last_dim_size' to the last dimension. - xla::PaddingConfig config = xla::MakeNoPaddingConfig(dims.size()); - auto* dim = config.mutable_dimensions(last_dim); - dim->set_interior_padding(last_dim_size); - diag = builder->Pad(diag, XlaHelpers::Zero(builder, input_type(0)), config); - - // Reshapes to the final shape. - dims.push_back(last_dim_size); - diag = builder->Reshape(diag, dims); - + auto diag_or_status = + CreateDiagonal(diag, last_dim_size, other_dims, ctx, builder); + OP_REQUIRES_OK(ctx, diag_or_status.status()); + diag = diag_or_status.ValueOrDie(); ctx->SetOutput(0, diag); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..d91ebb500b4479dbb3c8e2ea7719bc79dc24ba4f --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -0,0 +1,367 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/lib/math/math_util.h" + +namespace tensorflow { +namespace { + +// We implement bilinear interpolation by upsampling followed by convolution. +// The basic idea is as follows. To scale from NxN to RxR: +// +// 1. S := (N - 1) / gcd(N-1, R-1) +// 2. k := (R - 1) / gcd(N-1, R-1) +// 3. Convolution(kxk, stride=S, lhs_dilation=k, padding=k-1) +// +// For example, to Scale from 7x7 -> 15x15: +// +// 1. S := (7-1) / gcd(7-1, 15-1) = 6 / gcd(6, 14) = 6 / 2 = 3 +// 2. k := (15 - 1) / gcd(7-1, 15-1) = 14 / gcd(6, 14) = 14 / 2 = 7 +// 3. Convolution(7x7, stride=3, lhs_dilation=3, padding=2) +// +// +// The 7x7 -> 15x15 case is much too large to write out in full as an +// example. The smallest interesting example is 3x3 -> 4x4. +// +// S := 2 +// k := 3 +// +// 00 03 06 00 00 00 00 00 00 00 00 00 00 00 00 02 04 06 +// 09 12 15 -> 00 00 00 00 00 00 00 00 00 00 00 -> 06 08 10 12 +// 18 21 24 00 00 00 00 00 03 00 00 06 00 00 12 14 16 18 +// 00 00 00 00 00 00 00 00 00 00 00 18 20 22 24 +// 00 00 00 00 00 00 00 00 00 00 00 +// 00 00 09 00 00 12 00 00 15 00 00 +// 00 00 00 00 00 00 00 00 00 00 00 +// 00 00 00 00 00 00 00 00 00 00 00 +// 00 00 18 00 00 21 00 00 24 00 00 +// 00 00 00 00 00 00 00 00 00 00 00 +// 00 00 00 00 00 00 00 00 00 00 00 +// +// with the following convolutional kernel, with stride [2, 2]: +// 1 2 3 2 1 +// 2 4 6 4 2 +// 1/9 * 3 6 9 6 3 +// 2 4 6 4 2 +// 1 2 3 2 1 + +// Computes the size of the convolutional kernel and stride to use when resizing +// from in_size to out_size. +struct ResizeConvolutionDims { + // Size of the kernel to use. + std::vector kernel_size; + + // Stride of the convolution to use. + std::vector stride; +}; +ResizeConvolutionDims ComputeResizeConvolutionParameters( + gtl::ArraySlice in_size, gtl::ArraySlice out_size) { + CHECK_EQ(in_size.size(), out_size.size()); + int num_spatial_dims = in_size.size(); + ResizeConvolutionDims dims; + dims.kernel_size.resize(num_spatial_dims); + dims.stride.resize(num_spatial_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] == 1) { + // We must handle input size 1 specially because XLA convolution does + // not allow stride 0. + dims.stride[i] = dims.kernel_size[i] = 1; + } else if (out_size[i] == 1) { + // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first + // entry before resizing. + dims.stride[i] = dims.kernel_size[i] = 1; + } else { + int64 gcd = MathUtil::GCD(static_cast(in_size[i] - 1), + static_cast(out_size[i] - 1)); + dims.stride[i] = (in_size[i] - 1) / gcd; + dims.kernel_size[i] = (out_size[i] - 1) / gcd; + } + } + return dims; +} + +xla::ComputationDataHandle MakeBilinearResizeKernel( + xla::ComputationBuilder* builder, gtl::ArraySlice kernel_size, + int64 channels) { + // Form a 2D convolution kernel like: + // 1 2 3 2 1 + // 2 4 6 4 2 + // 1/9 * 3 6 9 6 3 + // 2 4 6 4 2 + // 1 2 3 2 1 + // by multiplying two 1D kernels of the form: + // 1/3 * [1 2 3 2 1] + auto make_1d_kernel = [](int64 n) { + std::vector kernel(n * 2 - 1); + for (int64 i = 0; i < n; ++i) { + float v = i + 1; + kernel[i] = v; + kernel[n * 2 - 2 - i] = v; + } + return kernel; + }; + + // Form a block diagonal kernel where each channel interacts only with itself. + xla::Array4D diag(1, 1, channels, channels, 0.0f); + for (int i = 0; i < channels; ++i) { + diag(0, 0, i, i) = 1.0f / (kernel_size[0] * kernel_size[1]); + } + return builder->Mul( + builder->ConstantR1(make_1d_kernel(kernel_size[0])), + builder->Mul(builder->ConstantR1(make_1d_kernel(kernel_size[1])), + builder->ConstantR4FromArray4D(diag), + /*broadcast_dimensions=*/{1}), + /*broadcast_dimensions=*/{0}); +} + +class ResizeBilinearOp : public XlaOpKernel { + public: + explicit ResizeBilinearOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_)); + OP_REQUIRES( + ctx, align_corners_ == true, + errors::Unimplemented( + "ResizeBilinear with align_corners=False is not yet implemented")); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + + TensorShape input_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, input_shape.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input_shape.DebugString())); + const int64 batch = input_shape.dim_size(0); + const std::vector in_size = {input_shape.dim_size(1), + input_shape.dim_size(2)}; + const int64 channels = input_shape.dim_size(3); + OP_REQUIRES(ctx, in_size[0] > 0 && in_size[1] > 0, + errors::InvalidArgument("input size must be positive, got [", + in_size[0], ",", in_size[1], "]")); + + std::vector out_size; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &out_size)); + OP_REQUIRES(ctx, out_size.size() == 2, + errors::InvalidArgument("output size must be length 2, got ", + out_size.size())); + OP_REQUIRES(ctx, out_size[0] > 0 && out_size[1] > 0, + errors::InvalidArgument("output size must be positive, got [", + out_size[0], ",", out_size[1], "]")); + + const int num_spatial_dims = 2; + + xla::ComputationDataHandle input = ctx->Input(0); + + // If in_size[i] > 1 and out_size[i] == 1, slice out the first input in + // dimension i. + std::vector slice_size = in_size; + bool slice_input = false; + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] > 1 && out_size[i] == 1) { + // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first + // entry before resizing. + slice_input = true; + slice_size[i] = 1; + } + } + if (slice_input) { + input = b->Slice(input, {0, 0, 0, 0}, + {batch, slice_size[0], slice_size[1], channels}, + {1, 1, 1, 1}); + } + + // Output is always type float. + input = b->ConvertElementType(input, xla::F32); + + // Picture for a 1x3 to 1x4 resize: + // stride = 2, kernel size = 3 + // Input: + // 3 6 9 + // Input with dilation and padding: + // 0 0 3 0 0 6 0 0 9 0 0 + // Convolution kernel: + // 1/3 * [1 2 3 2 1] + // Output: + // 3 5 7 9 + xla::ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + for (int i = 0; i < num_spatial_dims; ++i) { + dnums.add_input_spatial_dimensions(1 + i); + dnums.add_output_spatial_dimensions(1 + i); + dnums.add_kernel_spatial_dimensions(i); + } + dnums.set_kernel_input_feature_dimension(num_spatial_dims); + dnums.set_kernel_output_feature_dimension(num_spatial_dims + 1); + + ResizeConvolutionDims dims = + ComputeResizeConvolutionParameters(in_size, out_size); + xla::ComputationDataHandle kernel = + MakeBilinearResizeKernel(b, dims.kernel_size, channels); + xla::ComputationDataHandle output = b->ConvGeneralDilated( + input, kernel, dims.stride, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, + {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/dims.kernel_size, + /*rhs_dilation=*/{1, 1}, dnums); + + // Add broadcasts to handle expanding from a size == 1 dimension to a + // size > 1 dimension. + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] == 1 && out_size[i] > 1) { + output = b->Add(output, b->ConstantR1(out_size[i], 0), + /*broadcast_dimensions=*/{1 + i}); + } + } + + ctx->SetOutput(0, output); + } + + private: + bool align_corners_; +}; + +REGISTER_XLA_OP(Name("ResizeBilinear"), ResizeBilinearOp); + +class ResizeBilinearGradOp : public XlaOpKernel { + public: + explicit ResizeBilinearGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_)); + OP_REQUIRES( + ctx, align_corners_ == true, + errors::Unimplemented("ResizeBilinearGrad with align_corners=False is " + "not yet implemented")); + + DataType output_dtype; + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &output_dtype)); + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(output_dtype, &output_type_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + + TensorShape input_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, input_shape.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input_shape.DebugString())); + const int64 batch = input_shape.dim_size(0); + const std::vector in_size = {input_shape.dim_size(1), + input_shape.dim_size(2)}; + const int64 channels = input_shape.dim_size(3); + OP_REQUIRES(ctx, in_size[0] > 0 && in_size[1] > 0, + errors::InvalidArgument("input size must be positive, got [", + in_size[0], ",", in_size[1], "]")); + + TensorShape grad_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, grad_shape.dims() == 4, + errors::InvalidArgument("gradient must be 4-dimensional", + grad_shape.DebugString())); + const int64 grad_batch = grad_shape.dim_size(0); + const std::vector grad_size = {grad_shape.dim_size(1), + grad_shape.dim_size(2)}; + const int64 grad_channels = grad_shape.dim_size(3); + OP_REQUIRES(ctx, batch == grad_batch, + errors::InvalidArgument( + "activations and gradients must have the same batch size (", + batch, " vs. ", grad_batch, ")")); + OP_REQUIRES(ctx, grad_size[0] > 0 && grad_size[1] > 0, + errors::InvalidArgument("gradient size must be positive, got [", + grad_size[0], ",", grad_size[1], "]")); + OP_REQUIRES( + ctx, channels == grad_channels, + errors::InvalidArgument( + "activations and gradients must have the same number of channels (", + channels, " vs. ", grad_channels, ")")); + + const int num_spatial_dims = 2; + + xla::ComputationDataHandle grad = ctx->Input(0); + + ResizeConvolutionDims dims = + ComputeResizeConvolutionParameters(in_size, grad_size); + + // To form the backward convolution, we keep the kernel unchanged (it is + // already symmetric) and swap the roles of strides and LHS dilation. + xla::ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + for (int i = 0; i < num_spatial_dims; ++i) { + dnums.add_input_spatial_dimensions(1 + i); + dnums.add_output_spatial_dimensions(1 + i); + dnums.add_kernel_spatial_dimensions(i); + } + dnums.set_kernel_input_feature_dimension(num_spatial_dims); + dnums.set_kernel_output_feature_dimension(num_spatial_dims + 1); + xla::ComputationDataHandle kernel = + MakeBilinearResizeKernel(b, dims.kernel_size, channels); + + // Broadcast the input kernel where the forward op expanded from a size == 1 + // dimension to a size > 1 dimension. This has the effect of summing the + // gradient contributions in that dimension. + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] == 1 && grad_size[i] > 1) { + kernel = b->Add(kernel, b->ConstantR1(grad_size[i], 0), + /*broadcast_dimensions=*/{i}); + } + } + + xla::ComputationDataHandle output = b->ConvGeneralDilated( + grad, kernel, /*window_strides=*/dims.kernel_size, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, + {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/dims.stride, + /*rhs_dilation=*/{1, 1}, dnums); + + // If in_size[i] > 1 and grad_size[i] == 1, pad the output in dimension i. + // Opposite of the slice performed by the forward op. + xla::PaddingConfig padding = xla::MakeNoPaddingConfig(4); + bool pad_output = false; + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] > 1 && grad_size[i] == 1) { + pad_output = true; + padding.mutable_dimensions(1 + i)->set_edge_padding_high(in_size[i] - + 1); + } + } + if (pad_output) { + output = b->Pad(output, b->ConstantR0(0.0f), padding); + } + + output = b->ConvertElementType(output, output_type_); + ctx->SetOutput(0, output); + } + + private: + bool align_corners_; + xla::PrimitiveType output_type_; +}; + +REGISTER_XLA_OP(Name("ResizeBilinearGrad"), ResizeBilinearGradOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index fcef497e5845d9080bc83b54e92dcf2fdecf5f12..644abd5905c6ce5a8f61792a1986560bab891040 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -23,8 +23,8 @@ limitations under the License. namespace tensorflow { namespace { -constexpr std::array kMatmulTypes = { - {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64}}; +constexpr std::array kMatmulTypes = { + {DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64}}; class MatMulOp : public XlaOpKernel { public: @@ -85,10 +85,7 @@ class SparseMatMulOp : public MatMulOp { ~SparseMatMulOp() override = default; }; -REGISTER_XLA_OP(Name("SparseMatMul") - .TypeConstraint("Ta", kFloatTypes) - .TypeConstraint("Tb", kFloatTypes), - SparseMatMulOp); +REGISTER_XLA_OP(Name("SparseMatMul"), SparseMatMulOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..650f8c7dc8be0cb08997ec641ca3f82352166fdd --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -0,0 +1,141 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/concat_lib.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { + +// TODO(phawkins): implement double-sized windowed reductions in XLA and remove +// the type constraint. +constexpr std::array kScanOpTypes = { + {DT_HALF, DT_BFLOAT16, DT_FLOAT}}; + +class ScanOp : public XlaOpKernel { + public: + ScanOp(OpKernelConstruction* ctx, bool sum) : XlaOpKernel(ctx), sum_(sum) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("reverse", &reverse_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("exclusive", &exclusive_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape(0); + const TensorShape tensor_axis_shape = ctx->InputShape(1); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tensor_axis_shape), + errors::InvalidArgument("ScanOp: axis must be a scalar, not ", + tensor_axis_shape.DebugString())); + + int64 axis; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &axis)); + if (axis < 0) { + axis += input_shape.dims(); + } + OP_REQUIRES( + ctx, FastBoundsCheck(axis, input_shape.dims()), + errors::InvalidArgument("ScanOp: Expected scan axis in the range [", + -input_shape.dims(), ", ", input_shape.dims(), + "), but got ", axis)); + + DataType dtype = ctx->input_type(0); + + if (input_shape.num_elements() == 0) { + // Exit early if there is nothing to compute. + ctx->SetOutput(0, ctx->Input(0)); + return; + } + + xla::ComputationBuilder* builder = ctx->builder(); + + std::vector window_strides(input_shape.dims(), 1); + std::vector window_dims(input_shape.dims(), 1); + window_dims[axis] = input_shape.dim_size(axis); + + std::vector> padding(input_shape.dims(), {0, 0}); + padding[axis].first = input_shape.dim_size(axis) - 1; + // In exclusive mode, add an extra padding element so there is a complete + // window of padding before the data starts. + if (exclusive_) { + ++padding[axis].first; + } + if (reverse_) { + std::swap(padding[axis].first, padding[axis].second); + } + + xla::ComputationDataHandle input = ctx->Input(0); + xla::ComputationDataHandle init; + const xla::Computation* reducer; + if (sum_) { + init = XlaHelpers::Zero(builder, dtype); + reducer = ctx->GetOrCreateAdd(dtype); + } else { + init = XlaHelpers::One(builder, dtype); + reducer = ctx->GetOrCreateMul(dtype); + } + auto output = builder->ReduceWindowWithGeneralPadding( + ctx->Input(0), init, *reducer, window_dims, window_strides, padding); + + // In exclusive mode, we have computed an extra element containing the sum + // of all the input elements. Slice off this extra "last" element. + if (exclusive_) { + if (reverse_) { + output = builder->SliceInDim(output, 1, input_shape.dim_size(axis) + 1, + 1, axis); + + } else { + output = + builder->SliceInDim(output, 0, input_shape.dim_size(axis), 1, axis); + } + } + ctx->SetOutput(0, output); + } + + private: + const bool sum_; // True=cumulative sum. False=cumulative product. + bool reverse_; + bool exclusive_; +}; + +class CumsumOp : public ScanOp { + public: + explicit CumsumOp(OpKernelConstruction* ctx) : ScanOp(ctx, /*sum=*/true) {} +}; +REGISTER_XLA_OP(Name("Cumsum").TypeConstraint("T", kScanOpTypes), CumsumOp); + +class CumprodOp : public ScanOp { + public: + explicit CumprodOp(OpKernelConstruction* ctx) : ScanOp(ctx, /*sum=*/false) {} +}; +REGISTER_XLA_OP(Name("Cumprod").TypeConstraint("T", kScanOpTypes), CumprodOp); + +} // anonymous namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 24a99f253d6dc8bb699fff587c363b12c227e821..e205fadd2b1bcae96a7bfa1bc83096d405ce22c4 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -15,6 +15,7 @@ limitations under the License. // XLA-specific Shape Ops. +#include "tensorflow/compiler/tf2xla/kernels/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -27,56 +28,42 @@ namespace { class ShapeOp : public XlaOpKernel { public: - explicit ShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + explicit ShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); + } void Compile(XlaOpKernelContext* ctx) override { const TensorShape input_shape = ctx->InputShape(0); - const int rank = input_shape.dims(); - Tensor shape_constant(DT_INT32, TensorShape({rank})); - auto vec = shape_constant.vec(); - // TODO(dga): support int64. b/28119922. - for (int i = 0; i < rank; ++i) { - int64 dim_size = input_shape.dim_size(i); - OP_REQUIRES( - ctx, FastBoundsCheck(dim_size, std::numeric_limits::max()), - errors::InvalidArgument("Shape does not support tensors > int32max", - " but dim ", i, " is ", dim_size)); - vec(i) = static_cast(dim_size); - } - + Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()})); + OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant)); ctx->SetConstantOutput(0, shape_constant); } + + private: + DataType out_dtype_; }; REGISTER_XLA_OP(Name("Shape"), ShapeOp); class ShapeNOp : public XlaOpKernel { public: - explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); + } void Compile(XlaOpKernelContext* ctx) override { for (int i = 0; i < ctx->num_inputs(); ++i) { - const TensorShape shape = ctx->InputShape(i); - const int dims = shape.dims(); - Tensor shape_constant(DT_INT32, TensorShape({dims})); - auto vec = shape_constant.vec(); - - // TODO(dga): support int64. b/28119922. - for (int j = 0; j < dims; ++j) { - int64 dim_size = shape.dim_size(j); - OP_REQUIRES( - ctx, FastBoundsCheck(dim_size, std::numeric_limits::max()), - errors::InvalidArgument("Shape does not support tensors > int32max", - " but shape ", i, " dim ", j, " is ", - dim_size)); - vec(j) = static_cast(dim_size); - } - + const TensorShape input_shape = ctx->InputShape(i); + Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()})); + OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant)); ctx->SetConstantOutput(i, shape_constant); } } bool IsExpensive() override { return false; } + + private: + DataType out_dtype_; }; REGISTER_XLA_OP(Name("ShapeN"), ShapeNOp); diff --git a/tensorflow/compiler/tf2xla/kernels/shape_util.cc b/tensorflow/compiler/tf2xla/kernels/shape_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..76ea5f525598f511f295eb5a30f3cf603fbf57aa --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/shape_util.cc @@ -0,0 +1,48 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/kernels/shape_util.h" + +#include + +#include "tensorflow/core/kernels/bounds_check.h" + +namespace tensorflow { + +Status TensorShapeToConstant(const TensorShape& input_shape, + Tensor* shape_constant) { + const int dims = input_shape.dims(); + if (shape_constant->dtype() == DT_INT32) { + auto vec = shape_constant->vec(); + for (int i = 0; i < dims; ++i) { + int64 dim_size = input_shape.dim_size(i); + if (!FastBoundsCheck(dim_size, std::numeric_limits::max())) { + return errors::InvalidArgument( + "Shape with out_type=int32 does not support tensors > int32max", + " but dim ", i, " is ", dim_size); + } + vec(i) = static_cast(dim_size); + } + } else { + auto vec = shape_constant->vec(); + for (int i = 0; i < dims; ++i) { + int64 dim_size = input_shape.dim_size(i); + vec(i) = dim_size; + } + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/shape_util.h b/tensorflow/compiler/tf2xla/kernels/shape_util.h new file mode 100644 index 0000000000000000000000000000000000000000..575086e118080f6799a54d3ae6409b2b641c4341 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/shape_util.h @@ -0,0 +1,34 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_SHAPE_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_SHAPE_UTIL_H_ + +#include + +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { + +// Converts a TensorShape to a constant Tensor. +// +// The input TensorShape input_shape is used to populate the elements of +// shape_constant, which is modified in place. +Status TensorShapeToConstant(const TensorShape& input_shape, + Tensor* shape_constant); + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_TF2XLA_KERNELS_SHAPE_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc index 89befda346ec06fec23ab1d1c9d910ded8cd806d..806fda632cde64c1b37ae3b9199028d6b6b0a215 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/util/tensor_format.h" namespace tensorflow { namespace { @@ -23,6 +24,16 @@ namespace { class SpaceToDepthOp : public XlaOpKernel { public: explicit SpaceToDepthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format")); + + OP_REQUIRES(ctx, data_format_ == FORMAT_NCHW || data_format_ == FORMAT_NHWC, + errors::InvalidArgument("Unsupported data format ", + ToString(data_format_), + "; expected formats NHWC or NCHW")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_)); OP_REQUIRES( ctx, block_size_ > 1, @@ -31,34 +42,100 @@ class SpaceToDepthOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const TensorShape input_tensor_shape = ctx->InputShape(0); - // The input is presumed to be [batch, height, width, depth] int input_rank = input_tensor_shape.dims(); static const int kRequiredDims = 4; OP_REQUIRES(ctx, kRequiredDims == input_rank, - errors::InvalidArgument("Input rank should be: ", kRequiredDims, - " instead of: ", input_rank)); + errors::InvalidArgument("Input rank should be ", kRequiredDims, + "; got ", input_rank)); const gtl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); xla::ComputationBuilder* b = ctx->builder(); xla::ComputationDataHandle input = ctx->Input(0); + int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_); + int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format_); + + std::vector reshaped_shape; + std::vector transpose_order; + std::vector output_shape; + reshaped_shape.reserve(input_rank); + transpose_order.reserve(input_rank); + output_shape.reserve(input_rank); + if (data_format_ == FORMAT_NHWC) { + int64 block_elems = 1; + for (int i = 0; i < num_spatial_dims; ++i) { + OP_REQUIRES(ctx, input_shape[1 + i] % block_size_ == 0, + errors::InvalidArgument( + "input shape[", 1 + i, "]=", input_shape[1 + i], + " is not divisible by block_size=", block_size_)); + block_elems *= block_size_; + } + + reshaped_shape.push_back(input_shape[0]); + for (int i = 0; i < num_spatial_dims; ++i) { + reshaped_shape.push_back(input_shape[1 + i] / block_size_); + reshaped_shape.push_back(block_size_); + } + reshaped_shape.push_back(input_shape[feature_dim]); + + transpose_order.push_back(0); + for (int i = 0; i < num_spatial_dims; ++i) { + transpose_order.push_back(i * 2 + 1); + } + for (int i = 0; i < num_spatial_dims; ++i) { + transpose_order.push_back(i * 2 + 2); + } + transpose_order.push_back(feature_dim + num_spatial_dims); + + output_shape.push_back(input_shape[0]); + for (int i = 0; i < num_spatial_dims; ++i) { + output_shape.push_back(input_shape[1 + i] / block_size_); + } + output_shape.push_back(input_shape[feature_dim] * block_elems); + } else { + // FORMAT_NCHW + int64 block_elems = 1; + for (int i = 0; i < num_spatial_dims; ++i) { + OP_REQUIRES(ctx, input_shape[2 + i] % block_size_ == 0, + errors::InvalidArgument( + "input shape[", 2 + i, "]=", input_shape[2 + i], + " is not divisible by block_size=", block_size_)); + block_elems *= block_size_; + } + + reshaped_shape.push_back(input_shape[0]); + reshaped_shape.push_back(input_shape[feature_dim]); + for (int i = 0; i < num_spatial_dims; ++i) { + reshaped_shape.push_back(input_shape[2 + i] / block_size_); + reshaped_shape.push_back(block_size_); + } + + transpose_order.push_back(0); + for (int i = 0; i < num_spatial_dims; ++i) { + transpose_order.push_back(i * 2 + 3); + } + transpose_order.push_back(feature_dim); + for (int i = 0; i < num_spatial_dims; ++i) { + transpose_order.push_back(i * 2 + 2); + } + + output_shape.push_back(input_shape[0]); + output_shape.push_back(input_shape[feature_dim] * block_elems); + for (int i = 0; i < num_spatial_dims; ++i) { + output_shape.push_back(input_shape[2 + i] / block_size_); + } + } + + // Note: comments are given in NHWC format; NCHW is similar with a different + // dimension order. // 1. Reshape `input` to `reshaped` of shape: // // [batch, // input_shape[1] / block_size_, block_size_, // input_shape[2] / block_size_, block_size_, // depth] - const int block_rank = 2; - for (int i = 0; i < block_rank; ++i) { - OP_REQUIRES(ctx, input_shape[1 + i] % block_size_ == 0, - errors::InvalidArgument( - "input shape[", 1 + i, "]=", input_shape[1 + i], - " is not divisible by block_size=", block_size_)); - } - xla::ComputationDataHandle reshaped = b->Reshape( - input, {input_shape[0], input_shape[1] / block_size_, block_size_, - input_shape[2] / block_size_, block_size_, input_shape[3]}); + xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape); // 2. Permute dimensions of `reshaped` to produce // `permuted_reshaped` of shape: @@ -69,7 +146,7 @@ class SpaceToDepthOp : public XlaOpKernel { // block_size_, block_size_, // depth] xla::ComputationDataHandle permuted_reshaped = - b->Transpose(reshaped, {0, 1, 3, 2, 4, 5}); + b->Transpose(reshaped, transpose_order); // 3. Reshape `permuted_reshaped` to flatten `block_shape` into the // batch dimension, producing an output tensor of shape: @@ -79,15 +156,14 @@ class SpaceToDepthOp : public XlaOpKernel { // input_shape[2] / block_size_, // block_size_ * block_size_ * depth] // - xla::ComputationDataHandle output = b->Reshape( - permuted_reshaped, {input_shape[0], input_shape[1] / block_size_, - input_shape[2] / block_size_, - block_size_ * block_size_ * input_shape[3]}); + xla::ComputationDataHandle output = + b->Reshape(permuted_reshaped, output_shape); ctx->SetOutput(0, output); } private: + TensorFormat data_format_; int block_size_; }; REGISTER_XLA_OP(Name("SpaceToDepth"), SpaceToDepthOp); diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..b10880de77e6b9811008076cd4a959c284e558d1 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -0,0 +1,279 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/lib/math/math_util.h" + +namespace tensorflow { +namespace { + +// Rotates a 32-bit integer 'v' left by 'distance' bits. +xla::ComputationDataHandle RotateLeftS32(xla::ComputationBuilder* builder, + const xla::ComputationDataHandle& v, + int distance) { + return builder->Or( + builder->ShiftLeft(v, builder->ConstantR0(distance)), + builder->ShiftRightLogical(v, builder->ConstantR0(32 - distance))); +} + +// TODO(b/65209188): add a primitive XOR to XLA and call it here, rather than +// building XOR out of other bitwise operators. +xla::ComputationDataHandle BitwiseXor(xla::ComputationBuilder* builder, + const xla::ComputationDataHandle& x, + const xla::ComputationDataHandle& y) { + return builder->Or(builder->And(x, builder->Not(y)), + builder->And(builder->Not(x), y)); +} + +using ThreeFry2x32State = std::array; + +// Implements the ThreeFry counter-based PRNG algorithm. +// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. +// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf +ThreeFry2x32State ThreeFry2x32(xla::ComputationBuilder* builder, + ThreeFry2x32State input, ThreeFry2x32State key) { + // Rotation distances specified by the Threefry2x32 algorithm. + constexpr std::array rotations = {13, 15, 26, 6, 17, 29, 16, 24}; + ThreeFry2x32State x; + + std::array ks; + // 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm. + ks[2] = builder->ConstantR0(0x1BD11BDA); + for (int i = 0; i < 2; ++i) { + ks[i] = key[i]; + x[i] = input[i]; + ks[2] = BitwiseXor(builder, ks[2], key[i]); + } + + x[0] = builder->Add(x[0], ks[0]); + x[1] = builder->Add(x[1], ks[1]); + + // Performs a single round of the Threefry2x32 algorithm, with a rotation + // amount 'rotation'. + auto round = [builder](ThreeFry2x32State v, int rotation) { + v[0] = builder->Add(v[0], v[1]); + v[1] = RotateLeftS32(builder, v[1], rotation); + v[1] = BitwiseXor(builder, v[0], v[1]); + return v; + }; + + // There are no known statistical flaws with 13 rounds of Threefry2x32. + // We are conservative and use 20 rounds. + x = round(x, rotations[0]); + x = round(x, rotations[1]); + x = round(x, rotations[2]); + x = round(x, rotations[3]); + x[0] = builder->Add(x[0], ks[1]); + x[1] = builder->Add(builder->Add(x[1], ks[2]), builder->ConstantR0(1)); + + x = round(x, rotations[4]); + x = round(x, rotations[5]); + x = round(x, rotations[6]); + x = round(x, rotations[7]); + x[0] = builder->Add(x[0], ks[2]); + x[1] = builder->Add(builder->Add(x[1], ks[0]), builder->ConstantR0(2)); + + x = round(x, rotations[0]); + x = round(x, rotations[1]); + x = round(x, rotations[2]); + x = round(x, rotations[3]); + x[0] = builder->Add(x[0], ks[0]); + x[1] = builder->Add(builder->Add(x[1], ks[1]), builder->ConstantR0(3)); + + x = round(x, rotations[4]); + x = round(x, rotations[5]); + x = round(x, rotations[6]); + x = round(x, rotations[7]); + x[0] = builder->Add(x[0], ks[1]); + x[1] = builder->Add(builder->Add(x[1], ks[2]), builder->ConstantR0(4)); + + x = round(x, rotations[0]); + x = round(x, rotations[1]); + x = round(x, rotations[2]); + x = round(x, rotations[3]); + x[0] = builder->Add(x[0], ks[2]); + x[1] = builder->Add(builder->Add(x[1], ks[0]), builder->ConstantR0(5)); + + return x; +} + +// Returns a tensor of 'shape' random values uniformly distributed in the range +// [minval, maxval) +xla::ComputationDataHandle RandomUniform(xla::ComputationBuilder* builder, + const xla::ComputationDataHandle& seed, + const TensorShape& shape, + double minval, double maxval) { + // Split the seed into two 32-bit scalars to form a key. + auto seed0 = builder->Reshape(builder->Slice(seed, {0}, {1}, {1}), {}); + auto seed1 = builder->Reshape(builder->Slice(seed, {1}, {2}, {1}), {}); + ThreeFry2x32State key = {seed0, seed1}; + const int64 size = shape.num_elements(); + + const int64 half_size = MathUtil::CeilOfRatio(size, 2); + const bool size_is_odd = (half_size * 2 != size); + + // Fill the generator inputs with unique counter values. + ThreeFry2x32State inputs; + TF_CHECK_OK(XlaHelpers::Iota(builder, DT_INT32, half_size, &inputs[0])); + inputs[1] = builder->Add(inputs[0], builder->ConstantR0(half_size)); + ThreeFry2x32State outputs = ThreeFry2x32(builder, inputs, key); + + if (size_is_odd) { + outputs[1] = builder->Slice(outputs[1], {0}, {half_size - 1}, {1}); + } + + auto bits = + builder->Reshape(builder->ConcatInDim(outputs, 0), shape.dim_sizes()); + + // Form 22 random mantissa bits, with a leading 1 bit. The leading 1 bit + // forces the random bits into the mantissa. + constexpr int kFloatBits = 32; + constexpr int kMantissaBits = 23; + bits = builder->Or( + builder->ShiftRightLogical( + bits, builder->ConstantR0(kFloatBits - kMantissaBits)), + builder->ConstantR0(bit_cast(1.0f))); + auto floats = builder->BitcastConvertType(bits, xla::F32); + + // We have a floating point number in the range [1.0, 2.0). + // Subtract 1.0f to shift to the range [0.0, 1.0) + floats = builder->Sub(floats, builder->ConstantR0(1.0f)); + // Multiply and add to shift to the range [minval, maxval). + floats = builder->Mul(floats, builder->ConstantR0(maxval - minval)); + floats = builder->Add(floats, builder->ConstantR0(minval)); + return floats; +} + +// Approximation for the inverse error function from +// Giles, M., "Approximating the erfinv function". +// The approximation has the form: +// w = -log((1 - x) * (1 + x)) +// if ( w < 5 ) { +// w = w - 2.5 +// p = sum_{i=1}^n lq[i]*w^i +// } else { +// w = sqrt(w) - 3 +// p = sum_{i=1}^n gq[i]*w^i +// } +// return p*x +xla::ComputationDataHandle ErfInvF32(xla::ComputationBuilder* b, + const xla::ComputationDataHandle& x, + const TensorShape& shape) { + constexpr int kDegree = 9; + constexpr std::array w_less_than_5_constants = { + 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, + -4.39150654e-06f, 0.00021858087f, -0.00125372503f, + -0.00417768164f, 0.246640727f, 1.50140941f}; + constexpr std::array w_greater_than_5_constants = { + -0.000200214257f, 0.000100950558f, 0.00134934322f, + -0.00367342844f, 0.00573950773f, -0.0076224613f, + 0.00943887047f, 1.00167406f, 2.83297682f}; + + auto one = b->ConstantR0(1.0); + auto w = b->Neg(b->Log(b->Mul(b->Sub(one, x), b->Add(one, x)))); + + auto lt = b->Lt(w, b->ConstantR0(5.0)); + auto coefficient = [&](int i) { + return b->Select( + lt, + b->Broadcast(b->ConstantR0(w_less_than_5_constants[i]), + shape.dim_sizes()), + b->Broadcast(b->ConstantR0(w_greater_than_5_constants[i]), + shape.dim_sizes())); + }; + w = b->Select(lt, b->Sub(w, b->ConstantR0(2.5f)), + b->Sub(b->SqrtF32(w), b->ConstantR0(3.0f))); + auto p = coefficient(0); + for (int i = 1; i < kDegree; ++i) { + p = b->Add(coefficient(i), b->Mul(p, w)); + } + return b->Mul(p, x); +} + +} // namespace + +class StatelessRandomUniformOp : public XlaOpKernel { + public: + explicit StatelessRandomUniformOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* builder = ctx->builder(); + + TensorShape shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); + + TensorShape seed_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2, + errors::InvalidArgument("seed must have shape [2], not ", + seed_shape.DebugString())); + xla::ComputationDataHandle seed = ctx->Input(1); + ctx->SetOutput(0, RandomUniform(builder, seed, shape, 0.0, 1.0)); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformOp); +}; + +// TODO(phawkins): generalize to non-float, non-int32 seed types. +REGISTER_XLA_OP(Name("StatelessRandomUniform") + .TypeConstraint("dtype", DT_FLOAT) + .TypeConstraint("Tseed", DT_INT32), + StatelessRandomUniformOp); + +class StatelessRandomNormalOp : public XlaOpKernel { + public: + explicit StatelessRandomNormalOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); + + TensorShape seed_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, seed_shape == TensorShape({2}), + errors::InvalidArgument("seed must have shape [2], not ", + seed_shape.DebugString())); + xla::ComputationDataHandle seed = ctx->Input(1); + xla::ComputationBuilder* builder = ctx->builder(); + auto uniform = RandomUniform(builder, seed, shape, -1.0, 1.0); + // Convert uniform distribution to normal distribution by computing + // sqrt(2) * erfinv(x) + auto normal = builder->Mul(builder->ConstantR0(std::sqrt(2.0)), + ErfInvF32(builder, uniform, shape)); + ctx->SetOutput(0, normal); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomNormalOp); +}; + +// TODO(phawkins): generalize to non-float, non-int32 seed types. +REGISTER_XLA_OP(Name("StatelessRandomNormal") + .TypeConstraint("dtype", DT_FLOAT) + .TypeConstraint("Tseed", DT_INT32), + StatelessRandomNormalOp); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 351fda251798e43b607fb445f2c98abd57b3d86b..03c22354a9425189e6cf7ee5a7201c90ecb1908d 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -311,6 +311,32 @@ class TensorArrayGatherOp : public XlaOpKernel { xla::ComputationDataHandle ta = resource->value; + // Look for the case where the gather takes a simple slice from the + // tensor array (0, 1, 2, 3, 4, ..., N) + std::vector const_indices; + Status status = ctx->ConstantInputAsIntVector(1, &const_indices); + if (status.ok()) { + bool gather_is_dense_slice = true; + for (auto i = 0; i < const_indices.size(); i++) { + if (const_indices[i] != i) { + gather_is_dense_slice = false; + break; + } + } + + if (gather_is_dense_slice) { + std::vector begin(ta_shape.dims(), 0); + std::vector strides(ta_shape.dims(), 1); + std::vector end(ta_shape.dims(), 1); + end[0] = const_indices.size(); + for (auto i = 1; i < ta_shape.dims(); i++) { + end[i] = ta_shape.dim_size(i); + } + ctx->SetOutput(0, b->Slice(ta, begin, end, strides)); + return; + } + } + xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice( ctx, ta, ta_shape, indices, indices_shape, 0, dtype_, index_type, b); ctx->SetOutput(0, gather); @@ -352,28 +378,47 @@ class TensorArrayScatterOp : public XlaOpKernel { const xla::ComputationDataHandle value = ctx->Input(2); const xla::ComputationDataHandle flow = ctx->Input(3); - auto slice_dims = value_shape.dim_sizes(); - slice_dims[0] = 1LL; - - std::vector value_starts(value_shape.dims(), 0); - auto value_ends = value_shape.dim_sizes(); - - std::vector value_strides(value_shape.dims(), 1); - - // For every (index, value) pair, update the corresponding TensorArray - // storage. - for (int i = 0; i < num_indices; ++i) { - // Slice out part of the value. - value_starts[0] = i; - value_ends[0] = i + 1; - auto slice = b->Slice(value, value_starts, value_ends, value_strides); + // Look for the case where the scatter is for each sub-tensor in order. The + // tensor array implementation allows for this to be a straight addition. + bool scatter_all_elements_in_order = false; + std::vector const_indices; + Status status = ctx->ConstantInputAsIntVector(1, &const_indices); + if (status.ok() && num_indices == value_shape.dim_size(0)) { + scatter_all_elements_in_order = true; + for (auto i = 0; i < num_indices; i++) { + if (const_indices[i] != i) { + scatter_all_elements_in_order = false; + break; + } + } + } - // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. - auto index = b->Slice(indices, {i}, {i + 1}, {1}); - auto start_indices = - b->Pad(b->Reshape(index, {1}), b->ConstantR0(0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); - ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); + if (scatter_all_elements_in_order) { + ta = b->Add(ta, value); + } else { + auto slice_dims = value_shape.dim_sizes(); + slice_dims[0] = 1LL; + + std::vector value_starts(value_shape.dims(), 0); + auto value_ends = value_shape.dim_sizes(); + + std::vector value_strides(value_shape.dims(), 1); + + // For every (index, value) pair, update the corresponding TensorArray + // storage. + for (int i = 0; i < num_indices; ++i) { + // Slice out part of the value. + value_starts[0] = i; + value_ends[0] = i + 1; + auto slice = b->Slice(value, value_starts, value_ends, value_strides); + + // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. + auto index = b->Slice(indices, {i}, {i + 1}, {1}); + auto start_indices = + b->Pad(b->Reshape(index, {1}), b->ConstantR0(0), + xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); + } } resource->value = ta; diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index b19ea22f50d2dd44e8d1d81f5930263f364030e1..68847ae7a2cb926edd9d29007e24b0db7fb5a75f 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" +#include "tensorflow/compiler/tf2xla/kernels/shape_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -22,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/no_op.h" namespace tensorflow { @@ -121,5 +123,26 @@ class ResourceGatherOp : public XlaOpKernel { REGISTER_XLA_OP(Name("ResourceGather").TypeConstraint("dtype", kNumericTypes), ResourceGatherOp); +class VariableShapeOp : public XlaOpKernel { + public: + explicit VariableShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + DataType variable_dtype; + TensorShape shape; + OP_REQUIRES_OK(ctx, + ctx->GetVariableTypeAndShape(0, &variable_dtype, &shape)); + Tensor shape_constant(out_dtype_, TensorShape({shape.dims()})); + OP_REQUIRES_OK(ctx, TensorShapeToConstant(shape, &shape_constant)); + ctx->SetConstantOutput(0, shape_constant); + } + + private: + DataType out_dtype_; +}; + +REGISTER_XLA_OP(Name("VariableShape"), VariableShapeOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc index 28a5e6a58bb312f4c4821bcce484a08160009d56..9b0e6174475c22e325c090bec5f1d56822e106bc 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc @@ -27,7 +27,6 @@ namespace tensorflow { // The current implementation simply unrolls the computation along the batch // dimension. -// TODO(andydavis): add batching support to XLA's Dot operator. xla::StatusOr BatchDot( xla::ComputationBuilder* builder, xla::ComputationDataHandle x, xla::ComputationDataHandle y, bool transpose_x, bool transpose_y) { @@ -52,26 +51,20 @@ xla::StatusOr BatchDot( // The batch dimensions must be equal and the matrix dimensions must be // valid. - std::vector dimensions; - int64 batch_count = 1; + std::vector batch_dimension_numbers; for (int i = 0; i < ndims - 2; ++i) { - int64 x_size = x_shape->dimensions(i); - int64 y_size = y_shape->dimensions(i); - if (x_size != y_size) { + if (x_shape->dimensions(i) != y_shape->dimensions(i)) { return errors::InvalidArgument( "Dimension ", i, " of inputs to BatchedDot must be equal: ", xla::ShapeUtil::HumanString(*x_shape), " vs ", xla::ShapeUtil::HumanString(*y_shape)); } - dimensions.push_back(x_size); - batch_count *= x_size; + batch_dimension_numbers.push_back(i); } int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1); int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2); - int64 x_inner_dim_size = x_shape->dimensions(x_inner_dim); - int64 y_inner_dim_size = y_shape->dimensions(y_inner_dim); - if (x_inner_dim_size != y_inner_dim_size) { + if (x_shape->dimensions(x_inner_dim) != y_shape->dimensions(y_inner_dim)) { return errors::InvalidArgument( "Dimensions ", x_inner_dim, " and ", y_inner_dim, " of arguments to BatchedDot must be equal: ", @@ -80,19 +73,22 @@ xla::StatusOr BatchDot( " transpose: ", transpose_y); } - // If there are no batch dimensions, use a regular Dot. This case exists - // to improve the readability of the emitted graphs. - if (dimensions.empty()) { - auto lhs = transpose_x ? builder->Transpose(x, {1, 0}) : x; - auto rhs = transpose_y ? builder->Transpose(y, {1, 0}) : y; - return builder->Dot(lhs, rhs); + // Check for zero lhs/rhs dim size. + if (xla::ShapeUtil::HasZeroElements(*x_shape) || + xla::ShapeUtil::HasZeroElements(*y_shape)) { + std::vector dimensions(batch_dimension_numbers.size()); + for (int i = 0; i < batch_dimension_numbers.size(); ++i) { + dimensions[i] = x_shape->dimensions(batch_dimension_numbers[i]); + } + int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2); + int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1); + dimensions.push_back(x_shape->dimensions(x_outer_dim)); + dimensions.push_back(y_shape->dimensions(y_outer_dim)); + return builder->Broadcast( + builder->ConstantLiteral(xla::Literal::Zero(x_shape->element_type())), + dimensions); } - int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2); - int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1); - dimensions.push_back(x_shape->dimensions(x_outer_dim)); - dimensions.push_back(y_shape->dimensions(y_outer_dim)); - if (x_shape->element_type() == xla::C64 && transpose_x) { x = builder->Conj(x); } @@ -100,55 +96,23 @@ xla::StatusOr BatchDot( y = builder->Conj(y); } - // Reshape input tensors into 3D tensors by flattening the batch - // dimensions. This makes it easier to unroll the batch dimension. - auto x_flat = - builder->Reshape(x, {batch_count, x_shape->dimensions(ndims - 2), - x_shape->dimensions(ndims - 1)}); - auto y_flat = - builder->Reshape(y, {batch_count, y_shape->dimensions(ndims - 2), - y_shape->dimensions(ndims - 1)}); - - // Slice batches into individual matrices and multiply them. - std::vector out_slices; - for (int64 i = 0; i < batch_count; ++i) { - // Slice off individual matrices and reshape to 2D tensors. - auto x_slice = builder->Slice( - x_flat, {i, 0, 0}, - {i + 1, x_shape->dimensions(ndims - 2), x_shape->dimensions(ndims - 1)}, - {1, 1, 1}); - x_slice = builder->Reshape(x_slice, {x_shape->dimensions(ndims - 2), - x_shape->dimensions(ndims - 1)}); - auto y_slice = builder->Slice( - y_flat, {i, 0, 0}, - {i + 1, y_shape->dimensions(ndims - 2), y_shape->dimensions(ndims - 1)}, - {1, 1, 1}); - y_slice = builder->Reshape(y_slice, {y_shape->dimensions(ndims - 2), - y_shape->dimensions(ndims - 1)}); - - // Transpose if needed. - auto lhs = transpose_x ? builder->Transpose(x_slice, {1, 0}) : x_slice; - auto rhs = transpose_y ? builder->Transpose(y_slice, {1, 0}) : y_slice; - - // Multiply matrices and add an outer singleton dimension to the output - // so we can concatenate along the flattened batch dimension later. - auto out = builder->Dot(lhs, rhs); - out = builder->Reshape(out, - {1, dimensions[ndims - 2], dimensions[ndims - 1]}); - out_slices.push_back(out); + // If there are no batch dimensions, use a regular Dot. + // TODO(b/69062148) Remove this code when Dot emitters can be passed + // dimensions to transpose directly (i.e. without requiring a Transpose HLO). + if (batch_dimension_numbers.empty()) { + auto lhs = transpose_x ? builder->Transpose(x, {1, 0}) : x; + auto rhs = transpose_y ? builder->Transpose(y, {1, 0}) : y; + return builder->Dot(lhs, rhs); } - // Concatenate output slices and reshape to original number of dimensions. - xla::ComputationDataHandle data; - if (out_slices.empty()) { - // It is illegal to pass an empty list to ConcatInDim. - // The batch count is empty, so both inputs must have zero elements. - // Arbitrarily use the left input as the argument to Reshape(). - data = x; - } else { - data = builder->ConcatInDim(out_slices, 0); + xla::DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); + dot_dnums.add_rhs_contracting_dimensions(y_inner_dim); + for (auto batch_dimension_number : batch_dimension_numbers) { + dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); + dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); } - return builder->Reshape(data, dimensions); + return builder->DotGeneral(x, y, dot_dnums); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index 7ffe0aa6df9b21c4311eb6c8d311fba1e115b3f4..ce24b61b5dc7176f3caa05e3eb9257399fef7926 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -28,7 +28,7 @@ limitations under the License. namespace tensorflow { xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder, - xla::Shape& shape) { + const xla::Shape& shape) { return builder->Broadcast( builder->ConstantLiteral(xla::Literal::Zero(shape.element_type())), xla::AsInt64Slice(shape.dimensions())); @@ -40,6 +40,9 @@ xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder, case xla::F16: return builder->ConstantR0(static_cast(value)); break; + case xla::BF16: + return builder->ConstantR0(static_cast(value)); + break; case xla::F32: return builder->ConstantR0(static_cast(value)); break; diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h index 8fba6b5cf247e9b2c26533c53ece8b0d7d4f4c36..fb138b4f736500aac8184770d97fbf930ced69ea 100644 --- a/tensorflow/compiler/tf2xla/lib/util.h +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -25,7 +25,7 @@ namespace tensorflow { // Returns a zero-filled tensor with shape `shape`. xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder, - xla::Shape& shape); + const xla::Shape& shape); // Returns a floating point scalar constant of 'type' with 'value'. // If 'type' is complex, returns a real value with zero imaginary component. diff --git a/tensorflow/compiler/tf2xla/sharding_util.cc b/tensorflow/compiler/tf2xla/sharding_util.cc index d9c839b61019b92b6de3a77a7bec610ae848a9a4..b08a7583cb5ab7efa30a1fa27b973d04992584a7 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.cc +++ b/tensorflow/compiler/tf2xla/sharding_util.cc @@ -14,34 +14,59 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/sharding_util.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { +namespace { +const char kDeviceSuffixReplicatedCore[] = "REPLICATED_CORE"; +const char kShardingAttribute[] = "_XlaSharding"; +} // namespace -static const char DEVICE_SUFFIX_REPLICATED_CORE[] = "REPLICATED_CORE"; +namespace { +xla::StatusOr> +GetShardingFromNodeDef(const NodeDef& node_def) { + if (!HasNodeAttr(node_def, kShardingAttribute)) { + return tensorflow::gtl::optional(); + } + string value; + xla::OpSharding sharding; + TF_RETURN_IF_ERROR(GetNodeAttr(node_def, kShardingAttribute, &value)); + if (!sharding.ParseFromString(value)) { + return xla::InvalidArgument( + "Experimental _XlaSharding attribute was not a valid encoded " + "xla::OpSharding proto."); + } + return tensorflow::gtl::optional(sharding); +} -static Status CoreOutOfRangeError(int core, int num_cores_per_replica) { +Status CoreOutOfRangeError(int core, int num_cores_per_replica) { return errors::InvalidArgument( "Invalid replicated core id: ", core, "; num_cores_per_replica=", num_cores_per_replica); } +} // namespace xla::StatusOr> -ParseShardingFromDevice(const string& device_name, int num_cores_per_replica) { +ParseShardingFromDevice( + const string& device_name, int num_cores_per_replica, + tensorflow::gtl::optional explicit_sharding) { if (device_name.empty()) { return tensorflow::gtl::optional(); } - DeviceNameUtils::ParsedName parsed_device; if (!DeviceNameUtils::ParseFullName(device_name, &parsed_device)) { return errors::InvalidArgument("Malformed assigned device '", device_name, "'"); } - if (!parsed_device.has_type || - !StringPiece(parsed_device.type) - .ends_with(DEVICE_SUFFIX_REPLICATED_CORE)) { + + if (explicit_sharding.has_value()) { + return explicit_sharding; + } else if (!parsed_device.has_type || !parsed_device.has_id || + !StringPiece(parsed_device.type) + .contains(kDeviceSuffixReplicatedCore)) { return tensorflow::gtl::optional(); } else { const int core = parsed_device.id; @@ -53,20 +78,34 @@ ParseShardingFromDevice(const string& device_name, int num_cores_per_replica) { } } +xla::StatusOr> +ParseShardingFromDevice(const NodeDef& node_def, int num_cores_per_replica) { + const string& device_name = node_def.device(); + TF_ASSIGN_OR_RETURN(tensorflow::gtl::optional sharding, + GetShardingFromNodeDef(node_def)); + return ParseShardingFromDevice(device_name, num_cores_per_replica, sharding); +} + xla::StatusOr> ParseShardingFromDevice(const Node& node, int num_cores_per_replica) { string device_name = node.assigned_device_name(); if (device_name.empty()) { device_name = node.requested_device(); } - return ParseShardingFromDevice(device_name, num_cores_per_replica); + TF_ASSIGN_OR_RETURN(tensorflow::gtl::optional sharding, + GetShardingFromNodeDef(node.def())); + return ParseShardingFromDevice(device_name, num_cores_per_replica, sharding); } + void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst) { string device_name = src.assigned_device_name(); if (device_name.empty()) { device_name = src.requested_device(); } dst->set_assigned_device_name(device_name); + if (const AttrValue* attr = src.attrs().Find(kShardingAttribute)) { + dst->AddAttr(kShardingAttribute, *attr); + } } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/sharding_util.h b/tensorflow/compiler/tf2xla/sharding_util.h index f6468bba9f950fec88dcc6b3ec760f014d3a0ef3..9e430e30a1247c7d01910b6d57f7c577964e1dd1 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.h +++ b/tensorflow/compiler/tf2xla/sharding_util.h @@ -29,14 +29,21 @@ namespace tensorflow { // - if the device name is invalid. // - the core is parsed and is out of the range [0, num_cores_per_replica). // -// Otherwise, returns either a non-value or a sharding set as per -// xla:ShardingBuilder::AssignDevice. +// Otherwise, returns either: +// - explicit_sharding if explicit_sharding.has_value() +// - a non-value if there is no assigned core or +// - a sharding set as per xla::ShardingBuilder::AssignDevice. xla::StatusOr> -ParseShardingFromDevice(const string& device_name, int num_cores_per_replica); +ParseShardingFromDevice(const string& device_name, int num_cores_per_replica, + tensorflow::gtl::optional + explicit_sharding = tensorflow::gtl::nullopt); xla::StatusOr> ParseShardingFromDevice(const Node& node, int num_cores_per_replica); +xla::StatusOr> +ParseShardingFromDevice(const NodeDef& node_def, int num_cores_per_replica); + void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index a14c93a2b9494b89f579bc20ee0510c136f8f01b..906f2290433face4cce3296b2f815d50d8c496ce 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -253,8 +253,7 @@ Status CreateXlaArgs(const Graph& graph, // Converts the TensorFlow graph into an XLA computation, by executing the // graph symbolically, with each op building up the XLA HLO. Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, - xla::Computation* computation, - bool* requires_runtime_context) { + xla::Computation* computation) { XlaOpRegistry::RegisterCompilationKernels(); for (Node* node : graph->nodes()) { node->set_assigned_device_name( @@ -277,7 +276,6 @@ Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, TF_RETURN_IF_ERROR(compiler.CompileGraph(XlaCompiler::CompileOptions(), "tfcompile", std::move(graph), xla_args, &result)); - *requires_runtime_context = result.requires_runtime_context; *computation = std::move(*result.computation); int num_const_results = 0; @@ -352,12 +350,10 @@ Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config, Status ConvertGraphDefToXla(const GraphDef& graph_def, const tf2xla::Config& config, xla::Client* client, - xla::Computation* computation, - bool* requires_runtime_context) { + xla::Computation* computation) { std::unique_ptr graph; TF_RETURN_IF_ERROR(InitGraph(graph_def, config, &graph)); - TF_RETURN_IF_ERROR(ConvertGraphToXla(std::move(graph), client, computation, - requires_runtime_context)); + TF_RETURN_IF_ERROR(ConvertGraphToXla(std::move(graph), client, computation)); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/tf2xla.h b/tensorflow/compiler/tf2xla/tf2xla.h index ab99beebf7946237425d4d304a858ac6817177b8..473c431b12d441c652f1d0d6c11c5e87836ab36d 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.h +++ b/tensorflow/compiler/tf2xla/tf2xla.h @@ -30,13 +30,9 @@ namespace tensorflow { // // The computation is built in the context of the given `client`, which may // subsequently be used to compile or execute the computation. -// -// If `requires_runtime_context` is filled with true, this indicates the last -// argument of the computation is XlaLocalRuntimeContext*. Status ConvertGraphDefToXla(const GraphDef& graph_def, const tf2xla::Config& config, xla::Client* client, - xla::Computation* computation, - bool* requires_runtime_context); + xla::Computation* computation); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..7aca889a266439538c4cd1c153460e6cc871b246 --- /dev/null +++ b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc @@ -0,0 +1,97 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/tf2xla_supported_ops.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace tf2xla { +namespace { + +void PrintSupportedOps(const string& device, const string& regen_run) { + XlaOpRegistry::RegisterCompilationKernels(); + + std::vector kdefs = + XlaOpRegistry::DeviceKernels(device, + /*include_compilation_only_kernels=*/true); + std::sort( + kdefs.begin(), kdefs.end(), + [](const KernelDef* a, const KernelDef* b) { return a->op() < b->op(); }); + + std::cout << "**Supported operators for device: " << device << "**\n\n" + << "Operator | Type Constraint\n" + << "-------- | ---------------" << std::endl; + for (const KernelDef* kdef : kdefs) { + std::vector constraints; + for (const KernelDef::AttrConstraint& constraint : kdef->constraint()) { + std::vector types; + for (int type : constraint.allowed_values().list().type()) { + types.push_back(DataTypeString(static_cast(type))); + } + std::sort(types.begin(), types.end()); + constraints.push_back("`" + constraint.name() + "={" + + str_util::Join(types, ",") + "}`"); + } + std::cout << "`" << kdef->op() << "` | " + << str_util::Join(constraints, "
") << std::endl; + } + + std::cout << "\nTo regenerate this table, run:\n\n```shell\n" + << regen_run << " --device=" << device << "\n```" << std::endl; +} + +} // namespace + +void SupportedOpsMain(int argc, char** argv, const char* regen_run) { + std::vector device_names = XlaOpRegistry::BackendNames(); + std::sort(device_names.begin(), device_names.end()); + + // Set up and parse flags. + string device; + std::vector flag_list = { + {"device", &device, + "Name of the compilation device for which to print supported ops, " + "one of: " + + str_util::Join(device_names, ",")}, + }; + string usage = Flags::Usage(argv[0], flag_list); + bool parsed_flags_ok = Flags::Parse(&argc, argv, flag_list); + QCHECK(parsed_flags_ok) << "\n" << usage; + QCHECK(XlaOpRegistry::IsBackendRegistered(device)) + << "\nUnknown device: " << device << "\n" + << usage; + + // Run the program. + port::InitMain(usage.c_str(), &argc, &argv); + QCHECK(argc == 1) << "\nERROR: This command does not take any arguments " + "other than flags\n\n" + << usage; + PrintSupportedOps(device, regen_run); +} + +} // namespace tf2xla +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_supported_ops.h b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..1b45fb4cdd3b0173b04e130b7416874a9a406dc5 --- /dev/null +++ b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.h @@ -0,0 +1,33 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_SUPPORTED_OPS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_SUPPORTED_OPS_H_ + +namespace tensorflow { +namespace tf2xla { + +// The implementation of a main function for a binary that prints a table of +// supported tf2xla operators for a given device, along with their type +// constraints, to stdout. +// +// Pass the argc and argv from main, unmodified. Use regen_run to specify the +// command used to regenerate the table. +void SupportedOpsMain(int argc, char** argv, const char* regen_run); + +} // namespace tf2xla +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_SUPPORTED_OPS_H_ diff --git a/tensorflow/compiler/tf2xla/tf2xla_supported_ops_main.cc b/tensorflow/compiler/tf2xla/tf2xla_supported_ops_main.cc new file mode 100644 index 0000000000000000000000000000000000000000..690666c2400d45e33c1a5d1818b68a86a70a5be3 --- /dev/null +++ b/tensorflow/compiler/tf2xla/tf2xla_supported_ops_main.cc @@ -0,0 +1,22 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/tf2xla_supported_ops.h" + +int main(int argc, char** argv) { + const char* regen_run = + "bazel run -c opt -- tensorflow/compiler/tf2xla:tf2xla_supported_ops"; + tensorflow::tf2xla::SupportedOpsMain(argc, argv, regen_run); +} diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index ecd15652fe84b0c19d2f7fc18f877236547f9be9..a9978e697b091715ce120f0d18fdddd259e08b32 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -70,10 +70,7 @@ TEST(ConvertGraphDefToXla, Sum) { xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie(); xla::Computation computation; - bool requires_runtime_context; - TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation, - &requires_runtime_context)); - ASSERT_FALSE(requires_runtime_context); + TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation)); // Set up arguments. auto x_literal = xla::Literal::CreateR0(10); diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index 55f2f3149c6ba7bfa18608f961c8a76103a50756..f428a194328935fec1210ea96245344de859e611 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -88,8 +88,8 @@ Status ValidateConfig(const tf2xla::Config& config) { TF_RETURN_IF_ERROR(CheckNameDuplicates("fetch", fetch.name(), &names)); } TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("fetch", names)); - if (config.feed().empty() || config.fetch().empty()) { - return errors::InvalidArgument("feeds and fetches must be specified"); + if (config.fetch().empty()) { + return errors::InvalidArgument("fetches must be specified"); } return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index 436039e154842443f779aba276bc571fc2ab7537..ed10d80609641b090cf78bf2e17364fe2fa89c31 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -58,24 +58,14 @@ TEST(ValidateConfig, Good) { TEST(ValidateConfig, BadEmpty) { tf2xla::Config config; - ExpectErrorContains(ValidateConfig(config), - "feeds and fetches must be specified"); -} - -TEST(ValidateConfig, BadNoFeed) { - tf2xla::Config config; - tf2xla::Fetch* fetch = config.add_fetch(); - fetch->mutable_id()->set_node_name("foo"); - ExpectErrorContains(ValidateConfig(config), - "feeds and fetches must be specified"); + ExpectErrorContains(ValidateConfig(config), "fetches must be specified"); } TEST(ValidateConfig, BadNoFetch) { tf2xla::Config config; tf2xla::Feed* feed = config.add_feed(); feed->mutable_id()->set_node_name("foo"); - ExpectErrorContains(ValidateConfig(config), - "feeds and fetches must be specified"); + ExpectErrorContains(ValidateConfig(config), "fetches must be specified"); } TEST(ValidateConfig, BadFeedNodeName) { diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index 4f32c29954b2d809d31ef8c584b6a6c3dcdf5cef..cc459dc87c00f19230c65341d53da213e07fe364 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -100,7 +100,7 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel, b->SetOpMetadata(metadata); auto sharding_parse_result = ParseShardingFromDevice( - op_kernel->requested_device(), std::numeric_limits::max()); + op_kernel->def(), std::numeric_limits::max()); OP_REQUIRES_OK(context, sharding_parse_result.status()); tensorflow::gtl::optional op_sharding = sharding_parse_result.ValueOrDie(); diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc index b5c17c5273bb15e20184b2fefd93880d4828105e..79da701fd244a461a60588153b601d5c1870fa89 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc @@ -28,9 +28,10 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, temps_(new void*[static_data.num_temps]), arg_names_(static_data.arg_names), result_names_(static_data.result_names), - program_shape_(static_data.program_shape) { + program_shape_(static_data.program_shape), + hlo_profile_printer_(static_data.hlo_profile_printer) { // Allocate arg and temp buffers. - if (alloc_mode == AllocMode::ARGS_RESULTS_AND_TEMPS) { + if (alloc_mode == AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) { alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( static_data.arg_sizes, static_data.num_args, args_, /*annotate_initialized=*/false); @@ -39,9 +40,13 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, static_data.temp_sizes, static_data.num_temps, temps_, /*annotate_initialized=*/true); - // The runtime context is always the last arg, if it is required. - if (static_data.requires_runtime_context) { - args_[static_data.num_args - 1] = &context_; + // If Hlo profiling is enabled the generated code expects an appropriately + // sized buffer to be passed in as the last argument. If Hlo profiling is + // disabled the last function argument is still present in the function + // signature, but it is ignored by the generated code and we pass in null for + // it. + if (hlo_profiling_enabled()) { + profile_counters_ = new int64[static_data.profile_counters_size](); } } @@ -50,6 +55,7 @@ XlaCompiledCpuFunction::~XlaCompiledCpuFunction() { tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_); delete[] args_; delete[] temps_; + delete[] profile_counters_; } namespace { diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index f49a7889222ff989144217ab10b27595f89e4311..e0ae3ed9a811bcc49ce8862037a67d293e879e57 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -16,10 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ -#include +#include #include -#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/core/platform/types.h" @@ -27,6 +26,7 @@ limitations under the License. // never use this functionality. namespace xla { class ProgramShape; +class HloProfilePrinter; } namespace tensorflow { @@ -48,12 +48,10 @@ namespace tensorflow { class XlaCompiledCpuFunction { public: // Type of the raw function, produced by either JIT or AOT. - // - // TODO(toddw): Add support for hlo profiling, and replace std::function with - // a raw function pointer, for some codesize savings. - using RawFunction = std::function; + using RawFunction = void (*)(void* result, + const xla::ExecutableRunOptions* run_options, + const void** args, void** temps, + int64* profile_counters); // StaticData represents the state necessary to run an XLA-compiled // function. For JIT this is backed by data in XlaJitCompiledCpuFunction; for @@ -71,9 +69,6 @@ class XlaCompiledCpuFunction { // The 0-based index of the result tuple, in the temp buffers. size_t result_index = 0; - // Is the final arg XlaLocalRuntimeContext? - bool requires_runtime_context = false; - // [Optional] Arrays of arg and result names. These are arrays of C-style // strings, where the array is terminated by nullptr. const char** arg_names = nullptr; @@ -81,21 +76,29 @@ class XlaCompiledCpuFunction { // [Optional] Arg and result shapes. const xla::ProgramShape* program_shape = nullptr; + + // [Optional] Profile printer. Null if profiling is disabled. + const xla::HloProfilePrinter* hlo_profile_printer = nullptr; + + // [Optional] The number of profile counters expected in the profile counter + // buffer by the generated code and hlo_profile_printer. 0 if profiling is + // disabled. + int64 profile_counters_size = 0; }; // AllocMode controls the buffer allocation mode. enum class AllocMode { - // Allocate all buffers - args, results and temps. - ARGS_RESULTS_AND_TEMPS, + // Allocate all buffers - args, results, profile and temps. + ARGS_RESULTS_PROFILES_AND_TEMPS, - // Only allocate result and temp buffers. + // Only allocate result, profile and temp buffers. // Use set_arg_data to set argument buffers before Run is called. - RESULTS_AND_TEMPS_ONLY, + RESULTS_PROFILES_AND_TEMPS_ONLY, }; XlaCompiledCpuFunction( const StaticData& static_data, - AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS); + AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS); virtual ~XlaCompiledCpuFunction(); XlaCompiledCpuFunction(const XlaCompiledCpuFunction&) = delete; @@ -104,21 +107,22 @@ class XlaCompiledCpuFunction { // Sets the intra-op thread pool used to run individual ops concurrently. void set_thread_pool(const Eigen::ThreadPoolDevice* pool) { run_options_.set_intra_op_thread_pool(pool); - context_.thread_pool = pool; } // Runs the computation, with inputs read from arg buffers, and outputs // written to result buffers. Returns true on success and false on failure. bool Run() { - context_.error = false; - context_.error_msg.clear(); raw_function_(temps_[result_index_], &run_options_, - const_cast(args_), temps_); - return !context_.error; + const_cast(args_), temps_, profile_counters_); + return true; } // Returns the error message from the previous failed Run call. - const string& error_msg() const { return context_.error_msg; } + // + // TODO(fschneider): For now this always returns an empty string because there + // is no support for error reporting in XLA. Remove this once all callers are + // updated. + string error_msg() const { return {}; } // ------------------------------ // Arg methods for managing input buffers. Buffers are in row-major order. @@ -141,10 +145,6 @@ class XlaCompiledCpuFunction { // tensorflow::tfcompile::runtime::kAlign. If possible, use the functions in // tensorflow/compiler/aot/runtime.h to ensure correct alignment. // - // If StaticData.requires_runtime_context==true, the final argument is an - // XlaLocalRuntimeContext, which is managed internally by this class, and - // should not be changed. - // // Aliasing of argument and result buffers is not allowed, and results in // undefined behavior. void set_arg_data(size_t index, void* data) { args_[index] = data; } @@ -162,6 +162,16 @@ class XlaCompiledCpuFunction { return static_cast(temps_[result_index_]); } + // Profile counters for this XLA computation. + // + // When Hlo profiling is enabled (`hlo_profiling_enabled()` return true in + // this case) these counters are non-null and are automatically populated by + // `Run`. The counters can then be pretty-printed using + // `hlo_profile_printer()`. + // + // When Hlo profiling is disabled, this accessor returns null. + const int64* profile_counters() const { return profile_counters_; } + // Returns the buffer for the positional result at the given `index`. void* result_data(size_t index) { return results()[index]; } const void* result_data(size_t index) const { return results()[index]; } @@ -195,6 +205,12 @@ class XlaCompiledCpuFunction { // program shape isn't available. const xla::ProgramShape* ProgramShape() const { return program_shape_; } + bool hlo_profiling_enabled() const { return hlo_profile_printer_ != nullptr; } + const xla::HloProfilePrinter& hlo_profile_printer() const { + assert(hlo_profiling_enabled()); + return *hlo_profile_printer_; + } + private: const RawFunction raw_function_; const size_t result_index_; @@ -208,14 +224,17 @@ class XlaCompiledCpuFunction { void* alloc_args_ = nullptr; void* alloc_temps_ = nullptr; + // Backing memory for profiling counters. + int64* profile_counters_ = nullptr; + // Options and context passed to the compiled function. xla::ExecutableRunOptions run_options_; - tensorflow::XlaLocalRuntimeContext context_; // Optional metadata. const char** arg_names_ = nullptr; const char** result_names_ = nullptr; const xla::ProgramShape* program_shape_ = nullptr; + const xla::HloProfilePrinter* hlo_profile_printer_ = nullptr; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 48cebdf74c71f974bf075e0255626ec57eb9a149..50da76e514c83912cbf864bdc3aaa7b8e4f77925 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -502,18 +502,6 @@ Status BuildComputation( return Status::OK(); } -void AssignMajorToMinorLayout(xla::Shape* shape) { - if (xla::ShapeUtil::IsTuple(*shape)) { - for (xla::Shape& elem_shape : *shape->mutable_tuple_shapes()) { - AssignMajorToMinorLayout(&elem_shape); - } - } else { - auto& minor_to_major = *shape->mutable_layout()->mutable_minor_to_major(); - minor_to_major.Resize(xla::ShapeUtil::Rank(*shape), 0); - std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); - } -} - } // namespace Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, @@ -543,8 +531,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, options.resolve_compile_time_constants); core::ScopedUnref context_unref(context); - result->tuple_arg = options.use_tuple_arg; - std::vector arg_expressions; std::vector arg_cores; TF_RETURN_IF_ERROR(BuildArguments( @@ -564,11 +550,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, result->computation.get(), &num_computation_outputs, &num_nonconst_outputs, &result->resource_updates)); - result->requires_runtime_context = context->has_context_parameter(); - - // Tuple arguments and runtime context parameters are incompatible. - TF_RET_CHECK(!(options.use_tuple_arg && result->requires_runtime_context)); - VLOG(2) << "Outputs: total: " << context->retvals().size() << " nonconstant: " << num_nonconst_outputs; result->outputs.resize(context->retvals().size()); @@ -596,7 +577,7 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, << xla::ShapeUtil::HumanString(result->xla_output_shape); // Tensorflow expects a major-to-minor order of results. - AssignMajorToMinorLayout(&result->xla_output_shape); + xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape); // Converts the output shapes to TensorShapes. int computation_output = 0; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index ac7d4cfb127d1de8c92f3a855191c45af77888ad..380e24e96bc713af4453f92a5359995e9ab4734a 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -54,8 +54,6 @@ namespace tensorflow { // +---------------------+-----------------------------------------+ // Within each block, the arguments are arranged by the _Arg index from which // they were derived. -// If `Options::requires_runtime_context` is true, then an additional runtime -// context argument is passed as a final argument. // // The run-time outputs of the XLA computation are arranged in the following // order: @@ -191,16 +189,9 @@ class XlaCompiler { // original arguments, and are not necessarily in the same order.) std::vector input_mapping; - // Does the computation require the local runtime context to be passed as - // the last argument? - bool requires_runtime_context = false; - // Input shapes of the computation. std::vector xla_input_shapes; - // Should the arguments be packed into a single tuple? - bool tuple_arg; - // Output shape in XLA format. The output shape is always a tuple. xla::Shape xla_output_shape; @@ -232,8 +223,7 @@ class XlaCompiler { int graph_def_version = TF_GRAPH_DEF_VERSION; // If 'allow_cpu_custom_calls' is true, kernels may make use of CustomCall() - // for CPU; additionally, an optional XlaLocalRuntimeContext* may be passed - // to the computation. + // for CPU. bool allow_cpu_custom_calls = false; // If not nullptr, populate_resource_manager is called with the diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 651bafd6c5d946adfedd63ebbe93e4ea016f0b37..5d19dd353fc04744e196bb50c35cb60b35d8b258 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -70,24 +70,6 @@ XlaContext::XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder, allow_cpu_custom_calls_(allow_cpu_custom_calls), resolve_compile_time_constants_(resolve_compile_time_constants) {} -const xla::ComputationDataHandle& -XlaContext::GetOrCreateRuntimeContextParameter() { - CHECK(allow_cpu_custom_calls_); - if (has_context_parameter_) return context_parameter_; - has_context_parameter_ = true; - - // Allocate the next available parameter for the context parameter. - int num_parameters = 0; - for (const XlaExpression& arg : args_) { - if (!arg.has_constant_value()) { - ++num_parameters; - } - } - context_parameter_ = builder_->Parameter( - num_parameters, xla::ShapeUtil::MakeOpaqueShape(), "tf_context"); - return context_parameter_; -} - string XlaContext::DebugString() { return "TLA JIT context"; } // This is called by the Retval Op to associate a computed value @@ -178,6 +160,20 @@ const xla::Computation* XlaContext::GetOrCreateAdd(const DataType type) { }); } +const xla::Computation* XlaContext::GetOrCreateMul(const DataType type) { + return LookupOrCreate(type, &mul_func_, [this, type] { + const string type_string = DataTypeString(type); + VLOG(1) << "Building Mul() for " << type_string; + xla::ComputationBuilder b(builder()->client(), "mul<" + type_string + ">"); + xla::PrimitiveType xla_type; + TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); + auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); + auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); + b.Mul(x, y); + return b.Build().ConsumeValueOrDie(); + }); +} + const xla::Computation* XlaContext::LookupOrCreate( DataType type, ComputationMap* out, const std::function& create) { diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index de8aafa3628e6eebdabbc508cd95a2ac86e3472f..1a7dafe8cdb56cc9b8fcd3ba6e262c21c2a07d90 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -56,15 +56,10 @@ class XlaContext : public ResourceBase { xla::ComputationBuilder* builder(); bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; } - bool has_context_parameter() const { return has_context_parameter_; } const std::vector& args() const { return args_; } void set_args(std::vector args); - // Get the runtime context parameter, adding one if it does not already exist. - // Dies if not compiling a local executable. - const xla::ComputationDataHandle& GetOrCreateRuntimeContextParameter(); - const std::vector& retvals() { return retvals_; } // This is called by the Retval Op to associate a computed value @@ -102,6 +97,11 @@ class XlaContext : public ResourceBase { // separate specialization of the computation for each DataType. const xla::Computation* GetOrCreateAdd(const DataType type); + // Get an XLA lambda to compute Mul. This is cached in the + // XlaContext since it may be used by multiple Ops. There is a + // separate specialization of the computation for each DataType. + const xla::Computation* GetOrCreateMul(const DataType type); + // The name of the XlaContext resource during symbolic graph execution. static const char kXlaContextResourceName[]; @@ -116,16 +116,9 @@ class XlaContext : public ResourceBase { const bool allow_cpu_custom_calls_; // If true, constant return values are returned as Tensors instead of - // run-time computation outptus. + // run-time computation outputs. const bool resolve_compile_time_constants_; - // When 'has_context_parameter_' is true, this is the computation handle - // for an additional final parameter to the computation, through which will be - // passed a XlaLocalRuntimeContext* at runtime. Created on demand by - // GetOrCreateRuntimeContextParameter(). - bool has_context_parameter_ = false; - xla::ComputationDataHandle context_parameter_; - // Arguments to the Tensorflow graph, indexed by _Arg index. // Includes both compile-time constant arguments and runtime parameters. std::vector args_; @@ -155,6 +148,9 @@ class XlaContext : public ResourceBase { // Cached computation to compute Sum of two elements, specialized by type. ComputationMap add_func_; + // Cached computation to compute Mul of two elements, specialized by type. + ComputationMap mul_func_; + // Cached computation to compute Sigmoid of an element, specialized by type. ComputationMap sigmoid_func_; diff --git a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc index d504613d232c779e47a506657d2825d052e726dc..8ca757e72355d890c13b8b448d35c327d3986696 100644 --- a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc +++ b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc @@ -21,8 +21,6 @@ namespace tensorflow { bool GpuOpFilter(KernelDef* kdef) { // TODO(b/31361304): The GPU backend does not parallelize PRNG ops, leading to // slow code. - // TODO(b/34969189) The implementation of TruncatedNormal generates illegal - // code on GPU. if (kdef->op() == "RandomStandardNormal" || kdef->op() == "RandomUniform" || kdef->op() == "RandomUniformInt" || kdef->op() == "TruncatedNormal") { return false; diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 9c3e15d2fa4c84af94d137f2e03107bcc980f4cd..ec9e535b707beec6ea26dc81c7ee76b1d4da9225 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file defines helper routines for Tla JIT compilation. +// This file defines helper routines for XLA compilation. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/lib/util.h" @@ -121,6 +121,8 @@ xla::ComputationDataHandle XlaHelpers::One(xla::ComputationBuilder* b, xla::ComputationDataHandle XlaHelpers::Epsilon(xla::ComputationBuilder* b, DataType data_type) { switch (data_type) { + case DT_BFLOAT16: + return b->ConstantR0(bfloat16::epsilon()); case DT_FLOAT: return b->ConstantR0(std::numeric_limits::epsilon()); case DT_DOUBLE: @@ -169,6 +171,9 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral( case xla::S16: case xla::U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; + case xla::BF16: + literal = *xla::Literal::CreateR0(static_cast(value)); + break; case xla::F16: literal = *xla::Literal::CreateR0(static_cast(value)); diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index 1dd454ea8d57e21526e5bcde0c8efc5514983b93..584417bc72c8f6645c05912e857b031cfb394e54 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -37,27 +37,14 @@ namespace { // Returns a vector of positional argument buffer sizes. xla::StatusOr> ComputeArgSizes( - const xla::ProgramShape& program_shape, bool requires_runtime_context) { + const xla::ProgramShape& program_shape) { std::vector arg_sizes; const size_t num_args = program_shape.parameters_size(); arg_sizes.reserve(num_args); for (int i = 0; i < num_args; ++i) { const xla::Shape& arg_shape = program_shape.parameters(i); - if (i == num_args - 1 && requires_runtime_context) { - // If the compiled function needs an XlaLocalRuntimeContext* arg, it's - // always last, and must be represented as an opaque type. - const xla::PrimitiveType type = arg_shape.element_type(); - if (type != xla::OPAQUE) { - return errors::InvalidArgument( - "expected final context arg to be opaque, but got type: ", - xla::PrimitiveType_Name(type), ", from program shape: ", - xla::ShapeUtil::HumanString(program_shape)); - } - arg_sizes.push_back(-1); - } else { - constexpr size_t kPointerSize = sizeof(void*); - arg_sizes.push_back(xla::ShapeUtil::ByteSizeOf(arg_shape, kPointerSize)); - } + constexpr size_t kPointerSize = sizeof(void*); + arg_sizes.push_back(xla::ShapeUtil::ByteSizeOf(arg_shape, kPointerSize)); } return std::move(arg_sizes); } @@ -90,21 +77,6 @@ xla::StatusOr ComputeResultIndex( return result_slice.index(); } -// Adapt ComputeFunctionType, which includes a final profile_counters arg, to -// RawFunction, which doesn't include that final arg. -// -// TODO(toddw): Change RawFunction and AOT to also pass the final -// profile_counters arg, and remove this adapter. -XlaCompiledCpuFunction::RawFunction RawFunctionAdapter( - xla::cpu::CpuExecutable::ComputeFunctionType compute_function) { - return [compute_function](void* result, - const xla::ExecutableRunOptions* run_options, - const void** args, void** temps) { - return compute_function(result, run_options, args, temps, - /*profile_counters=*/nullptr); - }; -} - // Collect names from `entries`, where T is one of tf2xla::{Feed,Fetch}. We hold // the actual strings in nonempty_names, and hold arrays of pointers in // name_ptrs, terminated by a nullptr entry. @@ -144,9 +116,8 @@ XlaJitCompiledCpuFunction::Compile( TF_ASSIGN_OR_RETURN(xla::LocalClient * client, xla::ClientLibrary::GetOrCreateLocalClient()); xla::Computation computation; - bool requires_runtime_context; - TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToXla( - graph_def, config, client, &computation, &requires_runtime_context)); + TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToXla(graph_def, config, client, + &computation)); // Get and verify the program shape. TF_ASSIGN_OR_RETURN(std::unique_ptr program_shape, @@ -177,14 +148,13 @@ XlaJitCompiledCpuFunction::Compile( const xla::cpu::CpuExecutable* cpu_executable = static_cast(executable->executable()); XlaCompiledCpuFunction::RawFunction raw_function = - RawFunctionAdapter(cpu_executable->compute_function()); + cpu_executable->compute_function(); const xla::BufferAssignment& buffer_assignment = cpu_executable->buffer_assignment(); // Compute buffer sizes and the result index, needed to run the raw function. - TF_ASSIGN_OR_RETURN( - std::vector arg_sizes, - ComputeArgSizes(*program_shape, requires_runtime_context)); + TF_ASSIGN_OR_RETURN(std::vector arg_sizes, + ComputeArgSizes(*program_shape)); TF_ASSIGN_OR_RETURN(std::vector temp_sizes, ComputeTempSizes(buffer_assignment)); TF_ASSIGN_OR_RETURN(size_t result_index, @@ -203,7 +173,6 @@ XlaJitCompiledCpuFunction::Compile( jit->static_data_.temp_sizes = jit->temp_sizes_.data(); jit->static_data_.num_temps = jit->temp_sizes_.size(); jit->static_data_.result_index = result_index; - jit->static_data_.requires_runtime_context = requires_runtime_context; // Optional metadata is collected and set below. CollectNames(config.feed(), &jit->nonempty_arg_names_, &jit->arg_names_); CollectNames(config.fetch(), &jit->nonempty_result_names_, @@ -211,6 +180,14 @@ XlaJitCompiledCpuFunction::Compile( jit->static_data_.arg_names = jit->arg_names_.data(); jit->static_data_.result_names = jit->result_names_.data(); jit->static_data_.program_shape = jit->program_shape_.get(); + + if (cpu_executable->hlo_profiling_enabled()) { + jit->static_data_.hlo_profile_printer = + &cpu_executable->hlo_profile_printer(); + jit->static_data_.profile_counters_size = + cpu_executable->hlo_profile_printer().profile_counters_size(); + } + return std::move(jit_unique_ptr); } diff --git a/tensorflow/compiler/tf2xla/xla_local_runtime_context.h b/tensorflow/compiler/tf2xla/xla_local_runtime_context.h deleted file mode 100644 index dca420d6ee3fec45f88ac3b450ab0cb4fb83d38a..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/xla_local_runtime_context.h +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_LOCAL_RUNTIME_CONTEXT_H_ -#define TENSORFLOW_COMPILER_TF2XLA_XLA_LOCAL_RUNTIME_CONTEXT_H_ - -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/types.h" - -// Forward-declare the ThreadPoolDevice so that it can be ignored unless it's -// actually used. E.g. some ahead-of-time compiled computations don't need a -// thread pool. -namespace Eigen { -struct ThreadPoolDevice; -} - -namespace tensorflow { - -// An instance of this class is passed to each call from tensorflow into a -// compiled XLA computation. See xla_launch_ops.cc. -struct XlaLocalRuntimeContext { - public: - XlaLocalRuntimeContext() {} - - // Kernels implemented using custom call ops set this if they encounter an - // error. The error is checked after the entire XLA computation is - // complete. - // - // error+error_msg are used instead of Status to reduce the binary size - // overhead for ahead-of-time compiled binaries. - bool error = false; - string error_msg; - - // Kernels that need a thread pool can get it from here. - const Eigen::ThreadPoolDevice* thread_pool = nullptr; - - private: - TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalRuntimeContext); -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_LOCAL_RUNTIME_CONTEXT_H_ diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index a052bb105e7d3e47f2427c98ce47e52d95af78d9..79d501b511bf37ba4a79ab9d375d6f789a36889b 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -346,9 +346,9 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { } void XlaOpKernelContext::SetInvalidOutput(int index) { - const TensorShape shape; Tensor* output = nullptr; - OP_REQUIRES_OK(context_, context_->allocate_output(index, shape, &output)); + OP_REQUIRES_OK(context_, + context_->allocate_output(index, TensorShape({}), &output)); XlaExpression* expression = CastExpressionFromUninitializedTensor(output); xla::ComputationDataHandle handle; handle.set_handle(0); @@ -417,6 +417,11 @@ const xla::Computation* XlaOpKernelContext::GetOrCreateAdd( return XlaContext::Get(context_).GetOrCreateAdd(type); } +const xla::Computation* XlaOpKernelContext::GetOrCreateMul( + const DataType type) { + return XlaContext::Get(context_).GetOrCreateMul(type); +} + XlaOpKernel::XlaOpKernel(OpKernelConstruction* context) : OpKernel(context) {} void XlaOpKernel::Compute(OpKernelContext* context) { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 76bcf594e6a0601763844847583c18ee26d8adf3..f1ae81a5aa9d507a3e0dd577568377385b1844e6 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -178,7 +178,7 @@ class XlaOpKernelContext { // If this kernel invocation is within a function execution, // call_frame() returns the call frame for the function call. - FunctionCallFrame* call_frame() const { return context_->call_frame(); } + CallFrameInterface* call_frame() const { return context_->call_frame(); } FunctionLibraryRuntime* function_library() const { return context_->function_library(); @@ -210,6 +210,11 @@ class XlaOpKernelContext { // separate specialization of the computation for each DataType. const xla::Computation* GetOrCreateAdd(const DataType type); + // Gets an XLA lambda to compute Mul. This is cached in the + // XlaContext since it may be used by multiple Ops. There is a + // separate specialization of the computation for each DataType. + const xla::Computation* GetOrCreateMul(const DataType type); + private: OpKernelContext* const context_; }; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 02318cf7fa1d4edc12507f6b4d66a8e897cbe100..faf47434b5dc6b569ec4f9c91a8667de275a6315 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -187,22 +188,39 @@ void XlaOpRegistry::RegisterCompilationKernels() { // Constrain each type attribute to the intersection of: // a) the types supported by the backend, and - // b) the attribute's type constraints. - // TODO(phawkins): it may be necessary to also take the intersection with - // the set of types supported by the OpDef. + // b) the types allowed by the OpDef, and + // c) the type constraints. for (const string& type_attr : type_attrs) { KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint(); attr_constraint->set_name(type_attr); auto* allowed_values = attr_constraint->mutable_allowed_values()->mutable_list(); - auto it = op_registration->type_constraints.find(type_attr); + const OpDef::AttrDef& op_def_attr = *FindAttr(type_attr, *op_def); + const auto* op_def_allowed_types = + op_def_attr.has_allowed_values() + ? &op_def_attr.allowed_values().list().type() + : nullptr; + auto constraint_it = op_registration->type_constraints.find(type_attr); + const std::set* type_constraints = + constraint_it != op_registration->type_constraints.end() + ? &constraint_it->second + : nullptr; for (DataType dtype : backend.second.supported_types) { - if (it == op_registration->type_constraints.end() || - (it != op_registration->type_constraints.end() && - it->second.find(dtype) != it->second.end())) { - allowed_values->add_type(dtype); + // Filter out types that aren't allowed by the OpDef. + if (op_def_allowed_types != nullptr && + std::find(op_def_allowed_types->begin(), + op_def_allowed_types->end(), + dtype) == op_def_allowed_types->end()) { + continue; } + // Filter out types based on the type constraints. + if (type_constraints != nullptr && + type_constraints->find(dtype) == type_constraints->end()) { + continue; + } + // Passed all the filters, this type is allowed. + allowed_values->add_type(dtype); } if (op_registration->allow_resource_types) { allowed_values->add_type(DT_RESOURCE); @@ -245,6 +263,22 @@ std::vector XlaOpRegistry::DeviceKernels( return kernels; } +std::vector XlaOpRegistry::BackendNames() { + std::vector names; + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + for (const auto& backend_pair : registry.backends_) { + names.push_back(backend_pair.first); + } + return names; +} + +bool XlaOpRegistry::IsBackendRegistered(const string& name) { + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + return registry.backends_.find(name) != registry.backends_.end(); +} + XlaOpRegistry& XlaOpRegistry::Instance() { static XlaOpRegistry* r = new XlaOpRegistry; return *r; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 6aee8c91cc01b4382ef867fa8e438eede008ac73..8bfd9758f7af9c6b7ed20954e72f953b629b28a6 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -45,11 +45,11 @@ extern const char* const DEVICE_GPU_XLA_JIT; // "GPU_XLA_JIT" extern const char* const DEVICE_XLA_CPU; extern const char* const DEVICE_XLA_GPU; -constexpr std::array kFloatTypes = { - {DT_HALF, DT_FLOAT, DT_DOUBLE}}; -constexpr std::array kNumericTypes = { +constexpr std::array kFloatTypes = { + {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}}; +constexpr std::array kNumericTypes = { {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, - DT_COMPLEX64}}; + DT_COMPLEX64, DT_BFLOAT16}}; constexpr std::array kCpuAllTypes = { {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, @@ -97,6 +97,12 @@ class XlaOpRegistry { gtl::ArraySlice supported_types, BackendOpFilter op_filter); + // Returns the names of the registered backends. + static std::vector BackendNames(); + + // Returns true iff a backend with the given name is registered. + static bool IsBackendRegistered(const string& name); + // Registers `device_name` for XLA compilation, using information from // `registration`. static void RegisterCompilationDevice(const string& device_name, @@ -116,8 +122,8 @@ class XlaOpRegistry { static void RegisterCompilationKernels(); // Returns KernelDefs for compilation ops registered on - // 'compilation_device_name'. - // Does not include kernels registered as CompilationOnly. + // 'compilation_device_name'. Does not include kernels registered as + // CompilationOnly, iff include_compilation_only_kernels=false. static std::vector DeviceKernels( const string& compilation_device_name, bool include_compilation_only_kernels); diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index d3f292207fee396fb4248dede5c0eeb5cd2b87c9..cd69c69889b2487ad12abea275e79fee4f5c51e6 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -20,6 +20,10 @@ package_group( load("//tensorflow:tensorflow.bzl", "cc_header_only_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_proto_library_py", +) # Filegroup used to collect source files for dependency checking. filegroup( @@ -36,6 +40,12 @@ xla_proto_library( visibility = ["//visibility:public"], ) +tf_proto_library_py( + name = "xla_data_proto", # bzl adds a _py suffix + srcs = ["xla_data.proto"], + visibility = ["//visibility:public"], +) + xla_proto_library( name = "xla_proto", srcs = ["xla.proto"], diff --git a/tensorflow/compiler/xla/array3d.h b/tensorflow/compiler/xla/array3d.h index e9449f01ad69a5722f53cce09e2884e20a0def5a..a1c5840a5f3874e27043c821ed4684da2fa6c542 100644 --- a/tensorflow/compiler/xla/array3d.h +++ b/tensorflow/compiler/xla/array3d.h @@ -36,6 +36,8 @@ namespace xla { template class Array3D : public Array { public: + Array3D() : Array(std::vector{0, 0, 0}) {} + // Creates an array of dimensions n1 x n2 x n3, uninitialized values. Array3D(const int64 n1, const int64 n2, const int64 n3) : Array(std::vector{n1, n2, n3}) {} diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index a716159f9e74041c4823ad20b46fa94c2d7b9d8c..c28380b689c7a0e16bf0bcbf15003f4aa15e42a7 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -67,6 +67,15 @@ class Client { std::vector arguments; ExecutionOptions execution_options; ExecutionProfile* execution_profile; + + ComputationInstance(const Computation& computation, + std::vector arguments, + ExecutionOptions execution_options, + ExecutionProfile* execution_profile) + : computation(computation), + arguments(std::move(arguments)), + execution_options(execution_options), + execution_profile(execution_profile) {} }; // Executes a list ComputationInstances and returns global data produced from @@ -133,7 +142,7 @@ class Client { // Returns a vector of global data handles that point to the tuple elements. StatusOr>> DeconstructTuple( - const GlobalData& computation); + const GlobalData& data); // Retrieves the statistics of the given computation. StatusOr GetComputationStats( diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index 763d94e94c2167f47b3f0777a31815f02791aa9e..317dcb4e41723b93e7e50d911f16e48bc3505a09 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -153,6 +153,7 @@ bool ComputationBuilder::MakeWindow( } else { dim->set_window_dilation(1); } + dim->set_window_reversal(false); } return true; } @@ -624,7 +625,41 @@ ComputationDataHandle ComputationBuilder::Lt( ComputationDataHandle ComputationBuilder::Dot( const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) { - return BinaryOp(BINOP_DOT, lhs, rhs, /*broadcast_dimensions=*/{}); + StatusOr> lhs_shape_or_status = GetShape(lhs); + if (!lhs_shape_or_status.ok()) { + NoteError(lhs_shape_or_status.status()); + return ComputationDataHandle(); + } + std::unique_ptr lhs_shape = lhs_shape_or_status.ConsumeValueOrDie(); + + DotDimensionNumbers dimension_numbers; + dimension_numbers.add_lhs_contracting_dimensions( + lhs_shape->dimensions_size() == 1 ? 0 : 1); + dimension_numbers.add_rhs_contracting_dimensions(0); + return DotGeneral(lhs, rhs, dimension_numbers); +} + +ComputationDataHandle ComputationBuilder::DotGeneral( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + const DotDimensionNumbers& dimension_numbers) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + DotRequest request; + *request.mutable_lhs() = lhs; + *request.mutable_rhs() = rhs; + *request.mutable_dimension_numbers() = dimension_numbers; + + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_dot_request() = request; + AddCommonFieldsToOpRequest(&op_request); + OpResponse response; + + VLOG(2) << "making Dot request"; + Status s = client_->stub()->Op(&op_request, &response); + return ParseOpResponse(s, &response); } ComputationDataHandle ComputationBuilder::Conv( @@ -693,11 +728,15 @@ bool ComputationBuilder::VerifyConvolution( } return true; }; - return check_spatial_dimensions("spatial_dimensions", - dimension_numbers.spatial_dimensions()) && + return check_spatial_dimensions( + "input_spatial_dimensions", + dimension_numbers.input_spatial_dimensions()) && check_spatial_dimensions( "kernel_spatial_dimensions", - dimension_numbers.kernel_spatial_dimensions()); + dimension_numbers.kernel_spatial_dimensions()) && + check_spatial_dimensions( + "output_spatial_dimensions", + dimension_numbers.output_spatial_dimensions()); } ComputationDataHandle ComputationBuilder::ConvWithGeneralDimensions( @@ -729,11 +768,11 @@ ComputationDataHandle ComputationBuilder::ConvWithGeneralDimensions( } std::vector base_area_dimensions( - dimension_numbers.spatial_dimensions_size()); + dimension_numbers.input_spatial_dimensions_size()); for (std::vector::size_type i = 0; i < base_area_dimensions.size(); ++i) { base_area_dimensions[i] = - lhs_shape->dimensions(dimension_numbers.spatial_dimensions(i)); + lhs_shape->dimensions(dimension_numbers.input_spatial_dimensions(i)); } std::vector window_dimensions( @@ -1163,6 +1202,34 @@ ComputationDataHandle ComputationBuilder::ConvertElementType( return ParseOpResponse(s, &response); } +ComputationDataHandle ComputationBuilder::BitcastConvertType( + const ComputationDataHandle& operand, PrimitiveType new_element_type) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + StatusOr> shape_status = GetShape(operand); + if (!shape_status.ok()) { + first_error_ = shape_status.status(); + return ComputationDataHandle(); + } + std::unique_ptr original = shape_status.ConsumeValueOrDie(); + + ConvertRequest request; + *request.mutable_operand() = operand; + request.set_new_element_type(new_element_type); + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_bitcast_convert_request() = request; + AddCommonFieldsToOpRequest(&op_request); + OpResponse response; + + VLOG(2) << "making bitcast convert request"; + Status s = client_->stub()->Op(&op_request, &response); + + return ParseOpResponse(s, &response); +} + ComputationDataHandle ComputationBuilder::SquareF32( const ComputationDataHandle& operand) { return BinaryOp(BINOP_POW, operand, ConstantR0(2.0), @@ -1437,6 +1504,34 @@ ComputationDataHandle ComputationBuilder::While( return ParseOpResponse(s, &response); } +ComputationDataHandle ComputationBuilder::Conditional( + const ComputationDataHandle& predicate, + const ComputationDataHandle& true_operand, + const Computation& true_computation, + const ComputationDataHandle& false_operand, + const Computation& false_computation) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + ConditionalRequest request; + *request.mutable_predicate() = predicate; + *request.mutable_true_operand() = true_operand; + *request.mutable_true_computation() = true_computation.handle(); + *request.mutable_false_operand() = false_operand; + *request.mutable_false_computation() = false_computation.handle(); + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_conditional_request() = request; + AddCommonFieldsToOpRequest(&op_request); + OpResponse response; + + VLOG(2) << "making conditional op request"; + Status s = client_->stub()->Op(&op_request, &response); + + return ParseOpResponse(s, &response); +} + ComputationDataHandle ComputationBuilder::Reduce( const ComputationDataHandle& operand, const ComputationDataHandle& init_value, const Computation& computation, @@ -1816,25 +1911,27 @@ ComputationBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { dimension_numbers.set_kernel_input_feature_dimension( kConvKernelInputDimension); for (int i = 0; i < num_spatial_dims; ++i) { - dimension_numbers.add_spatial_dimensions(i + 2); + dimension_numbers.add_input_spatial_dimensions(i + 2); dimension_numbers.add_kernel_spatial_dimensions(i + 2); + dimension_numbers.add_output_spatial_dimensions(i + 2); } return dimension_numbers; } /* static */ StatusOr ComputationBuilder::CreateConvDimensionNumbers( - int64 input_batch, int64 input_feature, int64 output_batch, - int64 output_feature, int64 first_spatial, int64 second_spatial, + int64 input_batch, int64 input_feature, int64 input_first_spatial, + int64 input_second_spatial, int64 output_batch, int64 output_feature, + int64 output_first_spatial, int64 output_second_spatial, int64 kernel_output_feature, int64 kernel_input_feature, int64 kernel_first_spatial, int64 kernel_second_spatial) { - if (std::set( - {input_batch, input_feature, first_spatial, second_spatial}) + if (std::set({input_batch, input_feature, input_first_spatial, + input_second_spatial}) .size() != 4) { return FailedPrecondition( "dimension numbers for the input are not unique: (%lld, %lld, %lld, " "%lld)", - input_batch, input_feature, first_spatial, second_spatial); + input_batch, input_feature, input_first_spatial, input_second_spatial); } if (std::set({kernel_output_feature, kernel_input_feature, kernel_first_spatial, kernel_second_spatial}) @@ -1845,25 +1942,28 @@ ComputationBuilder::CreateConvDimensionNumbers( kernel_output_feature, kernel_input_feature, kernel_first_spatial, kernel_second_spatial); } - if (std::set( - {output_batch, output_feature, first_spatial, second_spatial}) + if (std::set({output_batch, output_feature, output_first_spatial, + output_second_spatial}) .size() != 4) { return FailedPrecondition( "dimension numbers for the output are not unique: (%lld, %lld, %lld, " "%lld)", - output_batch, output_feature, first_spatial, second_spatial); + output_batch, output_feature, output_first_spatial, + output_second_spatial); } ConvolutionDimensionNumbers dimension_numbers; dimension_numbers.set_input_batch_dimension(input_batch); dimension_numbers.set_input_feature_dimension(input_feature); - dimension_numbers.set_output_batch_dimension(output_batch); - dimension_numbers.set_output_feature_dimension(output_feature); - dimension_numbers.add_spatial_dimensions(first_spatial); - dimension_numbers.add_spatial_dimensions(second_spatial); + dimension_numbers.add_input_spatial_dimensions(input_first_spatial); + dimension_numbers.add_input_spatial_dimensions(input_second_spatial); dimension_numbers.set_kernel_output_feature_dimension(kernel_output_feature); dimension_numbers.set_kernel_input_feature_dimension(kernel_input_feature); dimension_numbers.add_kernel_spatial_dimensions(kernel_first_spatial); dimension_numbers.add_kernel_spatial_dimensions(kernel_second_spatial); + dimension_numbers.set_output_batch_dimension(output_batch); + dimension_numbers.set_output_feature_dimension(output_feature); + dimension_numbers.add_output_spatial_dimensions(output_first_spatial); + dimension_numbers.add_output_spatial_dimensions(output_second_spatial); return dimension_numbers; } diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index 4c6e320557f9202b738333fc2066ac4394fcff6b..28889ece73f5da72c3eea681c9e4aea7351d3d54 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -121,14 +121,10 @@ class ComputationBuilder { // result, OpMetadata is set on the Computation Builder. All subsequent // instructions generated via this Computation Builder will have the same // OpMetadata attached until a call to ClearOpMetdata. - void SetOpMetadata(const OpMetadata& metadata) { - metadata_ = metadata; - } + void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; } // Clears the HloMetadata state. - void ClearOpMetadata() { - metadata_.Clear(); - } + void ClearOpMetadata() { metadata_.Clear(); } // Sets an OpSharding that will be attached to all instructions until cleared. void SetSharding(const OpSharding& sharding) { sharding_ = sharding; } @@ -397,6 +393,11 @@ class ComputationBuilder { ComputationDataHandle Dot(const ComputationDataHandle& lhs, const ComputationDataHandle& rhs); + // Enqueues a general dot instruction onto the computation. + ComputationDataHandle DotGeneral( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + const DotDimensionNumbers& dimension_numbers); + // Default dimension numbers used for a 2D convolution. static constexpr int64 kConvBatchDimension = 0; static constexpr int64 kConvFeatureDimension = 1; @@ -417,8 +418,9 @@ class ComputationBuilder { // Creates a ConvolutionDimensionNumbers with the given arguments. Returns an // error if either the input or the weight dimension numbers have conflicts. static StatusOr CreateConvDimensionNumbers( - int64 input_batch, int64 input_feature, int64 output_batch, - int64 output_feature, int64 first_spatial, int64 second_spatial, + int64 input_batch, int64 input_feature, int64 input_first_spatial, + int64 input_second_spatial, int64 output_batch, int64 output_feature, + int64 output_first_spatial, int64 output_second_spatial, int64 kernel_output_feature, int64 kernel_input_feature, int64 kernel_first_spatial, int64 kernel_second_spatial); @@ -673,6 +675,13 @@ class ComputationBuilder { ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand, PrimitiveType new_element_type); + // Enqueues a no-op instruction onto the computation that changes + // the element type of the operand array to primitive_type. The + // bit-widths of the source and destination element types must be + // identical. + ComputationDataHandle BitcastConvertType(const ComputationDataHandle& operand, + PrimitiveType new_element_type); + // Enqueues a float32 reciprocal instruction onto the computation. // (float32 is specified as there is an implicit float32 -1.0f constant // exponent). @@ -732,6 +741,13 @@ class ComputationBuilder { const Computation& body, const ComputationDataHandle& init); + // Enqueues a conditional node onto the computation. + ComputationDataHandle Conditional(const ComputationDataHandle& predicate, + const ComputationDataHandle& true_operand, + const Computation& true_computation, + const ComputationDataHandle& false_operand, + const Computation& false_computation); + // Enqueues a ReducePrecision node onto the computation. ComputationDataHandle ReducePrecision(const ComputationDataHandle& operand, const int exponent_bits, @@ -807,7 +823,7 @@ class ComputationBuilder { // The operand must represent a constant value, which in this case // means that it must not statically depend on any parameter of the // computation that is being built other then the ones specified on the - // paramtere list. The parameters in the list will be indexed by their + // parameter list. The parameters in the list will be indexed by their // parameter id property so the number of parameters specified should be at // least as many as the largest used parameter index. // diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index d936bd870b8b4e63e5c9b067478c19dd2e42006a..5f2b55713e342aa3d0251386d57cb52481fe748d 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -51,7 +51,7 @@ std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, std::unique_ptr MakeFakeDataOrDie(const Shape& shape, Client* client) { - if (ShapeUtil::ByteSizeOf(shape) < (1LL << 30)) { + if (ShapeUtil::ByteSizeOf(shape) < (1LL << 20)) { StatusOr> literal_status = MakeFakeLiteral(shape); if (!literal_status.ok()) { // If we got an Unimplemented error, fall back to making the fake data via diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index c3c664f76af78507925274455dc35b2902f0ac4a..7900246a4937a15fda0502c44cd9762c789109a0 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -78,14 +78,14 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions( } for (int i = 0; i < arguments.size(); ++i) { if (!computation_layout.parameter_layout(i).MatchesLayoutInShape( - arguments[i]->shape())) { + arguments[i]->on_host_shape())) { return InvalidArgument( "argument does not match shape or layout of computation parameter " "%d: expected %s, got %s", i, ShapeUtil::HumanString(computation_layout.parameter_layout(i).shape()) .c_str(), - ShapeUtil::HumanString(arguments[i]->shape()).c_str()); + ShapeUtil::HumanString(arguments[i]->on_host_shape()).c_str()); } } @@ -275,22 +275,15 @@ StatusOr> LocalClient::Compile( device_ordinal, options)); } -// Copy the literal data to the device with the given ordinal and return as a -// ScopedShapedBuffer. The given memory allocator is used for device memory -// allocation. StatusOr> LocalClient::LiteralToShapedBuffer(const Literal& literal, int device_ordinal, DeviceMemoryAllocator* allocator) { if (allocator == nullptr) { allocator = backend().memory_allocator(); } - TF_ASSIGN_OR_RETURN( - auto scoped_buffer, - ScopedShapedBuffer::Allocate( - literal.shape(), allocator, device_ordinal, - [this](const Shape& shape) { - return backend().transfer_manager()->GetByteSizeRequirement(shape); - })); + TF_ASSIGN_OR_RETURN(auto scoped_buffer, + backend().transfer_manager()->AllocateScopedShapedBuffer( + literal.shape(), allocator, device_ordinal)); TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device_ordinal)); TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( @@ -298,8 +291,6 @@ LocalClient::LiteralToShapedBuffer(const Literal& literal, int device_ordinal, return std::move(scoped_buffer); } -// Copy the data from the device contained in the given ShapedBuffer and -// return as a Literal. StatusOr> LocalClient::ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer) { TF_ASSIGN_OR_RETURN( @@ -309,4 +300,22 @@ StatusOr> LocalClient::ShapedBufferToLiteral( shaped_buffer); } +Status LocalClient::TransferToInfeedLocal(const Literal& literal, + int device_ordinal) { + TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, + backend().stream_executor(device_ordinal)); + return backend().transfer_manager()->TransferLiteralToInfeed(executor, + literal); +} + +StatusOr> LocalClient::TransferFromOutfeedLocal( + const Shape& shape, int device_ordinal) { + TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, + backend().stream_executor(device_ordinal)); + auto literal = MakeUnique(); + TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed( + executor, shape, literal.get())); + return std::move(literal); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 32fe0d9f84e56f44e4098571e558c7e846d003b5..3ca0d2ef5513cfb6b0dbfbc63b311f81a318356e 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -162,6 +162,20 @@ class LocalClient : public Client { StatusOr> ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer); + // Transfer the given literal to the infeed queue of the given device. + // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does + // not inherit from Client and there is no possibility of confusion with + // Client::TransferToInfeed. + Status TransferToInfeedLocal(const Literal& literal, int device_ordinal); + + // Transfer and return a value of the given shape from the outfeed of the + // given device. + // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does + // not inherit from Client and there is no possibility of confusion with + // Client::TransferFromOutfeed. + StatusOr> TransferFromOutfeedLocal( + const Shape& shape, int device_ordinal); + // Returns the platform that the underlying service targets. perftools::gputools::Platform* platform() const; diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc index 76c0168f370ff1f0749759705b7ecff359a80341..2ee23927d86612a59470dd3d3a219d00055ec65b 100644 --- a/tensorflow/compiler/xla/index_util.cc +++ b/tensorflow/compiler/xla/index_util.cc @@ -78,7 +78,7 @@ namespace xla { int64 scale = 1; int64 linear_index = 0; bool first = true; - for (auto dimension : shape.layout().minor_to_major()) { + for (auto dimension : LayoutUtil::MinorToMajor(shape)) { if (first) { // Avoid two multiplies on the first loop iteration linear_index = multi_index[dimension]; @@ -110,7 +110,7 @@ namespace xla { // Accumulated product D{L(0)} * D{L(1)} * ... int64 divisor = 1; - for (auto dimension : shape.layout().minor_to_major()) { + for (auto dimension : LayoutUtil::MinorToMajor(shape)) { multi_index[dimension] = (linear_index / divisor) % shape.dimensions(dimension); divisor *= shape.dimensions(dimension); @@ -133,18 +133,17 @@ namespace xla { /* static */ int64 IndexUtil::GetDimensionStride(const Shape& shape, int64 dimension) { - const Layout& layout = shape.layout(); - int64 pdim_size = layout.padded_dimensions_size(); + int64 pdim_size = LayoutUtil::PaddedDimensions(shape).size(); int64 stride = 1; DCHECK(pdim_size == 0 || pdim_size == shape.dimensions_size()); - for (auto dim : layout.minor_to_major()) { + for (auto dim : LayoutUtil::MinorToMajor(shape)) { if (dim == dimension) { break; } if (pdim_size == 0) { stride *= shape.dimensions(dim); } else { - stride *= layout.padded_dimensions(dim); + stride *= LayoutUtil::PaddedDimension(shape, dim); } } return stride; diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index 5c2cc2a7a99cc51ded3d98c9dd5903e4b3078548..f9803be32f5fc3c6b2f7e2527eec0a766647abc7 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -57,6 +57,7 @@ void SetDefaultLayoutToContainer( /* static */ Layout LayoutUtil::MakeLayout( tensorflow::gtl::ArraySlice minor_to_major) { Layout layout; + layout.set_format(DENSE); for (int64 dimension_number : minor_to_major) { layout.add_minor_to_major(dimension_number); } @@ -68,6 +69,7 @@ namespace { // Internal helper that creates a default layout for an array of the given rank. Layout CreateDefaultLayoutForRank(int64 rank) { Layout layout; + layout.set_format(DENSE); tensorflow::protobuf::RepeatedField* minor_to_major = layout.mutable_minor_to_major(); minor_to_major->Resize(rank, 0); @@ -105,7 +107,11 @@ Layout CreateDefaultLayoutForRank(int64 rank) { for (auto& element_shape : *shape->mutable_tuple_shapes()) { SetToDefaultLayout(&element_shape); } + shape->clear_layout(); + } else if (ShapeUtil::IsOpaque(*shape)) { + shape->clear_layout(); } else { + shape->mutable_layout()->set_format(DENSE); tensorflow::protobuf::RepeatedField* minor_to_major = shape->mutable_layout()->mutable_minor_to_major(); minor_to_major->Resize(shape->dimensions_size(), 0); @@ -137,8 +143,10 @@ Layout CreateDefaultLayoutForRank(int64 rank) { TF_RETURN_IF_ERROR(ValidateLayoutInShape(element_shape)); } return tensorflow::Status::OK(); - } else if (ShapeUtil::Rank(shape) == 0 && !shape.has_layout()) { - // A scalar without a layout is ok. + } else if (ShapeUtil::IsOpaque(shape)) { + if (shape.has_layout()) { + return InvalidArgument("opaque should not have a layout field"); + } return tensorflow::Status::OK(); } else { // Array shape. @@ -156,46 +164,59 @@ Layout CreateDefaultLayoutForRank(int64 rank) { return InvalidArgument("a single Layout is not valid for tuple shapes"); } - if (layout.minor_to_major_size() != ShapeUtil::Rank(shape)) { + if (ShapeUtil::IsOpaque(shape)) { + return tensorflow::Status::OK(); + } + + if (layout.format() == INVALID_FORMAT) { return InvalidArgument( - "layout minor_to_major field contains %d elements, " - "but shape is rank %lld: {%s}; shape: %s", - layout.minor_to_major_size(), ShapeUtil::Rank(shape), - tensorflow::str_util::Join(layout.minor_to_major(), ", ").c_str(), - shape.ShortDebugString().c_str()); + "Layout does not have a valid format: layout {%s}, shape {%s}", + layout.ShortDebugString().c_str(), shape.ShortDebugString().c_str()); } - std::vector dimensions_in_layout(ShapeUtil::Rank(shape), false); - for (int64 i = 0; i < ShapeUtil::Rank(shape); ++i) { - int64 dim = layout.minor_to_major(i); - if (dim < 0 || dim >= ShapeUtil::Rank(shape)) { + if (layout.format() == DENSE) { + if (layout.minor_to_major_size() != ShapeUtil::Rank(shape)) { return InvalidArgument( - "layout minor_to_major field has out-of-bounds value: %s", - HumanString(layout).c_str()); + "layout minor_to_major field contains %d elements, " + "but shape is rank %lld: {%s}; shape: %s", + layout.minor_to_major_size(), ShapeUtil::Rank(shape), + tensorflow::str_util::Join(layout.minor_to_major(), ", ").c_str(), + shape.ShortDebugString().c_str()); } - if (dimensions_in_layout[dim]) { - return InvalidArgument( - "layout minor_to_major field has duplicate values: {%s}", - HumanString(layout).c_str()); - } - dimensions_in_layout[dim] = true; - } - if (layout.padded_dimensions_size() > 0) { - if (layout.padded_dimensions_size() != ShapeUtil::Rank(shape)) { - return InvalidArgument( - "layout has %d padded dimensions, but shape is rank %lld", - layout.padded_dimensions_size(), ShapeUtil::Rank(shape)); + std::vector dimensions_in_layout(ShapeUtil::Rank(shape), false); + for (int64 i = 0; i < ShapeUtil::Rank(shape); ++i) { + int64 dim = layout.minor_to_major(i); + if (dim < 0 || dim >= ShapeUtil::Rank(shape)) { + return InvalidArgument( + "layout minor_to_major field has out-of-bounds value: %s", + HumanString(layout).c_str()); + } + if (dimensions_in_layout[dim]) { + return InvalidArgument( + "layout minor_to_major field has duplicate values: {%s}", + HumanString(layout).c_str()); + } + dimensions_in_layout[dim] = true; } - for (int i = 0; i < layout.padded_dimensions_size(); ++i) { - if (layout.padded_dimensions(i) < shape.dimensions(i)) { + + if (layout.padded_dimensions_size() > 0) { + if (layout.padded_dimensions_size() != ShapeUtil::Rank(shape)) { return InvalidArgument( - "for dimension %d, dimension padding (%lld) is smaller than " - "the dimension size (%lld) of the shape", - i, layout.padded_dimensions(i), shape.dimensions(i)); + "layout has %d padded dimensions, but shape is rank %lld", + layout.padded_dimensions_size(), ShapeUtil::Rank(shape)); + } + for (int i = 0; i < layout.padded_dimensions_size(); ++i) { + if (layout.padded_dimensions(i) < shape.dimensions(i)) { + return InvalidArgument( + "for dimension %d, dimension padding (%lld) is smaller than " + "the dimension size (%lld) of the shape", + i, layout.padded_dimensions(i), shape.dimensions(i)); + } } } } + return tensorflow::Status::OK(); } @@ -213,12 +234,23 @@ Layout CreateDefaultLayoutForRank(int64 rank) { LayoutUtil::ClearLayout(program_shape->mutable_result()); } +/* static */ bool LayoutUtil::IsDense(const Shape& shape) { + return ShapeUtil::IsArray(shape) && shape.has_layout() && + IsDense(shape.layout()); +} + +/* static */ bool LayoutUtil::IsDense(const Layout& layout) { + return layout.format() == DENSE; +} + /* static */ bool LayoutUtil::IsMonotonicWithDim0Minor(const Layout& layout) { + CHECK(layout.format() == DENSE); return std::is_sorted(layout.minor_to_major().begin(), layout.minor_to_major().end()); } /* static */ bool LayoutUtil::IsMonotonicWithDim0Major(const Layout& layout) { + CHECK(layout.format() == DENSE); return std::is_sorted(layout.minor_to_major().begin(), layout.minor_to_major().end(), std::greater()); } @@ -228,6 +260,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { shape.layout().padded_dimensions_size() == 0) { return false; } + CHECK(IsDense(shape)); CHECK_EQ(shape.dimensions_size(), shape.layout().padded_dimensions_size()); for (int64 i = 0; i < shape.dimensions_size(); ++i) { if (shape.layout().padded_dimensions(i) > shape.dimensions(i)) { @@ -237,15 +270,32 @@ Layout CreateDefaultLayoutForRank(int64 rank) { return false; } +/* static */ tensorflow::gtl::ArraySlice +LayoutUtil::PaddedDimensions(const Shape& shape) { + CHECK(IsDense(shape)); + return AsInt64Slice(shape.layout().padded_dimensions()); +} + +/* static */ int64 LayoutUtil::PaddedDimension(const Shape& shape, + int64 index) { + CHECK(IsDense(shape)); + return shape.layout().padded_dimensions(index); +} + +/* static */ PaddingValue LayoutUtil::GetPaddingValue(const Shape& shape) { + CHECK(IsDense(shape)); + return shape.layout().padding_value(); +} + /* static */ bool LayoutUtil::HasLayout(const Shape& shape) { if (ShapeUtil::IsTuple(shape)) { // Tuple shape: all subshapes must have a layout. return std::all_of(shape.tuple_shapes().begin(), shape.tuple_shapes().end(), [](const Shape& s) { return HasLayout(s); }); + } else if (ShapeUtil::IsOpaque(shape)) { + return true; } - // A scalar trivially always has a layout. - return (ShapeUtil::Rank(shape) == 0 || - (shape.has_layout() && (shape.layout().minor_to_major_size() > 0))); + return shape.has_layout() && shape.layout().format() != INVALID_FORMAT; } /* static */ bool LayoutUtil::HasLayout(const ProgramShape& program_shape) { @@ -261,6 +311,18 @@ Layout CreateDefaultLayoutForRank(int64 rank) { return protobuf_util::ProtobufEquals(lhs, rhs); } +/* static */ tensorflow::gtl::ArraySlice LayoutUtil::MinorToMajor( + const Shape& shape) { + CHECK(IsDense(shape)); + return AsInt64Slice(shape.layout().minor_to_major()); +} + +/* static */ tensorflow::gtl::ArraySlice LayoutUtil::MinorToMajor( + const Layout& layout) { + CHECK(layout.format() == DENSE); + return AsInt64Slice(layout.minor_to_major()); +} + /* static */ int64 LayoutUtil::Major(const Layout& layout, int64 physical_dimension_number) { CHECK_LE(0, physical_dimension_number); @@ -271,6 +333,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { /* static */ int64 LayoutUtil::Minor(const Layout& layout, int64 physical_dimension_number) { + CHECK_EQ(layout.format(), DENSE); CHECK_LE(0, physical_dimension_number); CHECK_LT(physical_dimension_number, layout.minor_to_major_size()); return layout.minor_to_major(physical_dimension_number); diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index bc42e222292933be35e82d1fe50802e8830d16b3..d00cd03756360a279ad8b803476f72bac0568734 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -71,6 +71,12 @@ class LayoutUtil { // Clears the layout on all Shapes within the given ProgramShape. static void ClearLayout(ProgramShape* program_shape); + // Returns whether the given Shape is an array and has a dense format layout. + static bool IsDense(const Shape& shape); + + // Returns whether the given Layout has a dense format. + static bool IsDense(const Layout& layout); + // Returns whether the layout is monotonic and dim 0 is minor in the layout. // * R0 and R1: this is always trivially true. // * R2+: equivalent to column-major. Dimension 0 is the minor, dimension 1 is @@ -88,6 +94,19 @@ class LayoutUtil { // dimension size). static bool IsPadded(const Shape& shape); + // Returns the padded_dimensions array for the given Shape. Requires that the + // shape is an array and has a dense layout. + static tensorflow::gtl::ArraySlice PaddedDimensions( + const Shape& shape); + + // Returns the given index of the padded_dimensions array for the given Shape. + // Requires that the shape is an array and has a dense layout. + static int64 PaddedDimension(const Shape& shape, int64 index); + + // Returns the padding_value for the given Shape. Requires that the shape is + // an array and has a dense layout. + static PaddingValue GetPaddingValue(const Shape& shape); + // Returns whether the given shape has a layout. For tuple shapes, true is // returned only if all elements have layouts. static bool HasLayout(const Shape& shape); @@ -98,6 +117,11 @@ class LayoutUtil { // Returns whether lhs and rhs are identical. static bool Equal(const Layout& lhs, const Layout& rhs); + // Returns the minor_to_major array for the given Shape. Requires that the + // shape is an array and has a dense layout. + static tensorflow::gtl::ArraySlice MinorToMajor(const Shape& shape); + static tensorflow::gtl::ArraySlice MinorToMajor(const Layout& layout); + // Major(0) is the most major logical dimension number, major(1) is the // second-most-major logical dimension number and so on. // diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 93d3cd425f0a868b51677058796e9c40c2d3dff8..f493460e795deaf66fa57b9fa42918fdd05bdac6 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -64,12 +64,12 @@ Literal::StrideConfig::StrideConfig( if (!dimensions.empty()) { // Selects the shape with the largest minor dimension as the one upon // which to run the tight stride loop. - if (dimensions[source_shape.layout().minor_to_major()[0]] >= - dimensions[dest_shape.layout().minor_to_major()[0]]) { - minor_dimension = source_shape.layout().minor_to_major()[0]; + if (dimensions[LayoutUtil::Minor(source_shape.layout(), 0)] >= + dimensions[LayoutUtil::Minor(dest_shape.layout(), 0)]) { + minor_dimension = LayoutUtil::Minor(source_shape.layout(), 0); dest_stride = IndexUtil::GetDimensionStride(dest_shape, minor_dimension); } else { - minor_dimension = dest_shape.layout().minor_to_major()[0]; + minor_dimension = LayoutUtil::Minor(dest_shape.layout(), 0); source_stride = IndexUtil::GetDimensionStride(source_shape, minor_dimension); } @@ -252,6 +252,10 @@ Status Literal::Copy(const Literal& src_literal, return *Literal::CreateR0(1); case S64: return *Literal::CreateR0(1); + case F16: + return *Literal::CreateR0(static_cast(1.0f)); + case BF16: + return *Literal::CreateR0(static_cast(1.0f)); case F32: return *Literal::CreateR0(1); case F64: @@ -263,8 +267,6 @@ Status Literal::Copy(const Literal& src_literal, case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; - case F16: - return *Literal::CreateR0(static_cast(1.0f)); case TUPLE: LOG(FATAL) << "tuple element type cannot take on value of 1"; case OPAQUE: @@ -402,6 +404,27 @@ std::unique_ptr Literal::Relayout( return outer_result; } +std::unique_ptr Literal::Relayout( + const Shape& shape_with_layout) const { + CHECK(ShapeUtil::Compatible(shape_with_layout, shape())) + << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout) + << " not compatible with literal shape " + << ShapeUtil::HumanString(shape()); + std::unique_ptr result = CreateFromShape(shape_with_layout); + ShapeUtil::ForEachSubshape( + result->shape(), + [this, &result](const Shape& subshape, const ShapeIndex& index) { + if (ShapeUtil::IsArray(subshape)) { + DimensionVector base(ShapeUtil::Rank(subshape), 0); + DimensionVector copy_size(subshape.dimensions().begin(), + subshape.dimensions().end()); + TF_CHECK_OK(result->GetSubliteral(index).Copy(GetSubliteral(index), + base, base, copy_size)); + } + }); + return result; +} + StatusOr> Literal::Reshape( tensorflow::gtl::ArraySlice dimensions) const { if (ShapeUtil::IsTuple(shape())) { @@ -409,10 +432,8 @@ StatusOr> Literal::Reshape( } std::unique_ptr output; if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) { - std::vector minor_to_major(ShapeUtil::Rank(shape())); - std::iota(minor_to_major.rbegin(), minor_to_major.rend(), - static_cast(0)); - output = Relayout(LayoutUtil::MakeLayout(minor_to_major)); + output = + Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape()))); } else { output = CloneToUnique(); } @@ -458,9 +479,10 @@ std::unique_ptr Literal::Transpose( // dimension has within the transposed array, a layout is affine if // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major // vector of the affine layout. + CHECK(LayoutUtil::IsDense(permuted_shape)); Layout* layout = permuted_shape.mutable_layout(); layout->clear_minor_to_major(); - for (auto index : shape().layout().minor_to_major()) { + for (auto index : LayoutUtil::MinorToMajor(shape())) { layout->add_minor_to_major(inverse_permutation[index]); } std::unique_ptr new_literal = CreateFromShape(permuted_shape); @@ -484,9 +506,9 @@ std::unique_ptr Literal::Slice( CHECK_GT(dimension, 0); result_dimensions.push_back(dimension); } - const auto result_shape = ShapeUtil::MakeShapeWithLayout( - shape().element_type(), result_dimensions, - AsInt64Slice(shape().layout().minor_to_major())); + const auto result_shape = + ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions, + LayoutUtil::MinorToMajor(shape())); auto result_literal = MakeUnique(); *result_literal->mutable_shape() = result_shape; @@ -713,7 +735,13 @@ string Literal::ToString(bool print_layout) const { pieces.push_back("}"); } else { pieces.push_back(shape_to_string(shape())); - pieces.push_back(" {...}"); + pieces.push_back(" {"); + EachCellAsString( + [&](tensorflow::gtl::ArraySlice indices, const string& value) { + pieces.push_back(" "); + pieces.push_back(value); + }); + pieces.push_back("}"); } return tensorflow::str_util::Join(pieces, ""); diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index f37e529caf54e3aded1a418d1f01c1440cd0f284..c782e0f19e5e15eb03894b2dda40ba40b3dfaba7 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -99,6 +99,7 @@ class Literal { f16s_.clear(); f32s_.clear(); f64s_.clear(); + c64s_.clear(); tuple_literals_.clear(); } @@ -285,11 +286,15 @@ class Literal { std::unique_ptr Relayout(const Layout& new_layout, const ShapeIndex& shape_index = {}) const; - // Creates a new literal by reshaping this literal to have 'shape'. Both the - // original shape and 'shape' must contain the same number of elements. The + // An overload of Relayout which changes the layout of the entire shape rather + // than being limited to a single array within the shape. + std::unique_ptr Relayout(const Shape& shape_with_layout) const; + + // Creates a new literal by reshaping this literal to have the given + // dimensions. The total number of elements must not change; The // implementation currently only supports monotonic dim0-major layouts. StatusOr> Reshape( - tensorflow::gtl::ArraySlice shape) const; + tensorflow::gtl::ArraySlice dimensions) const; // Creates a new literal by reordering the dimensions of this literal. // The given `permutation` must be a permutation of the dimension numbers @@ -1106,7 +1111,7 @@ void Literal::PopulateR2WithLayout( primitive_util::NativeToPrimitiveType(), {static_cast(values.size()), static_cast(values.begin()->size())}, - AsInt64Slice(layout.minor_to_major())); + LayoutUtil::MinorToMajor(layout)); const int64 dim0_size = values.size(); const int64 dim1_size = values.begin()->size(); @@ -1137,9 +1142,10 @@ void Literal::PopulateR2( template void Literal::PopulateFromArrayWithLayout(const Array& values, const Layout& layout) { + CHECK_EQ(layout.format(), DENSE); *mutable_shape() = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), values.dimensions(), - AsInt64Slice(layout.minor_to_major())); + LayoutUtil::MinorToMajor(layout)); Reserve(values.num_elements()); values.Each([this](tensorflow::gtl::ArraySlice indices, NativeT value) { this->Set(indices, value); }); diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index 816bb3c549eaae4e8fc2b7d438627266603272f9..7ff64c4134155e7fe22ab99584970a7d6d6e8803 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -515,7 +515,7 @@ TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) { TEST_F(LiteralUtilTest, ReshapeR0) { auto original = Literal::CreateR0(1.7f); - auto reshape = original->Reshape(/*shape=*/{}).ConsumeValueOrDie(); + auto reshape = original->Reshape(/*dimensions=*/{}).ConsumeValueOrDie(); EXPECT_EQ(*original, *reshape); } diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index 19c6a138885c61f1304bfae3d8bb5d958a1bb5bc..cb4583d198b454be1432134a9f6a77dbbbe5bdd8 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -26,6 +26,13 @@ limitations under the License. namespace xla { namespace primitive_util { +// The number of exponent bits in a BF16 value. +const int kBFloat16ExponentBits = 8; + +// The number of mantissa bits in a BF16 value. There is an implicit leading +// 1, so there is an implicit additional bit of precision. +const int kBFloat16MantissaBits = 7; + // Returns the XLA primitive type (eg, F32) corresponding to the given // template parameter native type (eg, float). template diff --git a/tensorflow/compiler/xla/ptr_util.h b/tensorflow/compiler/xla/ptr_util.h index 627ddf535fe734ac55d01dabb7f160b46e6e69d8..c58c19db2cacbe9b038160f27b9bd76aa58146eb 100644 --- a/tensorflow/compiler/xla/ptr_util.h +++ b/tensorflow/compiler/xla/ptr_util.h @@ -37,7 +37,7 @@ std::unique_ptr WrapUnique(T* ptr) { template typename tensorflow::helper::MakeUniqueResult::scalar MakeUnique( Args&&... args) { - return tensorflow::MakeUnique(std::forward(args)...); + return tensorflow::MakeUnique(std::forward(args)...); } // Overload for array of unknown bound. diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..a6b8158671fd7872fd3492fe647558f7a3c3d1d8 --- /dev/null +++ b/tensorflow/compiler/xla/python/BUILD @@ -0,0 +1,82 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") + +py_library( + name = "xla_client", + srcs = ["xla_client.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":pywrap_xla", + "//tensorflow/compiler/xla:xla_data_proto_py", + ], +) + +py_test( + name = "xla_client_test", + srcs = ["xla_client_test.py"], + main = "xla_client_test.py", + srcs_version = "PY2AND3", + deps = [ + ":xla_client", + "//tensorflow/python:platform_test", + ], +) + +cc_library( + name = "numpy_bridge", + srcs = ["numpy_bridge.cc"], + hdrs = ["numpy_bridge.h"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/python:numpy_lib", + ], +) + +cc_library( + name = "local_computation_builder", + srcs = ["local_computation_builder.cc"], + hdrs = ["local_computation_builder.h"], + deps = [ + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/core:lib", + ], +) + +tf_py_wrap_cc( + name = "pywrap_xla", + srcs = ["xla.i"], + swig_includes = [ + "local_computation_builder.i", + ], + deps = [ + ":local_computation_builder", + ":numpy_bridge", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/xla/python/__init__.py b/tensorflow/compiler/xla/python/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc new file mode 100644 index 0000000000000000000000000000000000000000..0b0a53fac7adf2c088a3ceb9ae58a5ce2c7adf92 --- /dev/null +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -0,0 +1,265 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/local_computation_builder.h" +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { + +namespace swig { + +CompiledLocalComputation::CompiledLocalComputation( + std::unique_ptr executable) + : executable_(std::move(executable)) {} + +std::unique_ptr CompiledLocalComputation::Execute( + const std::vector& arguments) { + LocalClient* client = ClientLibrary::LocalClientOrDie(); + + // Transfer arguments in + std::vector> scoped_buffers; + scoped_buffers.reserve(arguments.size()); + for (const Literal& argument : arguments) { + scoped_buffers.push_back( + client + ->LiteralToShapedBuffer(argument, + /*device_ordinal=*/0, + client->backend().memory_allocator()) + .ConsumeValueOrDie()); + } + + // Execute + std::vector argument_buffers; + argument_buffers.reserve(scoped_buffers.size()); + for (auto& buffer : scoped_buffers) { + argument_buffers.push_back(buffer.get()); + } + ExecutableRunOptions options; + options.set_allocator(client->backend().memory_allocator()); + options.set_inter_op_thread_pool(client->backend().inter_op_thread_pool()); + options.set_intra_op_thread_pool( + client->backend().eigen_intra_op_thread_pool_device()); + std::unique_ptr result_buffer = + executable_->Run(argument_buffers, options).ConsumeValueOrDie(); + + // Transfer result out + return client->ShapedBufferToLiteral(*result_buffer).ConsumeValueOrDie(); +} + +LocalComputation::LocalComputation(std::unique_ptr computation) + : computation_(std::move(computation)) {} + +CompiledLocalComputation* LocalComputation::Compile( + const std::vector& argument_shapes) { + std::vector argument_shape_pointers; + argument_shape_pointers.reserve(argument_shapes.size()); + for (auto& argument_shape : argument_shapes) { + argument_shape_pointers.push_back(&argument_shape); + } + + LocalClient* client = ClientLibrary::LocalClientOrDie(); + ExecutableBuildOptions options; + return new CompiledLocalComputation( + client->Compile(*computation_, argument_shape_pointers, options) + .ValueOrDie()); +} + +const Computation& LocalComputation::computation() const { + return *computation_; +} + +LocalComputationBuilder::LocalComputationBuilder(const string& computation_name) + : builder_(ClientLibrary::LocalClientOrDie(), computation_name) {} + +LocalComputation* LocalComputationBuilder::Build() { + return new LocalComputation(std::unique_ptr( + new Computation(builder_.Build().ConsumeValueOrDie()))); +} + +ComputationDataHandle LocalComputationBuilder::Parameter(int64 parameter_number, + const Shape& shape, + const string& name) { + return builder_.Parameter(parameter_number, shape, name); +} + +std::unique_ptr LocalComputationBuilder::GetShape( + const ComputationDataHandle& operand) { + return builder_.GetShape(operand).ConsumeValueOrDie(); +} + +ComputationDataHandle LocalComputationBuilder::ConstantLiteral( + const Literal& literal) { + return builder_.ConstantLiteral(literal); +} + +ComputationDataHandle LocalComputationBuilder::Broadcast( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice broadcast_sizes) { + return builder_.Broadcast(operand, broadcast_sizes); +} + +ComputationDataHandle LocalComputationBuilder::Reshape( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes) { + return builder_.Reshape(operand, dimensions, new_sizes); +} + +ComputationDataHandle LocalComputationBuilder::Slice( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides) { + return builder_.Slice(operand, start_indices, limit_indices, strides); +} + +ComputationDataHandle LocalComputationBuilder::DynamicSlice( + const ComputationDataHandle& operand, + const ComputationDataHandle& start_indices, + tensorflow::gtl::ArraySlice slice_sizes) { + return builder_.DynamicSlice(operand, start_indices, slice_sizes); +} + +ComputationDataHandle LocalComputationBuilder::DynamicUpdateSlice( + const ComputationDataHandle& operand, const ComputationDataHandle& update, + const ComputationDataHandle& start_indices) { + return builder_.DynamicUpdateSlice(operand, update, start_indices); +} + +ComputationDataHandle LocalComputationBuilder::ConcatInDim( + tensorflow::gtl::ArraySlice operands, + int64 dimension) { + return builder_.ConcatInDim(operands, dimension); +} + +ComputationDataHandle LocalComputationBuilder::Select( + const ComputationDataHandle& pred, const ComputationDataHandle& on_true, + const ComputationDataHandle& on_false) { + return builder_.Select(pred, on_true, on_false); +} + +ComputationDataHandle LocalComputationBuilder::Tuple( + tensorflow::gtl::ArraySlice elements) { + return builder_.Tuple(elements); +} + +ComputationDataHandle LocalComputationBuilder::GetTupleElement( + const ComputationDataHandle& tuple_data, int64 index) { + return builder_.GetTupleElement(tuple_data, index); +} + +ComputationDataHandle LocalComputationBuilder::Dot( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) { + return builder_.Dot(lhs, rhs); +} + +ComputationDataHandle LocalComputationBuilder::ConvertElementType( + const ComputationDataHandle& operand, PrimitiveType new_element_type) { + return builder_.ConvertElementType(operand, new_element_type); +} + +ComputationDataHandle LocalComputationBuilder::Call( + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice operands) { + return builder_.Call(local_computation.computation(), operands); +} + +ComputationDataHandle LocalComputationBuilder::Transpose( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice permutation) { + return builder_.Transpose(operand, permutation); +} + +ComputationDataHandle LocalComputationBuilder::Map( + tensorflow::gtl::ArraySlice operands, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice static_operands) { + return builder_.Map(operands, local_computation.computation(), dimensions, + static_operands); +} + +ComputationDataHandle LocalComputationBuilder::Reduce( + const ComputationDataHandle& operand, + const ComputationDataHandle& init_value, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce) { + return builder_.Reduce(operand, init_value, local_computation.computation(), + dimensions_to_reduce); +} + +ComputationDataHandle LocalComputationBuilder::While( + const LocalComputation& condition, const LocalComputation& body, + const ComputationDataHandle& init) { + return builder_.While(condition.computation(), body.computation(), init); +} + +#define _FORWARD(method_name, return_sig, args_sig, args) \ + return_sig LocalComputationBuilder::method_name args_sig { \ + return builder_.method_name args; \ + } + +#define _FORWARD_UNOP(method_name) \ + _FORWARD(method_name, ComputationDataHandle, \ + (const ComputationDataHandle& operand), (operand)) + +#define _FORWARD_BINOP(method_name) \ + _FORWARD( \ + method_name, ComputationDataHandle, \ + (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ + tensorflow::gtl::ArraySlice broadcast_dimensions), \ + (lhs, rhs, broadcast_dimensions)) + +_FORWARD_BINOP(Eq) +_FORWARD_BINOP(Ne) +_FORWARD_BINOP(Ge) +_FORWARD_BINOP(Gt) +_FORWARD_BINOP(Lt) +_FORWARD_BINOP(Le) +_FORWARD_BINOP(Add) +_FORWARD_BINOP(Sub) +_FORWARD_BINOP(Mul) +_FORWARD_BINOP(Div) +_FORWARD_BINOP(Rem) +_FORWARD_BINOP(Max) +_FORWARD_BINOP(Min) +_FORWARD_BINOP(And) +_FORWARD_BINOP(Or) +_FORWARD_UNOP(Not) +_FORWARD_UNOP(Abs) +_FORWARD_UNOP(Exp) +_FORWARD_UNOP(Floor) +_FORWARD_UNOP(Ceil) +_FORWARD_UNOP(Log) +_FORWARD_UNOP(Sign) +_FORWARD_UNOP(Cos) +_FORWARD_UNOP(Sin) +_FORWARD_UNOP(Tanh) +_FORWARD_UNOP(SqrtF32) +_FORWARD_UNOP(SquareF32) +_FORWARD_BINOP(Pow) +_FORWARD_UNOP(IsFinite) +_FORWARD_UNOP(ReciprocalF32) +_FORWARD_UNOP(Neg) +_FORWARD_UNOP(Sort) + +#undef _FORWARD +#undef _FORWARD_UNOP +#undef _FORWARD_BINOP + +} // namespace swig + +} // namespace xla diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h new file mode 100644 index 0000000000000000000000000000000000000000..cbab45a5f0132eb08f291f542d40df6d0689e7ae --- /dev/null +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -0,0 +1,210 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ + +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { + +namespace swig { + +// Wraps a LocalExecutable produced by compiling a +// LocalComputation. The Execute method forwards to that of the +// underlying LocalExecutable, and additionally handles tranferring +// arguments and return values in and back out of the client library's +// local client. This class is intended to be made available to Python +// via SWIG. +class CompiledLocalComputation { + public: + CompiledLocalComputation(std::unique_ptr executable); + std::unique_ptr Execute(const std::vector& arguments); + + private: + std::unique_ptr executable_; +}; + +// Wraps a Computation produced by a LocalComputationBuilder. The +// Compile method compiles the computation to a (local) executable via +// the client library's local client. This class is intended to be +// made available to Python via SWIG. +class LocalComputation { + public: + LocalComputation(std::unique_ptr computation); + CompiledLocalComputation* Compile(const std::vector& argument_shapes); + const Computation& computation() const; + + private: + std::unique_ptr computation_; +}; + +// Wraps the ComputationBuilder API in order to: +// - Support consumption by SWIG in order to be made available to +// Python. +// - Set up the underlying builder to use the client library's +// LocalClient. +// - Wrap Computations in LocalComputations for Python access. +// - Correspondingly unwrap incoming LocalComputations. +class LocalComputationBuilder { + public: + LocalComputationBuilder(const string& computation_name); + + LocalComputation* Build(); + + ComputationDataHandle Parameter(int64 parameter_number, const Shape& shape, + const string& name); + + std::unique_ptr GetShape(const ComputationDataHandle& operand); + + ComputationDataHandle ConstantLiteral(const Literal& literal); + + ComputationDataHandle Broadcast( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice broadcast_sizes); + + ComputationDataHandle Reshape(const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes); + + ComputationDataHandle Slice(const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides); + + ComputationDataHandle DynamicSlice( + const ComputationDataHandle& operand, + const ComputationDataHandle& start_indices, + tensorflow::gtl::ArraySlice slice_sizes); + + ComputationDataHandle DynamicUpdateSlice( + const ComputationDataHandle& operand, const ComputationDataHandle& update, + const ComputationDataHandle& start_indices); + + ComputationDataHandle ConcatInDim( + tensorflow::gtl::ArraySlice operands, + int64 dimension); + + ComputationDataHandle Select(const ComputationDataHandle& pred, + const ComputationDataHandle& on_true, + const ComputationDataHandle& on_false); + + ComputationDataHandle Tuple( + tensorflow::gtl::ArraySlice elements); + + ComputationDataHandle GetTupleElement(const ComputationDataHandle& tuple_data, + int64 index); + + ComputationDataHandle Dot(const ComputationDataHandle& lhs, + const ComputationDataHandle& rhs); + + ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand, + PrimitiveType new_element_type); + + ComputationDataHandle Call( + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice operands); + + ComputationDataHandle Transpose( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice permutation); + + ComputationDataHandle Map( + tensorflow::gtl::ArraySlice operands, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice static_operands); + + ComputationDataHandle Reduce( + const ComputationDataHandle& operand, + const ComputationDataHandle& init_value, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce); + + ComputationDataHandle While(const LocalComputation& condition, + const LocalComputation& body, + const ComputationDataHandle& init); + +#define _FORWARD(method_name, return_sig, args_sig) \ + return_sig method_name args_sig; + +#define _FORWARD_UNOP(method_name) \ + _FORWARD(method_name, ComputationDataHandle, \ + (const ComputationDataHandle& operand)) + +#define _FORWARD_BINOP(method_name) \ + _FORWARD( \ + method_name, ComputationDataHandle, \ + (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ + tensorflow::gtl::ArraySlice broadcast_dimensions)) + + _FORWARD_BINOP(Eq) + _FORWARD_BINOP(Ne) + _FORWARD_BINOP(Ge) + _FORWARD_BINOP(Gt) + _FORWARD_BINOP(Lt) + _FORWARD_BINOP(Le) + _FORWARD_BINOP(Add) + _FORWARD_BINOP(Sub) + _FORWARD_BINOP(Mul) + _FORWARD_BINOP(Div) + _FORWARD_BINOP(Rem) + _FORWARD_BINOP(Max) + _FORWARD_BINOP(Min) + _FORWARD_BINOP(And) + _FORWARD_BINOP(Or) + _FORWARD_UNOP(Not) + _FORWARD_UNOP(Abs) + _FORWARD_UNOP(Exp) + _FORWARD_UNOP(Floor) + _FORWARD_UNOP(Ceil) + _FORWARD_UNOP(Log) + _FORWARD_UNOP(Sign) + _FORWARD_UNOP(Cos) + _FORWARD_UNOP(Sin) + _FORWARD_UNOP(Tanh) + _FORWARD_UNOP(SqrtF32) + _FORWARD_UNOP(SquareF32) + _FORWARD_BINOP(Pow) + _FORWARD_UNOP(IsFinite) + _FORWARD_UNOP(ReciprocalF32) + _FORWARD_UNOP(Neg) + _FORWARD_UNOP(Sort) + +#undef _FORWARD +#undef _FORWARD_UNOP +#undef _FORWARD_BINOP + + private: + ComputationBuilder builder_; +}; + +static void DeleteLocalComputation(LocalComputation* computation) { + delete computation; +} + +static void DeleteCompiledLocalComputation( + CompiledLocalComputation* computation) { + delete computation; +} + +} // namespace swig + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i new file mode 100644 index 0000000000000000000000000000000000000000..ac8f3e4277739cb97c1209a22bb5c6975266e3ee --- /dev/null +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -0,0 +1,348 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// SWIG typemaps and declarations for building, compiling, and +// executing XLA computations, wrapping most of what is declared in +// local_computation_builder.h. +// +// The typemaps below implement/assert the following correspondences +// (with elaborations below): +// +// C++ Python +// -------------------------------------+--------------------------------------- +// ComputationDataHandle <-> long +// ArraySlice <- sequence of long +// ArraySlice <- sequence of long +// Literal <-> (nested tuple of) numpy ndarray +// std::vector <- sequence of (nested tuple of) ndarray +// Shape <-> pair holding (dtype, dimensions) +// std::vector <- sequence of shape information pairs +// PrimitiveType <- int +// +// Arrows indicate whether a conversion only ever occurs in one +// direction, or whether it is maintained bidirectionally. Also, +// "long" and "int" denote the Python types so named, not C. +// +// The Python objects corresponding to C++ Literals have the type: +// +// T = ndarray | (T, ...) +// +// where a terminal numpy ndarray translates to a Literal with a +// non-tuple Shape, an XLA primitive element type corresponding to the +// ndarray's dtype. Meanwhile, a non-terminal "tuple of T" translates +// to a tuple-shaped Literal whose tuple components are translated +// recursively. For example, if x is a numpy ndarray in Python, with +// shape (2, 3) and dtype of dtype('float32'), then x translates to a +// Literal with rank 2, dimension 2 and 3, and XLA primitive type +// F32. Meanwhile, +// +// (x, (x, x), (x,)), +// +// translates to a tuple-shaped XLA Literal, whose component subshapes +// are a 2x3 F32-shaped literal followed by two tuple-shaped literals. +// +// The Python objects corresponding to C++ Shapes have the type: +// +// T = (dtype, S) +// S = DIMENSIONS | TUPLE_SHAPES +// DIMENSIONS = (int, ...) +// TUPLE_SHAPES = (T, ...) +// +// In the pair described by the T rule, the terminal dtype determines +// whether S expands as DIMENSIONS or TUPLE_SHAPES. Namely if it is +// dtype('O'), numpy's object dtype, the structure represents a tuple +// shape and the expansion of the non-terminal S is +// TUPLE_SHAPES. Otherwise, dtype describes a primitive element type +// and S expands into DIMENSIONS giving dimension sizes. For example: +// +// (dtype('float32'), (3, 5, 7)) +// +// describes a 3x5x7 array of F32s, and +// +// (dtype('O'), ((dtype('float32'), (2, 3)), +// (dtype('float64'), (4, 5)))) +// +// describes a tuple shape with two subshapes: the first a 2x3 F32, +// and the other a 4x5 F64. +// +// The Python int corresponding to a PrimitiveType enum must be valid +// per xla_data.proto (e.g. xla_data.PRED, xla_data.F32). +// +// The SWIG object wrappers generated by this file are not intended +// for end use, but rather for internal use in the Python XLA client, +// xla_client.py. +// +// One central reason for the Python-side indirection is that the +// Python-side objects produced by the typemaps in this file are +// further packaged up by xla_client before being passed on. For +// instance, xla_client wraps the long produced for a C++ +// ComputationDataHandle in a Python ComputationDataHandle proto, +// rather than exposing a raw long outside of the client. Similarly, +// the Python pair produced for a C++ Shape is further wrapped in a +// Python class (xla_client.Shape) so as not to expose the raw pair +// externally. +// +// Other SWIG object wrappers (e.g. of LocalComputation) are further +// wrapped by xla_client in order to set up a custom destructor that +// triggers memory deallocation on the C++ side. + +%include "tensorflow/python/platform/base.i" + +%{ +// Must be included first +#include "tensorflow/python/lib/core/numpy.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/compiler/xla/python/numpy_bridge.h" +#include "tensorflow/compiler/xla/python/local_computation_builder.h" + +using namespace xla; +using namespace xla::swig; +%} + +// Required to use PyArray_* functions. +%init %{ +tensorflow::ImportNumpy(); +%} + +// ComputationDataHandle + +%typemap(in) const ComputationDataHandle& (ComputationDataHandle temp) { + const int64 handle = numpy::PyIntOrPyLongToLong($input); + if (handle == -1 && PyErr_Occurred()) { + return NULL; + } + temp.set_handle(handle); + $1 = &temp; +} + +%typemap(out) ComputationDataHandle { + $result = numpy::LongToPyIntOrPyLong($1.handle()); +} + +// ArraySlice + +%typemap(in) tensorflow::gtl::ArraySlice + (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + return NULL; + } + const int size = PySequence_Size($input); + temps.resize(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + PyObject* py_int = numpy::PyNumberToPyInt(o); + if (!py_int) { + PyErr_SetString( + PyExc_TypeError, + "Argument sequence element cannot be converted to int"); + Py_DECREF(o); + return NULL; + } + temps[i] = numpy::PyIntOrPyLongToLong(py_int); + if (temps[i] == -1 && PyErr_Occurred()) { + Py_DECREF(py_int); + Py_DECREF(o); + return NULL; + } + Py_DECREF(py_int); + Py_DECREF(o); + } + $1 = temps; +} + +// ComputationDataHandle + +%typemap(in) tensorflow::gtl::ArraySlice + (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + return NULL; + } + const int size = PySequence_Size($input); + temps.resize(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + PyObject* py_int = numpy::PyNumberToPyInt(o); + if (!py_int) { + PyErr_SetString( + PyExc_TypeError, + "Argument sequence element cannot be converted to int"); + return NULL; + } + const int64 handle = numpy::PyIntOrPyLongToLong(py_int); + if (handle == -1 && PyErr_Occurred()) { + Py_DECREF(py_int); + Py_DECREF(o); + return NULL; + } + temps[i].set_handle(handle); + Py_DECREF(py_int); + Py_DECREF(o); + } + $1 = temps; +} + +// Literal + +%typemap(in) const Literal& (std::unique_ptr temp) { + temp = numpy::XlaLiteralFromPyObject($input); + $1 = &*temp; +} + +%typemap(out) std::unique_ptr { + $result = numpy::PyObjectFromXlaLiteral(*$1); +} + +%typemap(in) const std::vector& (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + return NULL; + } + const int size = PySequence_Size($input); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + temps.push_back(*numpy::XlaLiteralFromPyObject(o)); + Py_DECREF(o); + } + $1 = &temps; +} + +// Shape + +%typemap(in) const Shape& (Shape temp) { + if (!numpy::CheckPyShapeInfo($input)) { + return NULL; + } + temp = numpy::XlaShapeFromPyShapeInfo($input); + $1 = &temp; +} + +%typemap(out) std::unique_ptr { + $result = numpy::PyShapeInfoFromXlaShape(*$1); +} + +%typemap(in) const std::vector& (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + return NULL; + } + const int size = PySequence_Size($input); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + if (!numpy::CheckPyShapeInfo(o)) { + Py_DECREF(o); + return NULL; + } + temps.push_back(numpy::XlaShapeFromPyShapeInfo(o)); + Py_DECREF(o); + } + $1 = &temps; +} + +// PrimitiveType + +%typemap(in) PrimitiveType { + PyObject* py_int = numpy::PyNumberToPyInt($input); + if (!py_int) { + PyErr_SetString(PyExc_TypeError, "Argument cannot be converted to int"); + return NULL; + } + const long value = numpy::PyIntOrPyLongToLong(py_int); + if (value == -1 && PyErr_Occurred()) { + Py_DECREF(py_int); + return NULL; + } + if (!PrimitiveType_IsValid(value)) { + PyErr_SetString( + PyExc_TypeError, "Argument not valid for PrimitiveType enum"); + Py_DECREF(py_int); + return NULL; + } + $1 = static_cast(value); +} + +%ignoreall +%unignore xla; +%unignore xla::swig; +%unignore xla::swig::CompiledLocalComputation; +%unignore xla::swig::CompiledLocalComputation::Execute; +%unignore xla::swig::LocalComputation; +%unignore xla::swig::LocalComputation::Compile; +%unignore xla::swig::LocalComputationBuilder; +%unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder; +%unignore xla::swig::LocalComputationBuilder::Build; +%unignore xla::swig::LocalComputationBuilder::Parameter; +%unignore xla::swig::LocalComputationBuilder::GetShape; +%unignore xla::swig::LocalComputationBuilder::ConstantLiteral; +%unignore xla::swig::LocalComputationBuilder::ConstantR0; +%unignore xla::swig::LocalComputationBuilder::Broadcast; +%unignore xla::swig::LocalComputationBuilder::Reshape; +%unignore xla::swig::LocalComputationBuilder::Slice; +%unignore xla::swig::LocalComputationBuilder::DynamicSlice; +%unignore xla::swig::LocalComputationBuilder::DynamicUpdateSlice; +%unignore xla::swig::LocalComputationBuilder::ConcatInDim; +%unignore xla::swig::LocalComputationBuilder::Select; +%unignore xla::swig::LocalComputationBuilder::Tuple; +%unignore xla::swig::LocalComputationBuilder::GetTupleElement; +%unignore xla::swig::LocalComputationBuilder::ConvertElementType; +%unignore xla::swig::LocalComputationBuilder::Call; +%unignore xla::swig::LocalComputationBuilder::Transpose; +%unignore xla::swig::LocalComputationBuilder::Map; +%unignore xla::swig::LocalComputationBuilder::Reduce; +%unignore xla::swig::LocalComputationBuilder::While; +%unignore xla::swig::LocalComputationBuilder::Eq; +%unignore xla::swig::LocalComputationBuilder::Ne; +%unignore xla::swig::LocalComputationBuilder::Ge; +%unignore xla::swig::LocalComputationBuilder::Gt; +%unignore xla::swig::LocalComputationBuilder::Lt; +%unignore xla::swig::LocalComputationBuilder::Le; +%unignore xla::swig::LocalComputationBuilder::Dot; +%unignore xla::swig::LocalComputationBuilder::Add; +%unignore xla::swig::LocalComputationBuilder::Sub; +%unignore xla::swig::LocalComputationBuilder::Mul; +%unignore xla::swig::LocalComputationBuilder::Div; +%unignore xla::swig::LocalComputationBuilder::Rem; +%unignore xla::swig::LocalComputationBuilder::Max; +%unignore xla::swig::LocalComputationBuilder::Min; +%unignore xla::swig::LocalComputationBuilder::And; +%unignore xla::swig::LocalComputationBuilder::Or; +%unignore xla::swig::LocalComputationBuilder::Not; +%unignore xla::swig::LocalComputationBuilder::Abs; +%unignore xla::swig::LocalComputationBuilder::Exp; +%unignore xla::swig::LocalComputationBuilder::Floor; +%unignore xla::swig::LocalComputationBuilder::Ceil; +%unignore xla::swig::LocalComputationBuilder::Log; +%unignore xla::swig::LocalComputationBuilder::Sign; +%unignore xla::swig::LocalComputationBuilder::Cos; +%unignore xla::swig::LocalComputationBuilder::Sin; +%unignore xla::swig::LocalComputationBuilder::Tanh; +%unignore xla::swig::LocalComputationBuilder::SqrtF32; +%unignore xla::swig::LocalComputationBuilder::SquareF32; +%unignore xla::swig::LocalComputationBuilder::Pow; +%unignore xla::swig::LocalComputationBuilder::IsFinite; +%unignore xla::swig::LocalComputationBuilder::ReciprocalF32; +%unignore xla::swig::LocalComputationBuilder::Neg; +%unignore xla::swig::LocalComputationBuilder::Sort; +%unignore xla::swig::DeleteLocalComputation; +%unignore xla::swig::DeleteCompiledLocalComputation; + +%include "tensorflow/compiler/xla/python/local_computation_builder.h" + +%unignoreall diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc new file mode 100644 index 0000000000000000000000000000000000000000..b30bdc3669de3992a08ab70ef49b0aa17cc855f3 --- /dev/null +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -0,0 +1,389 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/numpy_bridge.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +namespace swig { + +namespace numpy { + +int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) { + switch (primitive_type) { + case PRED: + return NPY_BOOL; + case S8: + return NPY_INT8; + case S16: + return NPY_INT16; + case S32: + return NPY_INT32; + case S64: + return NPY_INT64; + case U8: + return NPY_UINT8; + case U16: + return NPY_UINT16; + case U32: + return NPY_UINT32; + case U64: + return NPY_UINT64; + case F16: + return NPY_FLOAT16; + case F32: + return NPY_FLOAT32; + case F64: + return NPY_FLOAT64; + case TUPLE: + return NPY_OBJECT; + default: + LOG(FATAL) << "No Numpy type for XLA primitive type " << primitive_type; + } +} + +PrimitiveType NumpyTypeToPrimitiveType(int np_type) { + switch (np_type) { + case NPY_BOOL: + return PRED; + case NPY_INT8: + return S8; + case NPY_INT16: + return S16; + case NPY_INT32: + return S32; + case NPY_INT64: + return S64; + case NPY_UINT8: + return U8; + case NPY_UINT16: + return U16; + case NPY_UINT32: + return U32; + case NPY_UINT64: + return U64; + case NPY_FLOAT16: + return F16; + case NPY_FLOAT32: + return F32; + case NPY_FLOAT64: + return F64; + case NPY_OBJECT: + return TUPLE; + default: + LOG(FATAL) << "No XLA primitive type for Numpy type " << np_type; + } +} + +bool NumpyTypeIsValid(int np_type) { + switch (np_type) { + case NPY_BOOL: + case NPY_INT8: + case NPY_INT16: + case NPY_INT32: + case NPY_INT64: + case NPY_UINT8: + case NPY_UINT16: + case NPY_UINT32: + case NPY_UINT64: + case NPY_FLOAT16: + case NPY_FLOAT32: + case NPY_FLOAT64: + case NPY_OBJECT: + return true; + default: + return false; + } +} + +PyObject* PyShapeInfoFromXlaShape(const Shape& shape) { + int np_typenum = PrimitiveTypeToNumpyType(shape.element_type()); + PyArray_Descr* np_dtype = PyArray_DescrFromType(np_typenum); + + PyObject* dimensions; + if (ShapeUtil::IsTuple(shape)) { + int num_elements = ShapeUtil::TupleElementCount(shape); + dimensions = PyTuple_New(ShapeUtil::TupleElementCount(shape)); + for (int i = 0; i < num_elements; ++i) { + PyTuple_SET_ITEM( + dimensions, i, + PyShapeInfoFromXlaShape(ShapeUtil::GetTupleElementShape(shape, i))); + } + } else { + int rank = ShapeUtil::Rank(shape); + dimensions = PyTuple_New(rank); + for (int i = 0; i < rank; ++i) { + PyTuple_SET_ITEM(dimensions, i, + LongToPyIntOrPyLong(ShapeUtil::GetDimension(shape, i))); + } + } + return PyTuple_Pack(2, np_dtype, dimensions); +} + +// Precondition: o->ob_type == &PyArrayDescr_Type +static int NumpyTypenum(PyObject* o) { + return reinterpret_cast(o)->type_num; +} + +bool CheckPyShapeInfo(PyObject* o) { + // The object is a tuple (a pair) + if (!PyTuple_Check(o)) { + PyErr_SetString(PyExc_TypeError, "Shape record must be a tuple"); + return false; + } + if (PyTuple_Size(o) != 2) { + PyErr_SetString(PyExc_ValueError, "Shape record tuple must be of length 2"); + return false; + } + + // It has a first element, which is a numpy dtype object + PyObject* first = PyTuple_GetItem(o, 0); + if (!first) { + return false; + } + if (first->ob_type != &PyArrayDescr_Type) { + PyErr_SetString( + PyExc_TypeError, + "Shape record does not have a numpy dtype as its first element"); + return false; + } + const int np_type = NumpyTypenum(first); + if (!NumpyTypeIsValid(np_type)) { + PyErr_SetString(PyExc_ValueError, + "Shape record has an invalid integer dtype"); + return false; + } + + // It has a second element, which is a tuple, either of shape + // records or of Python ints + PyObject* second = PyTuple_GetItem(o, 1); + if (!second) { + return false; + } + if (!PyTuple_Check(second)) { + PyErr_SetString(PyExc_TypeError, + "Shape record does not have a tuple as its second element"); + return false; + } + const int length = PyTuple_Size(second); + const PrimitiveType element_type = NumpyTypeToPrimitiveType(np_type); + for (int i = 0; i < length; i++) { + PyObject* dimension = PyTuple_GetItem(second, i); + if (element_type == TUPLE) { + if (!CheckPyShapeInfo(dimension)) { + return false; + } + } else if (!CheckPyIntOrLong(dimension)) { + PyErr_SetString(PyExc_TypeError, + "Non-tuple shape record has a non-integer dimension"); + return false; + } + } + + return true; +} + +// Precondition: CheckPyShapeInfo(o) +Shape XlaShapeFromPyShapeInfo(PyObject* o) { + const int np_type = NumpyTypenum(PyTuple_GetItem(o, 0)); + const PrimitiveType element_type = NumpyTypeToPrimitiveType(np_type); + PyObject* py_dimensions = PyTuple_GetItem(o, 1); + const int length = PyTuple_Size(py_dimensions); + if (element_type == TUPLE) { + std::vector subshapes; + subshapes.reserve(length); + for (int i = 0; i < length; i++) { + subshapes.push_back( + XlaShapeFromPyShapeInfo(PyTuple_GetItem(py_dimensions, i))); + } + return ShapeUtil::MakeTupleShape(subshapes); + } else { + std::vector dimensions(length); + for (int i = 0; i < length; i++) { + dimensions[i] = PyIntOrPyLongToLong(PyTuple_GetItem(py_dimensions, i)); + if (dimensions[i] == -1) { + CHECK(!PyErr_Occurred()); + } + } + return ShapeUtil::MakeShape(element_type, dimensions); + } +} + +PyObject* PyObjectFromXlaLiteral(const Literal& literal) { + if (ShapeUtil::IsTuple(literal.shape())) { + const std::vector& tuple_literals = literal.tuple_literals(); + int num_elements = ShapeUtil::TupleElementCount(literal.shape()); + PyObject* tuple = PyTuple_New(num_elements); + for (int i = 0; i < num_elements; i++) { + PyTuple_SET_ITEM(tuple, i, PyObjectFromXlaLiteral(tuple_literals[i])); + } + return tuple; + } else { + int rank = ShapeUtil::Rank(literal.shape()); + std::vector dimensions(rank); // NOLINT - PyArray requires a long* + for (int i = 0; i < rank; i++) { + dimensions[i] = ShapeUtil::GetDimension(literal.shape(), i); + } + int np_type = PrimitiveTypeToNumpyType(literal.shape().element_type()); + PyObject* array = + PyArray_EMPTY(rank, dimensions.data(), np_type, /*fortran=*/0); + CopyLiteralToNumpyArray(np_type, literal, + reinterpret_cast(array)); + return array; + } +} + +std::unique_ptr XlaLiteralFromPyObject(PyObject* o) { + if (PyTuple_Check(o)) { + int num_elements = PyTuple_Size(o); + std::vector> elements; + elements.reserve(num_elements); + for (int i = 0; i < num_elements; i++) { + PyObject* element = PyTuple_GetItem(o, i); + elements.push_back(XlaLiteralFromPyObject(element)); + } + return Literal::MakeTupleOwned(std::move(elements)); + } else if (PyArray_Check(o)) { + PyArrayObject* py_array = reinterpret_cast(o); + int rank = PyArray_NDIM(py_array); + std::vector dimensions(rank); + for (int i = 0; i < rank; i++) { + dimensions[i] = PyArray_DIM(py_array, i); + } + int np_type = PyArray_TYPE(py_array); + auto literal = Literal::CreateFromDimensions( + NumpyTypeToPrimitiveType(np_type), dimensions); + CopyNumpyArrayToLiteral(np_type, py_array, literal.get()); + return literal; + } else { + LOG(FATAL) + << "Non-tuple or Numpy array encountered in conversion to XLA literal"; + } +} + +void CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, + Literal* literal) { + switch (np_type) { + case NPY_BOOL: + CopyNumpyArrayToLiteral(py_array, literal); + break; + case NPY_INT32: + CopyNumpyArrayToLiteral(py_array, literal); + break; + case NPY_INT64: + CopyNumpyArrayToLiteral(py_array, literal); + break; + case NPY_UINT8: + CopyNumpyArrayToLiteral(py_array, literal); + break; + case NPY_UINT32: + CopyNumpyArrayToLiteral(py_array, literal); + break; + case NPY_UINT64: + CopyNumpyArrayToLiteral(py_array, literal); + break; + case NPY_FLOAT16: + CopyNumpyArrayToLiteral(py_array, literal); + break; + case NPY_FLOAT32: + CopyNumpyArrayToLiteral(py_array, literal); + break; + case NPY_FLOAT64: + CopyNumpyArrayToLiteral(py_array, literal); + break; + default: + LOG(FATAL) << "No XLA literal container for Numpy type" << np_type; + } +} + +void CopyLiteralToNumpyArray(int np_type, const Literal& literal, + PyArrayObject* py_array) { + switch (np_type) { + case NPY_BOOL: + CopyLiteralToNumpyArray(literal, py_array); + break; + case NPY_INT32: + CopyLiteralToNumpyArray(literal, py_array); + break; + case NPY_INT64: + CopyLiteralToNumpyArray(literal, py_array); + break; + case NPY_UINT8: + CopyLiteralToNumpyArray(literal, py_array); + break; + case NPY_UINT32: + CopyLiteralToNumpyArray(literal, py_array); + break; + case NPY_UINT64: + CopyLiteralToNumpyArray(literal, py_array); + break; + case NPY_FLOAT16: + CopyLiteralToNumpyArray(literal, py_array); + break; + case NPY_FLOAT32: + CopyLiteralToNumpyArray(literal, py_array); + break; + case NPY_FLOAT64: + CopyLiteralToNumpyArray(literal, py_array); + break; + default: + LOG(FATAL) << "No XLA literal container for Numpy type" << np_type; + } +} + +PyObject* LongToPyIntOrPyLong(long x) { // NOLINT +#if PY_MAJOR_VERSION < 3 + return PyInt_FromLong(x); +#else + return PyLong_FromLong(x); +#endif +} + +long PyIntOrPyLongToLong(PyObject* o) { // NOLINT +#if PY_MAJOR_VERSION < 3 + return PyInt_AsLong(o); +#else + return PyLong_AsLong(o); +#endif +} + +bool CheckPyIntOrLong(PyObject* o) { +#if PY_MAJOR_VERSION < 3 + return PyInt_Check(o); +#else + if (!PyLong_Check(o)) { + return false; + } + int overflow = 0; + PyLong_AsLongAndOverflow(o, &overflow); + return (overflow == 0); +#endif +} + +PyObject* PyNumberToPyInt(PyObject* o) { +#if PY_MAJOR_VERSION < 3 + return PyNumber_Int(o); +#else + return PyNumber_Long(o); +#endif +} + +} // namespace numpy + +} // namespace swig + +} // namespace xla diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h new file mode 100644 index 0000000000000000000000000000000000000000..4e6ecbb0e8b58979ec1f1484e722725c391106fb --- /dev/null +++ b/tensorflow/compiler/xla/python/numpy_bridge.h @@ -0,0 +1,123 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// These functions transform Python/Numpy data structures to XLA data +// structures and vice versa, performing copies where +// appropriate. Python tuples and Numpy ndarrays translate to XLA +// tuples and XLA literals, respectively, and Numpy shape/dtype +// information is translated to XLA shape information. + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_NUMPY_BRIDGE_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_NUMPY_BRIDGE_H_ + +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/python/lib/core/numpy.h" + +namespace xla { + +namespace swig { + +namespace numpy { + +// Maps XLA primitive types (PRED, S8, F32, ..., and TUPLE) to numpy +// dtypes (NPY_BOOL, NPY_INT8, NPY_FLOAT32, ..., and NPY_OBJECT), and +// vice versa. +int PrimitiveTypeToNumpyType(PrimitiveType primitive_type); +PrimitiveType NumpyTypeToPrimitiveType(int np_type); + +// Determines whether an integer-encoded Numpy dtype is valid, +// i.e. has a supported conversion to an XLA PrimitiveType. +bool NumpyTypeIsValid(int np_type); + +// Converts XLA shape information into a Python pair of the form +// (numpy dtype, dimensions). If the XLA shape represents a tuple, +// then the numpy dtype is NPY_OBJECT ('O') and `dimensions` is a +// Python tuple of shape-description pairs, created +// recursively. Otherwise, `dimensions` is a Python tuple-of-integers +// providing the array dimensions. +// +// The return value is a new reference. +PyObject* PyShapeInfoFromXlaShape(const Shape& shape); + +// Returns the outcome of a best-effort check that the Python object +// is a pair of the form (numpy dtype, dimensions), as produced by +// PyShapeInfoFromXlaShape. +bool CheckPyShapeInfo(PyObject* o); + +// Performs the inverse conversion to that of PyShapeInfoFromXlaShape. +// +// The return value is a new reference. +Shape XlaShapeFromPyShapeInfo(PyObject* o); + +// Converts an XLA literal to a Python object, either a Numpy ndarray +// or a nested Python tuple thereof. +// +// To avoid transferring ownership of the data buffers that underlie +// PyArrays and XLA literals, this function makes deep copies of all +// array data. +// +// The return value is a new reference. +PyObject* PyObjectFromXlaLiteral(const Literal& literal); + +// Converts a Numpy ndarray or a nested Python tuple thereof to a +// corresponding XLA literal. +// +// To avoid transferring ownership of the data buffers that underlie +// PyArrays and XLA literals, this function makes deep copies of all +// array data. +std::unique_ptr XlaLiteralFromPyObject(PyObject* o); + +// The following functions copy array data from the buffers underlying Numpy +// ndarrays into those underlying XLA literals, and vice versa. + +void CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, + Literal* literal); + +void CopyLiteralToNumpyArray(int np_type, const Literal& literal, + PyArrayObject* py_array); + +template +void CopyNumpyArrayToLiteral(PyArrayObject* py_array, Literal* literal) { + NativeT* source = static_cast(PyArray_DATA(py_array)); + auto dest = literal->GetMutableArraySlice(); + std::copy(source, source + PyArray_SIZE(py_array), dest.data()); +} + +template +void CopyLiteralToNumpyArray(const Literal& literal, PyArrayObject* py_array) { + NativeT* dest = static_cast(PyArray_DATA(py_array)); + auto source = literal.GetArraySlice(); + std::copy(source.begin(), source.end(), dest); +} + +// Workarounds for Python 2 and 3 interop + +PyObject* LongToPyIntOrPyLong(long x); // NOLINT +long PyIntOrPyLongToLong(PyObject* o); // NOLINT +bool CheckPyIntOrLong(PyObject* o); +PyObject* PyNumberToPyInt(PyObject* o); + +} // namespace numpy + +} // namespace swig + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_NUMPY_BRIDGE_H_ diff --git a/tensorflow/compiler/xla/python/xla.i b/tensorflow/compiler/xla/python/xla.i new file mode 100644 index 0000000000000000000000000000000000000000..1c4021a558d3fcff2abfdbdbad7f3928e86ed3b8 --- /dev/null +++ b/tensorflow/compiler/xla/python/xla.i @@ -0,0 +1,18 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/* XLA-wide SWIG wrapper */ + +%include "tensorflow/compiler/xla/python/local_computation_builder.i" diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py new file mode 100644 index 0000000000000000000000000000000000000000..c75d54856dd699ec5cd8a2337007a064ba709de8 --- /dev/null +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -0,0 +1,605 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""An in-process, local XLA client in Python, supporting AOT compilation.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +import numpy as np + +from tensorflow.compiler.xla import xla_data_pb2 +from tensorflow.compiler.xla.python import pywrap_xla as c_api + +_UNARY_OPS = [ + 'Not', + 'Abs', + 'Exp', + 'Floor', + 'Ceil', + 'Log', + 'Sign', + 'Cos', + 'Sin', + 'Tanh', + 'SqrtF32', + 'SquareF32', + 'IsFinite', + 'ReciprocalF32', + 'Neg', + 'Sort', +] + +_BINARY_OPS = [ + 'Eq', + 'Ne', + 'Ge', + 'Gt', + 'Lt', + 'Le', + 'Add', + 'Sub', + 'Mul', + 'Div', + 'Rem', + 'Max', + 'Min', + 'And', + 'Or', + 'Pow', +] + +# Most functions are snake_case for consistency with other modules, +# whereas method names of ComputationBuilder and LocalComputation are +# CamelCase for consistency with XLA. +# pylint: disable=invalid-name + +XLA_ELEMENT_TYPE_TO_DTYPE = { + xla_data_pb2.F32: np.dtype(np.float32), + xla_data_pb2.F64: np.dtype(np.float64), + xla_data_pb2.S32: np.dtype(np.int32), + xla_data_pb2.S64: np.dtype(np.int64), + xla_data_pb2.PRED: np.dtype(np.bool), + xla_data_pb2.TUPLE: np.dtype(np.object), +} + +DTYPE_TO_XLA_ELEMENT_TYPE = { + str(v): k + for k, v in XLA_ELEMENT_TYPE_TO_DTYPE.items() +} + + +class Shape(object): + """XLA shape. + + Represents an XLA shape by a corresponding Python/Numpy type and a + list of dimensions, which are themselves Shapes in case this one + represents an XLA tuple. + """ + + def __init__(self, np_dtype, dimensions): + self.np_dtype = np_dtype + self._dimensions = dimensions + + def element_type(self): + return DTYPE_TO_XLA_ELEMENT_TYPE[str(self.np_dtype)] + + def is_tuple(self): + return self.element_type() == xla_data_pb2.TUPLE + + def dimensions(self): + if self.is_tuple(): + raise ValueError('Tuple shape has no dimensions') + return self._dimensions + + def tuple_shapes(self): + if not self.is_tuple(): + raise ValueError('Shape is not a tuple shape') + return self._dimensions + + @staticmethod + def from_numpy(npval): + + def convert(npval): + if isinstance(npval, tuple): + return Shape(np.dtype('O'), tuple(convert(elt) for elt in npval)) + else: + return Shape(npval.dtype, np.shape(npval)) + + return convert(require_numpy_array_layout(npval)) + + +def _wrap_shape(shape_info): + dtype, dims = shape_info + element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(dtype)] + if element_type == xla_data_pb2.TUPLE: + dims = [_wrap_shape(subshape_info) for subshape_info in dims] + return Shape(dtype, dims) + + +def _unwrap_shape(shape): + if shape.is_tuple(): + components = tuple( + _unwrap_shape(subshape) for subshape in shape.tuple_shapes()) + else: + components = shape.dimensions() + return (shape.np_dtype, components) + + +def _unwrap_shapes(shapes): + return [_unwrap_shape(shape) for shape in shapes] + + +def _wrap_data_handle(handle): + cdh = xla_data_pb2.ComputationDataHandle() + cdh.handle = handle + return cdh + + +def _unwrap_data_handle(handle_proto): + return handle_proto.handle + + +def _unwrap_data_handles(handle_protos): + return [_unwrap_data_handle(cdh) for cdh in handle_protos] + + +def require_numpy_array_layout(value): + if isinstance(value, tuple): + return tuple(require_numpy_array_layout(x) for x in value) + else: + return np.require(value, requirements=['C', 'A']) + + +class LocalComputation(object): + """Python wrapper for a local XLA Computation. + + A LocalComputation can be executed if it is compiled. Otherwise, it + can still be used as a Computation where required by the + ComputationBuilder methods. + """ + + def __init__(self, c_local_computation, is_compiled): + self.c_local_computation = c_local_computation + self.is_compiled = is_compiled + + # Ensure a reference to C-based destructor for use in __del__. + if is_compiled: + self._delete = c_api.DeleteCompiledLocalComputation + else: + self._delete = c_api.DeleteLocalComputation + + def Compile(self, argument_shapes=()): + if self.is_compiled: + raise ValueError('Attempt to compile a compiled local XLA computation.') + return LocalComputation( + self.c_local_computation.Compile(_unwrap_shapes(argument_shapes)), + is_compiled=True) + + def CompileWithExampleArguments(self, arguments=()): + return self.Compile( + argument_shapes=[Shape.from_numpy(arg) for arg in arguments]) + + def Execute(self, arguments=()): + if not self.is_compiled: + raise ValueError('Cannot execute an uncompiled local XLA computation.') + arguments = tuple(map(require_numpy_array_layout, arguments)) + return self.c_local_computation.Execute(arguments) + + def __del__(self): + self._delete(self.c_local_computation) + + +class ComputationBuilder(object): + """XLA computation builder. + + Enqueues XLA ops in sequence and in order to build a + LocalComputation, which in turn can be compiled into a + CompiledLocalComputation, which in turn can be locally executed. + """ + + # The methods of this class map 1-to-1 onto the XLA C++ + # computation builder API. Therefore, there's no need to laboriously list + # arguments and return values for every method, especially where it's obvious. + # + # pylint: disable=g-doc-return-or-yield + # pylint: disable=g-doc-args + + def __init__(self, name): + self._client = c_api.LocalComputationBuilder(name.encode('utf8')) + self._parameter_numbering = itertools.count() + + def Build(self): + return LocalComputation(self._client.Build(), is_compiled=False) + + def Constant(self, value): + """Enqueues a constant op onto the computation. + + Args: + value: value for the constant, as a np.array with an explicit dtype set + to one of the supported types. + + Returns: + A ComputationDataHandle message. + """ + value = require_numpy_array_layout(value) + return _wrap_data_handle(self._client.ConstantLiteral(value)) + + def ConstantF32Scalar(self, value): + """Convenience method to enqueue a scalar F32 constant op. + + Args: + value: a floating-point number. + + Returns: + A ComputationDataHandle message. + """ + return self.Constant(np.array(value, dtype=np.float32)) + + def ConstantF64Scalar(self, value): + """Convenience method to enqueue a scalar F32 constant op. + + Args: + value: a floating-point number. + + Returns: + A ComputationDataHandle message. + """ + return self.Constant(np.array(value, dtype=np.float64)) + + def ConstantS32Scalar(self, value): + """Convenience method to enqueue a scalar S32 constant op. + + Args: + value: a floating-point number. + + Returns: + A ComputationDataHandle message. + """ + return self.Constant(np.array(value, dtype=np.int32)) + + def ConstantS64Scalar(self, value): + """Convenience method to enqueue a scalar S64 constant op. + + Args: + value: a floating-point number. + + Returns: + A ComputationDataHandle message. + """ + return self.Constant(np.array(value, dtype=np.int64)) + + def ConstantPredScalar(self, value): + """Convenience method to enqueue a scalar PRED constant op. + + Args: + value: a boolean value. + + Returns: + A ComputationDataHandle message. + """ + return self.Constant(np.array(value, dtype=np.bool)) + + def ParameterWithShape(self, shape, name=None, parameter_num=None): + """Enqueues a Parameter op onto the computation, given a shape. + + Args: + shape: the parameter's shape as a Shape object. + name: optional string name for the parameter. + parameter_num: parameter number in the computation function. If None, + the next linear parameter number is used. The default value capability + can be used for auto-numbering. If you're using auto-numbering for some + parameters, use it for *all* parameters to avoid clashes. + + Returns: + A ComputationDataHandle message. + """ + if name is None: + name = '' + if parameter_num is None: + parameter_num = next(self._parameter_numbering) + + return _wrap_data_handle( + self._client.Parameter( + parameter_num, _unwrap_shape(shape), name.encode('utf8'))) + + def ParameterFromNumpy(self, value, name=None, parameter_num=None): + """Enqueues a Parameter op onto the computation. + + Args: + value: a Numpy array, or a nested tuple thereof, from which the + shape is inferred. + name: as in ParameterWithShape. + parameter_num: as in ParameterWithShape. + + Returns: + A ComputationDataHandle message. + """ + return self.ParameterWithShape( + Shape.from_numpy(value), name=name, parameter_num=parameter_num) + + def Broadcast(self, operand, sizes): + """Enqueues a broadcast operation onto the computation. + + Args: + operand: the operand ComputationDataHandle to broadcast. + sizes: an iterable of broadcast sizes. + + Returns: + A ComputationDataHandle representing the added broadcast op. + """ + return _wrap_data_handle( + self._client.Broadcast(_unwrap_data_handle(operand), sizes)) + + def Concatenate(self, operands, dimension): + """Enqueues a concatenate operation onto the computation. + + Args: + operands: the operands to concatenate. + dimension: the dimension in which to perform the concatenation. + + Returns: + A ComputationDataHandle representing the added concatenate op. + """ + return _wrap_data_handle( + self._client.ConcatInDim(_unwrap_data_handles(operands), dimension)) + + def ConvertElementType(self, operand, new_element_type): + """Enqueues an element type conversion operation onto the computation. + + Args: + operand: the operand to convert. + new_element_type: the target primitive type. + + Returns: + A ComputationDataHandle representing the added conversion op. + """ + return _wrap_data_handle( + self._client.ConvertElementType( + _unwrap_data_handle(operand), new_element_type)) + + def GetShape(self, operand): + return _wrap_shape(self._client.GetShape(_unwrap_data_handle(operand))) + + def GetComputationStats(self): + raise NotImplementedError() + + def Reshape(self, operand, dimensions, new_sizes): + """Reshape op.""" + return _wrap_data_handle( + self._client.Reshape( + _unwrap_data_handle(operand), dimensions, new_sizes)) + + def Trans(self, operand): + """Specialized matrix transpose op.""" + return _wrap_data_handle( + self._client.Transpose(_unwrap_data_handle(operand), [1, 0])) + + def Transpose(self, operand, permutation): + """Transpose op.""" + return _wrap_data_handle( + self._client.Transpose(_unwrap_data_handle(operand), permutation)) + + def Select(self, pred, on_true, on_false): + """Element-wise selection op. + + Constructs an output array from elements of two input arrays, based on the + values of a predicate array. + """ + return _wrap_data_handle( + self._client.Select( + _unwrap_data_handle(pred), + _unwrap_data_handle(on_true), + _unwrap_data_handle(on_false))) + + def Slice(self, operand, start_indices, limit_indices, strides=None): + """Enqueues a slice operation onto the computation. + + Args: + operand: ComputationDataHandle for the N dimensional array to be sliced. + start_indices: iterable of N integers containing the starting indices of + the slice for each dimension. + limit_indices: iterable of N integers containing the ending indices + (exclusive) of the slice for each dimension. + strides: optional iterable of N integers containing the stride sizes for + each dimension. + + Returns: + A ComputationDataHandle representing the added Slice op. + """ + if strides is None: + start_indices = list(start_indices) + strides = [1] * len(start_indices) + return _wrap_data_handle( + self._client.Slice( + _unwrap_data_handle(operand), + start_indices, + limit_indices, + strides)) + + def DynamicSlice(self, operand, start_indices, slice_sizes): + """Enqueues a slice op with dynamic start indices onto the computation. + + Args: + operand: ComputationDataHandle for the N dimensional array to be sliced. + start_indices: ComputationDataHandle for the 1D array of N integers + containing the starting indices of the slice. + slice_sizes: iterable of N integers containing the slice sizes in each + dimension. + + Returns: + A ComputationDataHandle representing the added DynamicSlice op. + """ + return _wrap_data_handle( + self._client.DynamicSlice( + _unwrap_data_handle(operand), + _unwrap_data_handle(start_indices), + slice_sizes)) + + def DynamicUpdateSlice(self, operand, update, start_indices): + """Enqueues a dynamic update slice operation onto the computation. + + Args: + operand: ComputationDataHandle for the N dimensional array to be updated. + update: N dimensional array comprising the slice update. + start_indices: Rank-1 array of N integers comprising the starting indices + of the slice along each dimension. + Returns: + A ComputationDataHandle representing the added DynamicUpdateSlice op. + """ + return _wrap_data_handle( + self._client.DynamicUpdateSlice( + _unwrap_data_handle(operand), + _unwrap_data_handle(update), + _unwrap_data_handle(start_indices))) + + def Tuple(self, *ops): + """Enqueues a tuple operation onto the computation. + + Args: + ops: a sequence of tuple operands (each a ComputationDataHandle). + + Returns: + A ComputationDataHandle representing the added Tuple op. + """ + return _wrap_data_handle(self._client.Tuple(_unwrap_data_handles(ops))) + + def GetTupleElement(self, tup, index): + """Enqueues a 'get tuple element' operation onto the computation. + + Args: + tup: the tuple operand (a ComputationDataHandle). + index: numeric index to select from the tuple. + + Returns: + A ComputationDataHandle representing the added GetTupleElement op. + """ + return _wrap_data_handle( + self._client.GetTupleElement(_unwrap_data_handle(tup), index)) + + def Call(self, computation_to_apply, operands): + """Enqueues a call operation onto the computation. + + Args: + computation_to_apply: a Computation object. + operands: an iterable of ComputationDataHandle. The number and types of + operands must match the arity of computation_to_apply. + + Returns: + A ComputationDataHandle representing the added call op. + """ + return _wrap_data_handle( + self._client.Call(computation_to_apply.c_local_computation, + _unwrap_data_handles(operands))) + + def Map(self, operands, computation_to_apply, dimensions, static_operands=()): + """Enqueues a map operation onto the computation. + + Args: + operands: an iterable of ComputationDataHandle. + computation_to_apply: a Computation object. + dimensions: dimensions over which to apply map the function. + static_operands: auxiliary arguments passed to the applied computation. + + Returns: + A ComputationDataHandle representing the added Map op. + """ + return _wrap_data_handle( + self._client.Map( + _unwrap_data_handles(operands), + computation_to_apply.c_local_computation, + dimensions, + _unwrap_data_handles(static_operands))) + + def Reduce(self, operand, init_value, computation_to_apply, dimensions): + """Enqueues a reduction operation onto the computation. + + Args: + operand: reduction operand (ComputationDataHandle). + init_value: reduction initial value (ComputationDataHandle). + computation_to_apply: a Computation object - binary reduction function. + dimensions: sequence of dimensions (integers) to reduce on. + + Returns: + A ComputationDataHandle representing the added Reduce op. + """ + return _wrap_data_handle( + self._client.Reduce( + _unwrap_data_handle(operand), + _unwrap_data_handle(init_value), + computation_to_apply.c_local_computation, + dimensions)) + + def While(self, cond, body, init): + """Enqueues a While operation onto the computation. + + Args: + cond: a Computation for the loop condition, which has type T -> PRED + body: a Computation for the loop body, which has type T -> T + init: an ComputationDataHandle for the initial parameter, which has type T + + Returns: a ComputationDataHandle representing the While operation. + """ + return _wrap_data_handle( + self._client.While(cond.c_local_computation, + body.c_local_computation, + _unwrap_data_handle(init))) + + def Dot(self, lhs, rhs): + """Matrix multiplication between lhs and rhs.""" + return _wrap_data_handle( + self._client.Dot(_unwrap_data_handle(lhs), _unwrap_data_handle(rhs))) + + +def _forward_methods_to_local_builder(): + """Forward remaining ComputationBuilder methods to the C API. + + Set up methods, corresponding to unary and binary XLA operations, + whose calls are forwarded in a boilerplate manner to the underlying + LocalComputationBuilder C-extension API. + """ + + def forward_to_local_builder_with_handles(target_method, is_binop=False): + """Generate a forwarding method that wraps/unwraps data handles.""" + + def forward(self, *args, **kwargs): + unwrapped_args = [_unwrap_data_handle(arg) for arg in args] + + if is_binop and len(unwrapped_args) < 3: + unwrapped_args.append(kwargs.get('broadcast_dimensions', ())) + + return _wrap_data_handle( + target_method( + self._client, # pylint: disable=protected-access + *unwrapped_args)) + + return forward + + for method_name in _UNARY_OPS: + forward = forward_to_local_builder_with_handles( + getattr(c_api.LocalComputationBuilder, method_name)) + forward.__name__ = method_name + setattr(ComputationBuilder, method_name, forward) + + for method_name in _BINARY_OPS: + forward = forward_to_local_builder_with_handles( + getattr(c_api.LocalComputationBuilder, method_name), is_binop=True) + forward.__name__ = method_name + setattr(ComputationBuilder, method_name, forward) + + +_forward_methods_to_local_builder() diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py new file mode 100644 index 0000000000000000000000000000000000000000..878cd83edcc4bffee6bcfe31fe6a4e2705edf401 --- /dev/null +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -0,0 +1,898 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the Python extension-based XLA client.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +import numpy as np + +from tensorflow.compiler.xla.python import xla_client +import unittest + + +class LocalComputationTest(unittest.TestCase): + """Base class for running an XLA Computation through the local client.""" + + def _NewComputation(self, name=None): + if name is None: + name = self.id() + return xla_client.ComputationBuilder(name) + + def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected): + assert expected is not None + compiled_c = c.Build().CompileWithExampleArguments(arguments) + result = compiled_c.Execute(arguments) + # Numpy's comparison methods are a bit too lenient by treating inputs as + # "array-like", meaning that scalar 4 will be happily compared equal to + # [[4]]. We'd like to be more strict so assert shapes as well. + self.assertEqual(np.asanyarray(result).shape, np.asanyarray(expected).shape) + assert_func(result, expected) + + def _ExecuteAndCompareExact(self, c, arguments=(), expected=None): + self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, expected) + + def _ExecuteAndCompareClose(self, c, arguments=(), expected=None): + self._ExecuteAndAssertWith(np.testing.assert_allclose, c, arguments, + expected) + + +def NumpyArrayF32(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.float32 dtype.""" + return np.array(*args, dtype=np.float32, **kwargs) + + +def NumpyArrayF64(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.float64 dtype.""" + return np.array(*args, dtype=np.float64, **kwargs) + + +def NumpyArrayS32(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.int32 dtype.""" + return np.array(*args, dtype=np.int32, **kwargs) + + +def NumpyArrayS64(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.int64 dtype.""" + return np.array(*args, dtype=np.int64, **kwargs) + + +def NumpyArrayBool(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.bool dtype.""" + return np.array(*args, dtype=np.bool, **kwargs) + + +class ComputationsWithConstantsTest(LocalComputationTest): + """Tests focusing on Constant ops.""" + + def testConstantScalarSumF32(self): + c = self._NewComputation() + c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14)) + self._ExecuteAndCompareClose(c, expected=4.25) + + def testConstantScalarSumF64(self): + c = self._NewComputation() + c.Add(c.ConstantF64Scalar(1.11), c.ConstantF64Scalar(3.14)) + self._ExecuteAndCompareClose(c, expected=4.25) + + def testConstantScalarSumS32(self): + c = self._NewComputation() + c.Add(c.ConstantS32Scalar(1), c.ConstantS32Scalar(2)) + self._ExecuteAndCompareClose(c, expected=3) + + def testConstantScalarSumS64(self): + c = self._NewComputation() + c.Add(c.ConstantS64Scalar(1), c.ConstantS64Scalar(2)) + self._ExecuteAndCompareClose(c, expected=3) + + def testConstantVectorMulF32(self): + c = self._NewComputation() + c.Mul( + c.Constant(NumpyArrayF32([2.5, 3.3, -1.2, 0.7])), + c.Constant(NumpyArrayF32([-1.2, 2, -2, -3]))) + self._ExecuteAndCompareClose(c, expected=[-3, 6.6, 2.4, -2.1]) + + def testConstantVectorMulF64(self): + c = self._NewComputation() + c.Mul( + c.Constant(NumpyArrayF64([2.5, 3.3, -1.2, 0.7])), + c.Constant(NumpyArrayF64([-1.2, 2, -2, -3]))) + self._ExecuteAndCompareClose(c, expected=[-3, 6.6, 2.4, -2.1]) + + def testConstantVectorScalarDivF32(self): + c = self._NewComputation() + c.Div( + c.Constant(NumpyArrayF32([1.5, 2.5, 3.0, -10.8])), + c.ConstantF32Scalar(2.0)) + self._ExecuteAndCompareClose(c, expected=[0.75, 1.25, 1.5, -5.4]) + + def testConstantVectorScalarDivF64(self): + c = self._NewComputation() + c.Div( + c.Constant(NumpyArrayF64([1.5, 2.5, 3.0, -10.8])), + c.ConstantF64Scalar(2.0)) + self._ExecuteAndCompareClose(c, expected=[0.75, 1.25, 1.5, -5.4]) + + def testConstantVectorScalarPowF32(self): + c = self._NewComputation() + c.Pow(c.Constant(NumpyArrayF32([1.5, 2.5, 3.0])), c.ConstantF32Scalar(2.)) + self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.]) + + def testConstantVectorScalarPowF64(self): + c = self._NewComputation() + c.Pow(c.Constant(NumpyArrayF64([1.5, 2.5, 3.0])), c.ConstantF64Scalar(2.)) + self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.]) + + def testBooleanAnd(self): + c = self._NewComputation() + c.And( + c.Constant(NumpyArrayBool([True, False, True, False])), + c.Constant(NumpyArrayBool([True, True, False, False]))) + self._ExecuteAndCompareExact(c, expected=[True, False, False, False]) + + def testBooleanOr(self): + c = self._NewComputation() + c.Or( + c.Constant(NumpyArrayBool([True, False, True, False])), + c.Constant(NumpyArrayBool([True, True, False, False]))) + self._ExecuteAndCompareExact(c, expected=[True, True, True, False]) + + def testSum2DF32(self): + c = self._NewComputation() + c.Add( + c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6]])), + c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]]))) + self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]]) + + def testSum2DF64(self): + c = self._NewComputation() + c.Add( + c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6]])), + c.Constant(NumpyArrayF64([[1, -1, 1], [-1, 1, -1]]))) + self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]]) + + def testSum2DWith1DBroadcastDim0F32(self): + # sum of a 2D array with a 1D array where the latter is replicated across + # dimension 0 to match the former's shape. + c = self._NewComputation() + c.Add( + c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + c.Constant(NumpyArrayF32([10, 20, 30])), + broadcast_dimensions=(0,)) + self._ExecuteAndCompareClose( + c, expected=[[11, 12, 13], [24, 25, 26], [37, 38, 39]]) + + def testSum2DWith1DBroadcastDim0F64(self): + # sum of a 2D array with a 1D array where the latter is replicated across + # dimension 0 to match the former's shape. + c = self._NewComputation() + c.Add( + c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + c.Constant(NumpyArrayF64([10, 20, 30])), + broadcast_dimensions=(0,)) + self._ExecuteAndCompareClose( + c, expected=[[11, 12, 13], [24, 25, 26], [37, 38, 39]]) + + def testSum2DWith1DBroadcastDim1F32(self): + # sum of a 2D array with a 1D array where the latter is replicated across + # dimension 1 to match the former's shape. + c = self._NewComputation() + c.Add( + c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + c.Constant(NumpyArrayF32([10, 20, 30])), + broadcast_dimensions=(1,)) + self._ExecuteAndCompareClose( + c, expected=[[11, 22, 33], [14, 25, 36], [17, 28, 39]]) + + def testSum2DWith1DBroadcastDim1F64(self): + # sum of a 2D array with a 1D array where the latter is replicated across + # dimension 1 to match the former's shape. + c = self._NewComputation() + c.Add( + c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + c.Constant(NumpyArrayF64([10, 20, 30])), + broadcast_dimensions=(1,)) + self._ExecuteAndCompareClose( + c, expected=[[11, 22, 33], [14, 25, 36], [17, 28, 39]]) + + def testConstantAxpyF32(self): + c = self._NewComputation() + c.Add( + c.Mul( + c.ConstantF32Scalar(2), + c.Constant(NumpyArrayF32([2.2, 3.3, 4.4, 5.5]))), + c.Constant(NumpyArrayF32([100, -100, 200, -200]))) + self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189]) + + def testConstantAxpyF64(self): + c = self._NewComputation() + c.Add( + c.Mul( + c.ConstantF64Scalar(2), + c.Constant(NumpyArrayF64([2.2, 3.3, 4.4, 5.5]))), + c.Constant(NumpyArrayF64([100, -100, 200, -200]))) + self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189]) + + +class ParametersTest(LocalComputationTest): + """Tests focusing on Parameter ops and argument-passing.""" + + def setUp(self): + self.f32_scalar_2 = NumpyArrayF32(2.0) + self.f32_4vector = NumpyArrayF32([-2.3, 3.3, -4.3, 5.3]) + self.f64_scalar_2 = NumpyArrayF64(2.0) + self.f64_4vector = NumpyArrayF64([-2.3, 3.3, -4.3, 5.3]) + self.s32_scalar_3 = NumpyArrayS32(3) + self.s32_4vector = NumpyArrayS32([10, 15, -2, 7]) + self.s64_scalar_3 = NumpyArrayS64(3) + self.s64_4vector = NumpyArrayS64([10, 15, -2, 7]) + + def testScalarTimesVectorAutonumberF32(self): + c = self._NewComputation() + p0 = c.ParameterFromNumpy(self.f32_scalar_2) + p1 = c.ParameterFromNumpy(self.f32_4vector) + c.Mul(p0, p1) + self._ExecuteAndCompareClose( + c, + arguments=[self.f32_scalar_2, self.f32_4vector], + expected=[-4.6, 6.6, -8.6, 10.6]) + + def testScalarTimesVectorAutonumberF64(self): + c = self._NewComputation() + p0 = c.ParameterFromNumpy(self.f64_scalar_2) + p1 = c.ParameterFromNumpy(self.f64_4vector) + c.Mul(p0, p1) + self._ExecuteAndCompareClose( + c, + arguments=[self.f64_scalar_2, self.f64_4vector], + expected=[-4.6, 6.6, -8.6, 10.6]) + + def testScalarTimesVectorS32(self): + c = self._NewComputation() + p0 = c.ParameterFromNumpy(self.s32_scalar_3) + p1 = c.ParameterFromNumpy(self.s32_4vector) + c.Mul(p0, p1) + self._ExecuteAndCompareExact( + c, + arguments=[self.s32_scalar_3, self.s32_4vector], + expected=[30, 45, -6, 21]) + + def testScalarTimesVectorS64(self): + c = self._NewComputation() + p0 = c.ParameterFromNumpy(self.s64_scalar_3) + p1 = c.ParameterFromNumpy(self.s64_4vector) + c.Mul(p0, p1) + self._ExecuteAndCompareExact( + c, + arguments=[self.s64_scalar_3, self.s64_4vector], + expected=[30, 45, -6, 21]) + + def testScalarMinusVectorExplicitNumberingF32(self): + # Use explicit numbering and pass parameter_num first. Sub is used since + # it's not commutative and can help catch parameter reversal within the + # computation. + c = self._NewComputation() + p1 = c.ParameterFromNumpy(self.f32_4vector, parameter_num=1) + p0 = c.ParameterFromNumpy(self.f32_scalar_2, parameter_num=0) + c.Sub(p1, p0) + self._ExecuteAndCompareClose( + c, + arguments=[self.f32_scalar_2, self.f32_4vector], + expected=[-4.3, 1.3, -6.3, 3.3]) + + def testScalarMinusVectorExplicitNumberingF64(self): + # Use explicit numbering and pass parameter_num first. Sub is used since + # it's not commutative and can help catch parameter reversal within the + # computation. + c = self._NewComputation() + p1 = c.ParameterFromNumpy(self.f64_4vector, parameter_num=1) + p0 = c.ParameterFromNumpy(self.f64_scalar_2, parameter_num=0) + c.Sub(p1, p0) + self._ExecuteAndCompareClose( + c, + arguments=[self.f64_scalar_2, self.f64_4vector], + expected=[-4.3, 1.3, -6.3, 3.3]) + + +class SingleOpTest(LocalComputationTest): + """Tests for single ops. + + The goal here is smoke testing - to exercise the most basic functionality of + single XLA ops. As minimal as possible number of additional ops are added + around the op being tested. + """ + + def testConcatenateF32(self): + c = self._NewComputation() + c.Concatenate( + (c.Constant(NumpyArrayF32([1.0, 2.0, 3.0])), + c.Constant(NumpyArrayF32([4.0, 5.0, 6.0]))), + dimension=0) + self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + + def testConcatenateF64(self): + c = self._NewComputation() + c.Concatenate( + (c.Constant(NumpyArrayF64([1.0, 2.0, 3.0])), + c.Constant(NumpyArrayF64([4.0, 5.0, 6.0]))), + dimension=0) + self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + + def testConvertElementType(self): + xla_types = { + np.bool: xla_client.xla_data_pb2.PRED, + np.int32: xla_client.xla_data_pb2.S32, + np.int64: xla_client.xla_data_pb2.S64, + np.float32: xla_client.xla_data_pb2.F32, + np.float64: xla_client.xla_data_pb2.F64, + } + + def _ConvertAndTest(template, src_dtype, dst_dtype): + c = self._NewComputation() + x = c.Constant(np.array(template, dtype=src_dtype)) + c.ConvertElementType(x, xla_types[dst_dtype]) + + result = c.Build().Compile().Execute() + expected = np.array(template, dtype=dst_dtype) + + self.assertEqual(result.shape, expected.shape) + self.assertEqual(result.dtype, expected.dtype) + np.testing.assert_equal(result, expected) + + x = [0, 1, 0, 0, 1] + for src_dtype, dst_dtype in itertools.product(xla_types, xla_types): + _ConvertAndTest(x, src_dtype, dst_dtype) + + def testDotMatrixVectorF32(self): + c = self._NewComputation() + lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]) + rhs = NumpyArrayF32([[10.0], [20.0]]) + c.Dot(c.Constant(lhs), c.Constant(rhs)) + self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) + + def testDotMatrixVectorF64(self): + c = self._NewComputation() + lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]]) + rhs = NumpyArrayF64([[10.0], [20.0]]) + c.Dot(c.Constant(lhs), c.Constant(rhs)) + self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) + + def testDotMatrixMatrixF32(self): + c = self._NewComputation() + lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]) + rhs = NumpyArrayF32([[10.0, 20.0], [100.0, 200.0]]) + c.Dot(c.Constant(lhs), c.Constant(rhs)) + self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) + + def testDotMatrixMatrixF64(self): + c = self._NewComputation() + lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]]) + rhs = NumpyArrayF64([[10.0, 20.0], [100.0, 200.0]]) + c.Dot(c.Constant(lhs), c.Constant(rhs)) + self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) + + def testBooleanNot(self): + c = self._NewComputation() + arr = NumpyArrayBool([True, False, True]) + c.Not(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=~arr) + + def testExp(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Exp(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.exp(arr)) + + def testLog(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Log(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.log(arr)) + + def testNeg(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Neg(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=-arr) + + def testFloor(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Floor(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.floor(arr)) + + def testCeil(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Ceil(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.ceil(arr)) + + def testAbs(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, -12.1, 2.4, -1.]) + c.Abs(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.abs(arr)) + + def testTanh(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Tanh(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.tanh(arr)) + + def testTrans(self): + + def _TransposeAndTest(array): + c = self._NewComputation() + c.Trans(c.Constant(array)) + self._ExecuteAndCompareClose(c, expected=array.T) + + # Test square and non-square matrices in both default (C) and F orders. + for array_fun in [NumpyArrayF32, NumpyArrayF64]: + _TransposeAndTest(array_fun([[1, 2, 3], [4, 5, 6]])) + _TransposeAndTest(array_fun([[1, 2, 3], [4, 5, 6]], order="F")) + _TransposeAndTest(array_fun([[1, 2], [4, 5]])) + _TransposeAndTest(array_fun([[1, 2], [4, 5]], order="F")) + + def testTranspose(self): + + def _TransposeAndTest(array, permutation): + c = self._NewComputation() + c.Transpose(c.Constant(array), permutation) + expected = np.transpose(array, permutation) + self._ExecuteAndCompareClose(c, expected=expected) + + _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [0, 1]) + _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [1, 0]) + _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [0, 1]) + _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [1, 0]) + + arr = np.random.RandomState(0).randn(2, 3, 4).astype(np.float32) + for permutation in itertools.permutations(range(arr.ndim)): + _TransposeAndTest(arr, permutation) + _TransposeAndTest(np.asfortranarray(arr), permutation) + + def testEq(self): + c = self._NewComputation() + c.Eq( + c.Constant(NumpyArrayS32([1, 2, 3, 4])), + c.Constant(NumpyArrayS32([4, 2, 3, 1]))) + self._ExecuteAndCompareExact(c, expected=[False, True, True, False]) + + def testNe(self): + c = self._NewComputation() + c.Ne( + c.Constant(NumpyArrayS32([1, 2, 3, 4])), + c.Constant(NumpyArrayS32([4, 2, 3, 1]))) + self._ExecuteAndCompareExact(c, expected=[True, False, False, True]) + + c.Ne( + c.Constant(NumpyArrayF32([-2.0, 0.0, + float("nan"), + float("nan")])), + c.Constant(NumpyArrayF32([2.0, -0.0, 1.0, float("nan")]))) + self._ExecuteAndAssertWith( + np.testing.assert_allclose, c, (), expected=[True, False, True, True]) + + def testGt(self): + c = self._NewComputation() + c.Gt( + c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), + c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) + self._ExecuteAndCompareExact(c, expected=[False, True, True, False, False]) + + def testGe(self): + c = self._NewComputation() + c.Ge( + c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), + c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) + self._ExecuteAndCompareExact(c, expected=[True, True, True, False, False]) + + def testLt(self): + c = self._NewComputation() + c.Lt( + c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), + c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) + self._ExecuteAndCompareExact(c, expected=[False, False, False, True, True]) + + def testLe(self): + c = self._NewComputation() + c.Le( + c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), + c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) + self._ExecuteAndCompareExact(c, expected=[True, False, False, True, True]) + + def testMax(self): + c = self._NewComputation() + c.Max( + c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), + c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) + self._ExecuteAndCompareExact(c, expected=[1.0, 2.0, 3.0, 7.0, 12.0]) + + def testMaxExplicitBroadcastDim0(self): + c = self._NewComputation() + c.Max( + c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + c.Constant(NumpyArrayF32([3, 4, 5])), + broadcast_dimensions=(0,)) + self._ExecuteAndCompareExact(c, expected=[[3, 3, 3], [4, 5, 6], [7, 8, 9]]) + + def testMaxExplicitBroadcastDim1(self): + c = self._NewComputation() + c.Max( + c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + c.Constant(NumpyArrayF32([3, 4, 5])), + broadcast_dimensions=(1,)) + self._ExecuteAndCompareExact(c, expected=[[3, 4, 5], [4, 5, 6], [7, 8, 9]]) + + def testMin(self): + c = self._NewComputation() + c.Min( + c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), + c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) + self._ExecuteAndCompareExact(c, expected=[1.0, 0.0, 2.0, 4.0, 9.0]) + + def testReshape(self): + c = self._NewComputation() + c.Reshape( + c.Constant(NumpyArrayS32([[1, 2], [3, 4], [5, 6]])), + dimensions=[0, 1], + new_sizes=[2, 3]) + self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [4, 5, 6]]) + + def testSelect(self): + c = self._NewComputation() + c.Select( + c.Constant(NumpyArrayBool([True, False, False, True, False])), + c.Constant(NumpyArrayS32([1, 2, 3, 4, 5])), + c.Constant(NumpyArrayS32([-1, -2, -3, -4, -5]))) + self._ExecuteAndCompareExact(c, expected=[1, -2, -3, 4, -5]) + + def testSlice(self): + c = self._NewComputation() + c.Slice( + c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), [1, 0], + [3, 2]) + self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]]) + + def testDynamicSlice(self): + c = self._NewComputation() + c.DynamicSlice( + c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + c.Constant(NumpyArrayS32([1, 0])), [2, 2]) + self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]]) + + def testDynamicUpdateSlice(self): + c = self._NewComputation() + c.DynamicUpdateSlice( + c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + c.Constant(NumpyArrayS32([[1, 2], [3, 4]])), + c.Constant(NumpyArrayS32([1, 1]))) + self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [4, 1, 2], [7, 3, 4]]) + + def testTuple(self): + c = self._NewComputation() + c.Tuple( + c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])), + c.Constant(NumpyArrayBool([True, False, False, True]))) + result = c.Build().Compile().Execute() + self.assertIsInstance(result, tuple) + np.testing.assert_equal(result[0], 42) + np.testing.assert_allclose(result[1], [1.0, 2.0]) + np.testing.assert_equal(result[2], [True, False, False, True]) + + def testGetTupleElement(self): + c = self._NewComputation() + c.GetTupleElement( + c.Tuple( + c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])), + c.Constant(NumpyArrayBool([True, False, False, True]))), 1) + self._ExecuteAndCompareClose(c, expected=[1.0, 2.0]) + + def testBroadcast(self): + c = self._NewComputation() + c.Broadcast(c.Constant(NumpyArrayS32([10, 20, 30, 40])), sizes=(3,)) + self._ExecuteAndCompareExact( + c, expected=[[10, 20, 30, 40], [10, 20, 30, 40], [10, 20, 30, 40]]) + + +class EmbeddedComputationsTest(LocalComputationTest): + """Tests for XLA graphs with embedded computations (such as maps).""" + + def _CreateConstantS32Computation(self): + """Computation (f32) -> s32 that returns a constant 1 for any input.""" + c = self._NewComputation("constant_s32_one") + # TODO(eliben): consider adding a nicer way to create new parameters without + # having to create dummy Numpy arrays or populating Shape messages. Perhaps + # we need our own (Python-client-own) way to represent Shapes conveniently. + c.ParameterFromNumpy(NumpyArrayF32(0)) + c.ConstantS32Scalar(1) + return c.Build() + + def _CreateConstantS64Computation(self): + """Computation (f64) -> s64 that returns a constant 1 for any input.""" + c = self._NewComputation("constant_s64_one") + # TODO(eliben): consider adding a nicer way to create new parameters without + # having to create dummy Numpy arrays or populating Shape messages. Perhaps + # we need our own (Python-client-own) way to represent Shapes conveniently. + c.ParameterFromNumpy(NumpyArrayF64(0)) + c.ConstantS64Scalar(1) + return c.Build() + + def _CreateConstantF32Computation(self): + """Computation (f32) -> f32 that returns a constant 1.0 for any input.""" + c = self._NewComputation("constant_f32_one") + c.ParameterFromNumpy(NumpyArrayF32(0)) + c.ConstantF32Scalar(1.0) + return c.Build() + + def _CreateConstantF64Computation(self): + """Computation (f64) -> f64 that returns a constant 1.0 for any input.""" + c = self._NewComputation("constant_f64_one") + c.ParameterFromNumpy(NumpyArrayF64(0)) + c.ConstantF64Scalar(1.0) + return c.Build() + + def _CreateMulF32By2Computation(self): + """Computation (f32) -> f32 that multiplies its parameter by 2.""" + c = self._NewComputation("mul_f32_by2") + c.Mul(c.ParameterFromNumpy(NumpyArrayF32(0)), c.ConstantF32Scalar(2.0)) + return c.Build() + + def _CreateMulF64By2Computation(self): + """Computation (f64) -> f64 that multiplies its parameter by 2.""" + c = self._NewComputation("mul_f64_by2") + c.Mul(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(2.0)) + return c.Build() + + def _CreateBinaryAddF32Computation(self): + """Computation (f32, f32) -> f32 that adds its two parameters.""" + c = self._NewComputation("add_param0_by_param1") + c.Add( + c.ParameterFromNumpy(NumpyArrayF32(0)), + c.ParameterFromNumpy(NumpyArrayF32(0))) + return c.Build() + + def _CreateBinaryAddF64Computation(self): + """Computation (f64, f64) -> f64 that adds its two parameters.""" + c = self._NewComputation("add_param0_by_param1") + c.Add( + c.ParameterFromNumpy(NumpyArrayF64(0)), + c.ParameterFromNumpy(NumpyArrayF64(0))) + return c.Build() + + def _CreateBinaryDivF32Computation(self): + """Computation (f32, f32) -> f32 that divides its two parameters.""" + c = self._NewComputation("div_param0_by_param1") + c.Div( + c.ParameterFromNumpy(NumpyArrayF32(0)), + c.ParameterFromNumpy(NumpyArrayF32(0))) + return c.Build() + + def _CreateBinaryDivF64Computation(self): + """Computation (f64, f64) -> f64 that divides its two parameters.""" + c = self._NewComputation("div_param0_by_param1") + c.Div( + c.ParameterFromNumpy(NumpyArrayF64(0)), + c.ParameterFromNumpy(NumpyArrayF64(0))) + return c.Build() + + def _CreateTestF32Lt10Computation(self): + """Computation (f32) -> bool that tests if its parameter is less than 10.""" + c = self._NewComputation("test_f32_lt_10") + c.Lt(c.ParameterFromNumpy(NumpyArrayF32(0)), c.ConstantF32Scalar(10.)) + return c.Build() + + def _CreateTestF64Lt10Computation(self): + """Computation (f64) -> bool that tests if its parameter is less than 10.""" + c = self._NewComputation("test_f64_lt_10") + c.Lt(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(10.)) + return c.Build() + + def _MakeSample3DArrayF32(self): + return NumpyArrayF32([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]], + [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]]) + + def _MakeSample3DArrayF64(self): + return NumpyArrayF64([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]], + [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]]) + + def testCallF32(self): + c = self._NewComputation() + c.Call( + self._CreateMulF32By2Computation(), + operands=(c.ConstantF32Scalar(5.0),)) + self._ExecuteAndCompareClose(c, expected=10.0) + + def testCallF64(self): + c = self._NewComputation() + c.Call( + self._CreateMulF64By2Computation(), + operands=(c.ConstantF64Scalar(5.0),)) + self._ExecuteAndCompareClose(c, expected=10.0) + + def testMapEachElementToS32Constant(self): + c = self._NewComputation() + c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], + self._CreateConstantS32Computation(), [0]) + self._ExecuteAndCompareExact(c, expected=[1, 1, 1, 1]) + + def testMapEachElementToS64Constant(self): + c = self._NewComputation() + c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))], + self._CreateConstantS64Computation(), [0]) + self._ExecuteAndCompareExact(c, expected=[1, 1, 1, 1]) + + def testMapMulBy2F32(self): + c = self._NewComputation() + c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], + self._CreateMulF32By2Computation(), [0]) + self._ExecuteAndCompareClose(c, expected=[2.0, 4.0, 6.0, 8.0]) + + def testMapMulBy2F64(self): + c = self._NewComputation() + c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))], + self._CreateMulF64By2Computation(), [0]) + self._ExecuteAndCompareClose(c, expected=[2.0, 4.0, 6.0, 8.0]) + + def testSimpleMapChainF32(self): + # Chains a map of constant-f32 with a map of mul-by-2 + c = self._NewComputation() + const_f32 = c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], + self._CreateConstantF32Computation(), [0]) + c.Map([const_f32], self._CreateMulF32By2Computation(), [0]) + self._ExecuteAndCompareClose(c, expected=[2.0, 2.0, 2.0, 2.0]) + + def testSimpleMapChainF64(self): + # Chains a map of constant-f64 with a map of mul-by-2 + c = self._NewComputation() + const_f64 = c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))], + self._CreateConstantF64Computation(), [0]) + c.Map([const_f64], self._CreateMulF64By2Computation(), [0]) + self._ExecuteAndCompareClose(c, expected=[2.0, 2.0, 2.0, 2.0]) + + def testDivVectorsWithMapF32(self): + c = self._NewComputation() + c.Map((c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0])), + c.Constant(NumpyArrayF32([5.0, 5.0, 4.0, 4.0]))), + self._CreateBinaryDivF32Computation(), [0]) + self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0]) + + def testDivVectorsWithMapF64(self): + c = self._NewComputation() + c.Map((c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0])), + c.Constant(NumpyArrayF64([5.0, 5.0, 4.0, 4.0]))), + self._CreateBinaryDivF64Computation(), [0]) + self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0]) + + def testReduce1DtoScalarF32(self): + c = self._NewComputation() + c.Reduce( + operand=c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0])), + init_value=c.ConstantF32Scalar(0), + computation_to_apply=self._CreateBinaryAddF32Computation(), + dimensions=[0]) + self._ExecuteAndCompareClose(c, expected=10) + + def testReduce1DtoScalarF64(self): + c = self._NewComputation() + c.Reduce( + operand=c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0])), + init_value=c.ConstantF64Scalar(0), + computation_to_apply=self._CreateBinaryAddF64Computation(), + dimensions=[0]) + self._ExecuteAndCompareClose(c, expected=10) + + def testReduce2DTo1DDim0F32(self): + input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + c = self._NewComputation() + c.Reduce( + operand=c.Constant(input_array), + init_value=c.ConstantF32Scalar(0), + computation_to_apply=self._CreateBinaryAddF32Computation(), + dimensions=[0]) + self._ExecuteAndCompareClose(c, expected=[5, 7, 9]) + + def testReduce2DTo1DDim0F64(self): + input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + c = self._NewComputation() + c.Reduce( + operand=c.Constant(input_array), + init_value=c.ConstantF64Scalar(0), + computation_to_apply=self._CreateBinaryAddF64Computation(), + dimensions=[0]) + self._ExecuteAndCompareClose(c, expected=[5, 7, 9]) + + def testReduce2DTo1DDim1F32(self): + input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + c = self._NewComputation() + c.Reduce( + operand=c.Constant(input_array), + init_value=c.ConstantF32Scalar(0), + computation_to_apply=self._CreateBinaryAddF32Computation(), + dimensions=[1]) + self._ExecuteAndCompareClose(c, expected=[6, 15]) + + def testReduce2DTo1DDim1F64(self): + input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + c = self._NewComputation() + c.Reduce( + operand=c.Constant(input_array), + init_value=c.ConstantF64Scalar(0), + computation_to_apply=self._CreateBinaryAddF64Computation(), + dimensions=[1]) + self._ExecuteAndCompareClose(c, expected=[6, 15]) + + def testReduce3DAllPossibleWaysF32(self): + input_array = self._MakeSample3DArrayF32() + + def _ReduceAndTest(*dims): + c = self._NewComputation() + c.Reduce( + operand=c.Constant(input_array), + init_value=c.ConstantF32Scalar(0), + computation_to_apply=self._CreateBinaryAddF32Computation(), + dimensions=dims) + self._ExecuteAndCompareClose( + c, expected=np.sum(input_array, axis=tuple(dims))) + + _ReduceAndTest(0) + _ReduceAndTest(0) + _ReduceAndTest(0, 1) + _ReduceAndTest(0, 2) + _ReduceAndTest(1, 2) + _ReduceAndTest(0, 1, 2) + + def testReduce3DAllPossibleWaysF64(self): + input_array = self._MakeSample3DArrayF64() + + def _ReduceAndTest(*dims): + c = self._NewComputation() + c.Reduce( + operand=c.Constant(input_array), + init_value=c.ConstantF64Scalar(0), + computation_to_apply=self._CreateBinaryAddF64Computation(), + dimensions=dims) + self._ExecuteAndCompareClose( + c, expected=np.sum(input_array, axis=tuple(dims))) + + _ReduceAndTest(0) + _ReduceAndTest(0) + _ReduceAndTest(0, 1) + _ReduceAndTest(0, 2) + _ReduceAndTest(1, 2) + _ReduceAndTest(0, 1, 2) + + def testWhileF32(self): + cond = self._CreateTestF32Lt10Computation() + body = self._CreateMulF32By2Computation() + c = self._NewComputation() + init = c.ConstantF32Scalar(1.) + c.While(cond, body, init) + self._ExecuteAndCompareClose(c, expected=16.) + + def testWhileF64(self): + cond = self._CreateTestF64Lt10Computation() + body = self._CreateMulF64By2Computation() + c = self._NewComputation() + init = c.ConstantF64Scalar(1.) + c.While(cond, body, init) + self._ExecuteAndCompareClose(c, expected=16.) + + +if __name__ == "__main__": + unittest.main() diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index 90aa9720a1e18bad06842adeead46fc3120d01dd..0a155400159ef178e93c378ea22467c6e257b61d 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -102,7 +102,9 @@ ReferenceUtil::ConvArray3DGeneralDimensionsDilated( const Array3D& lhs, const Array3D& rhs, int64 kernel_stride, Padding padding, int64 lhs_dilation, int64 rhs_dilation, const ConvolutionDimensionNumbers& dnums) { - CHECK_EQ(dnums.spatial_dimensions_size(), 1); + CHECK_EQ(dnums.input_spatial_dimensions_size(), 1); + CHECK_EQ(dnums.kernel_spatial_dimensions_size(), 1); + CHECK_EQ(dnums.output_spatial_dimensions_size(), 1); // Reuse the code for Array4D-convolution by extending the 3D input into a 4D // array by adding a fourth dummy dimension of size 1 without stride, padding // and dilation. @@ -120,8 +122,9 @@ ReferenceUtil::ConvArray3DGeneralDimensionsDilated( }); // Add a second dummy spatial dimensions. ConvolutionDimensionNumbers dnums2d = dnums; - dnums2d.add_spatial_dimensions(3); + dnums2d.add_input_spatial_dimensions(3); dnums2d.add_kernel_spatial_dimensions(3); + dnums2d.add_output_spatial_dimensions(3); std::unique_ptr> convr4 = ConvArray4DGeneralDimensionsDilated( a4dlhs, a4drhs, {kernel_stride, 1}, padding, {lhs_dilation, 1}, {rhs_dilation, 1}, dnums2d); @@ -192,14 +195,26 @@ ReferenceUtil::ReduceWindow1DGeneric( const tensorflow::gtl::ArraySlice& window, const tensorflow::gtl::ArraySlice& stride, Padding padding) { std::vector dim_lengths{static_cast(operand.size())}; - auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); + return ReduceWindow1DGeneric( + operand, init, reduce_func, window, stride, + xla::MakePadding(dim_lengths, window, stride, padding)); +} +/* static */ std::unique_ptr> +ReferenceUtil::ReduceWindow1DGeneric( + const tensorflow::gtl::ArraySlice& operand, float init, + const std::function& reduce_func, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, + const tensorflow::gtl::ArraySlice>& padding) { + std::vector dim_lengths{static_cast(operand.size())}; std::vector window_counts(window.size(), 0); std::vector pad_low(window.size(), 0); for (int64 i = 0; i < window.size(); ++i) { + int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second; window_counts[i] = - WindowCount(dim_lengths[i], window[i], stride[i], padding); - pad_low[i] = padding_both[i].first; + window_util::StridedBound(padded_width, window[i], stride[i]); + pad_low[i] = padding[i].first; } auto result = MakeUnique>(window_counts[0]); @@ -465,9 +480,9 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( } ordered_input_dimensions[0] = - lhs_literal->shape().dimensions(dnums.spatial_dimensions(0)); + lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(0)); ordered_input_dimensions[1] = - lhs_literal->shape().dimensions(dnums.spatial_dimensions(1)); + lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(1)); ordered_kernel_dimensions[0] = rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0)); ordered_kernel_dimensions[1] = @@ -517,7 +532,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( HloEvaluator evaluator; std::unique_ptr result_literal = - evaluator.Evaluate(*computation, {}).ConsumeValueOrDie(); + evaluator.Evaluate(*computation, {}).ConsumeValueOrDie(); CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4); auto result = @@ -703,137 +718,4 @@ ReferenceUtil::ReduceToRowArray2D( return result; } -/* static */ std::unique_ptr> ReferenceUtil::PadArray2D( - const Array2D& operand, const PaddingConfig& padding, - const float pad) { - int64 in0 = operand.n1(); - int64 high_padding0 = padding.dimensions(0).edge_padding_high(); - int64 low_padding0 = padding.dimensions(0).edge_padding_low(); - int64 interior_padding0 = padding.dimensions(0).interior_padding(); - int64 out0 = - in0 + low_padding0 + high_padding0 + (in0 - 1) * interior_padding0; - - int64 in1 = operand.n2(); - int64 high_padding1 = padding.dimensions(1).edge_padding_high(); - int64 low_padding1 = padding.dimensions(1).edge_padding_low(); - int64 interior_padding1 = padding.dimensions(1).interior_padding(); - int64 out1 = - in1 + low_padding1 + high_padding1 + (in1 - 1) * interior_padding1; - - auto result = MakeUnique>(out0, out1); - result->Fill(pad); - int64 o0 = low_padding0; - for (int64 i0 = 0; i0 < in0; ++i0) { - int64 o1 = low_padding1; - for (int64 i1 = 0; i1 < in1; ++i1) { - if (o0 >= 0 && o1 >= 0 && o0 < out0 && o1 < out1) { - (*result)(o0, o1) = operand(i0, i1); - } - o1 += interior_padding1 + 1; - } - o0 += interior_padding0 + 1; - } - return result; -} - -/* static */ Array3D ReferenceUtil::PadArray3D( - const Array3D& operand, const PaddingConfig& padding, - const float pad) { - CHECK_EQ(padding.dimensions_size(), 3); - - const std::vector input_bounds = {operand.n1(), operand.n2(), - operand.n3()}; - std::vector pad_low(3); - std::vector pad_high(3); - std::vector pad_interior(3); - std::vector output_bounds(3); - for (int64 i = 0; i < 3; ++i) { - pad_low[i] = padding.dimensions(i).edge_padding_low(); - pad_high[i] = padding.dimensions(i).edge_padding_high(); - CHECK_LE(0, pad_low[i]); - CHECK_LE(0, pad_high[i]); - CHECK_LE(0, padding.dimensions(i).interior_padding()) << "not implemented"; - pad_interior[i] = padding.dimensions(i).interior_padding(); - - output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] + - (input_bounds[i] - 1) * pad_interior[i]; - } - - Array3D result(output_bounds[0], output_bounds[1], output_bounds[2]); - std::vector indices = {0, 0, 0}; - for (indices[0] = 0; indices[0] < output_bounds[0]; ++indices[0]) { - for (indices[1] = 0; indices[1] < output_bounds[1]; ++indices[1]) { - for (indices[2] = 0; indices[2] < output_bounds[2]; ++indices[2]) { - float* value = &result(indices[0], indices[1], indices[2]); - bool value_padded = false; - for (int i = 0; i < 3; ++i) { - bool in_low_padding = indices[i] < pad_low[i]; - bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i]; - if (in_low_padding || in_high_padding) { - *value = pad; - value_padded = true; - } - if (pad_interior[i] && - (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) { - *value = pad; - value_padded = true; - } - } - if (value_padded) { - continue; - } - *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1), - (indices[1] - pad_low[1]) / (pad_interior[1] + 1), - (indices[2] - pad_low[2]) / (pad_interior[2] + 1)); - } - } - } - return result; -} - -/* static */ Array4D ReferenceUtil::PadArray4D( - const Array4D& operand, const PaddingConfig& padding, - const float pad) { - CHECK_EQ(padding.dimensions_size(), 4); - - const std::vector input_bounds = {operand.n1(), operand.n2(), - operand.n3(), operand.n4()}; - std::vector pad_low(4); - std::vector pad_high(4); - std::vector pad_interior(4); - std::vector output_bounds(4); - for (int64 i = 0; i < 4; ++i) { - pad_low[i] = padding.dimensions(i).edge_padding_low(); - pad_high[i] = padding.dimensions(i).edge_padding_high(); - CHECK_LE(0, padding.dimensions(i).interior_padding()) << "not implemented"; - pad_interior[i] = padding.dimensions(i).interior_padding(); - - output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] + - (input_bounds[i] - 1) * pad_interior[i]; - } - - Array4D result(output_bounds[0], output_bounds[1], output_bounds[2], - output_bounds[3]); - result.Each([&](tensorflow::gtl::ArraySlice indices, float* value) { - for (int i = 0; i < 4; ++i) { - bool in_low_padding = indices[i] < pad_low[i]; - bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i]; - if (in_low_padding || in_high_padding) { - *value = pad; - return; - } - if (pad_interior[i] && - (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) { - *value = pad; - return; - } - } - *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1), - (indices[1] - pad_low[1]) / (pad_interior[1] + 1), - (indices[2] - pad_low[2]) / (pad_interior[2] + 1), - (indices[3] - pad_low[3]) / (pad_interior[3] + 1)); - }); - return result; -} - } // namespace xla diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index 2da17307817858eea60e868f4be1ab8138784385..58e1a844610678f64677838e93f0379b63f65d39 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -70,7 +70,7 @@ class ReferenceUtil { // dilation factors. static std::unique_ptr> ConvArray4DGeneralDimensionsDilated( const Array4D& lhs, const Array4D& rhs, - std::pair stride, Padding padding, + std::pair kernel_stride, Padding padding, std::pair lhs_dilation, std::pair rhs_dilation, ConvolutionDimensionNumbers dnums); @@ -184,6 +184,12 @@ class ReferenceUtil { const std::function& reduce_func, const tensorflow::gtl::ArraySlice& window, const tensorflow::gtl::ArraySlice& stride, Padding padding); + static std::unique_ptr> ReduceWindow1DGeneric( + const tensorflow::gtl::ArraySlice& operand, float init, + const std::function& reduce_func, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, + const tensorflow::gtl::ArraySlice>& padding); static std::unique_ptr> ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, @@ -486,19 +492,147 @@ class ReferenceUtil { } // Returns the result of a 2D pad on an input matrix. - static std::unique_ptr> PadArray2D( - const Array2D& operand, const PaddingConfig& padding, - const float pad); + template + static std::unique_ptr> PadArray2D( + const Array2D& operand, const PaddingConfig& padding, + const NativeT pad) { + int64 in0 = operand.n1(); + int64 high_padding0 = padding.dimensions(0).edge_padding_high(); + int64 low_padding0 = padding.dimensions(0).edge_padding_low(); + int64 interior_padding0 = padding.dimensions(0).interior_padding(); + int64 out0 = + in0 + low_padding0 + high_padding0 + (in0 - 1) * interior_padding0; + + int64 in1 = operand.n2(); + int64 high_padding1 = padding.dimensions(1).edge_padding_high(); + int64 low_padding1 = padding.dimensions(1).edge_padding_low(); + int64 interior_padding1 = padding.dimensions(1).interior_padding(); + int64 out1 = + in1 + low_padding1 + high_padding1 + (in1 - 1) * interior_padding1; + + auto result = MakeUnique>(out0, out1); + result->Fill(pad); + int64 o0 = low_padding0; + for (int64 i0 = 0; i0 < in0; ++i0) { + int64 o1 = low_padding1; + for (int64 i1 = 0; i1 < in1; ++i1) { + if (o0 >= 0 && o1 >= 0 && o0 < out0 && o1 < out1) { + (*result)(o0, o1) = operand(i0, i1); + } + o1 += interior_padding1 + 1; + } + o0 += interior_padding0 + 1; + } + return result; + } // Returns the result of a 3D pad on an input matrix. - static Array3D PadArray3D(const Array3D& operand, - const PaddingConfig& padding, - const float pad); + template + static Array3D PadArray3D(const Array3D& operand, + const PaddingConfig& padding, + const NativeT pad) { + CHECK_EQ(padding.dimensions_size(), 3); + + const std::vector input_bounds = {operand.n1(), operand.n2(), + operand.n3()}; + std::vector pad_low(3); + std::vector pad_high(3); + std::vector pad_interior(3); + std::vector output_bounds(3); + for (int64 i = 0; i < 3; ++i) { + pad_low[i] = padding.dimensions(i).edge_padding_low(); + pad_high[i] = padding.dimensions(i).edge_padding_high(); + CHECK_LE(0, pad_low[i]); + CHECK_LE(0, pad_high[i]); + CHECK_LE(0, padding.dimensions(i).interior_padding()) + << "not implemented"; + pad_interior[i] = padding.dimensions(i).interior_padding(); + + output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] + + (input_bounds[i] - 1) * pad_interior[i]; + } + + Array3D result(output_bounds[0], output_bounds[1], + output_bounds[2]); + std::vector indices = {0, 0, 0}; + for (indices[0] = 0; indices[0] < output_bounds[0]; ++indices[0]) { + for (indices[1] = 0; indices[1] < output_bounds[1]; ++indices[1]) { + for (indices[2] = 0; indices[2] < output_bounds[2]; ++indices[2]) { + NativeT* value = &result(indices[0], indices[1], indices[2]); + bool value_padded = false; + for (int i = 0; i < 3; ++i) { + bool in_low_padding = indices[i] < pad_low[i]; + bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i]; + if (in_low_padding || in_high_padding) { + *value = pad; + value_padded = true; + } + if (pad_interior[i] && + (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) { + *value = pad; + value_padded = true; + } + } + if (value_padded) { + continue; + } + *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1), + (indices[1] - pad_low[1]) / (pad_interior[1] + 1), + (indices[2] - pad_low[2]) / (pad_interior[2] + 1)); + } + } + } + return result; + } // Returns the result of a 4D pad on an input array. - static Array4D PadArray4D(const Array4D& operand, - const PaddingConfig& padding, - const float pad); + template + static Array4D PadArray4D(const Array4D& operand, + const PaddingConfig& padding, + const NativeT pad) { + CHECK_EQ(padding.dimensions_size(), 4); + + const std::vector input_bounds = {operand.n1(), operand.n2(), + operand.n3(), operand.n4()}; + std::vector pad_low(4); + std::vector pad_high(4); + std::vector pad_interior(4); + std::vector output_bounds(4); + for (int64 i = 0; i < 4; ++i) { + pad_low[i] = padding.dimensions(i).edge_padding_low(); + pad_high[i] = padding.dimensions(i).edge_padding_high(); + CHECK_LE(0, padding.dimensions(i).interior_padding()) + << "not implemented"; + pad_interior[i] = padding.dimensions(i).interior_padding(); + + output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] + + (input_bounds[i] - 1) * pad_interior[i]; + } + + Array4D result(output_bounds[0], output_bounds[1], + output_bounds[2], output_bounds[3]); + result.Each( + [&](tensorflow::gtl::ArraySlice indices, NativeT* value) { + for (int i = 0; i < 4; ++i) { + bool in_low_padding = indices[i] < pad_low[i]; + bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i]; + if (in_low_padding || in_high_padding) { + *value = pad; + return; + } + if (pad_interior[i] && + (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) { + *value = pad; + return; + } + } + *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1), + (indices[1] - pad_low[1]) / (pad_interior[1] + 1), + (indices[2] - pad_low[2]) / (pad_interior[2] + 1), + (indices[3] - pad_low[3]) / (pad_interior[3] + 1)); + }); + return result; + } // ApplyElementwise2D(f, x, y, ...) returns the Array2D formed by running // f(x[i], y[i], ...) for each array element in the Array2Ds x, y, .... diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc index eb6a71242ffa1499876b90f14f8a60ffdbdd069c..846ccdc83df900e3afedb6ababe07ebb1bd68f41 100644 --- a/tensorflow/compiler/xla/reference_util_test.cc +++ b/tensorflow/compiler/xla/reference_util_test.cc @@ -60,7 +60,9 @@ TEST_F(ReferenceUtilTest, TransposeArray2D) { TEST_F(ReferenceUtilTest, MatmulArray2D) { Array2D rhs({ - {7.f, 8.f}, {9.f, 10.f}, {11.f, 12.f}, + {7.f, 8.f}, + {9.f, 10.f}, + {11.f, 12.f}, }); auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs); auto actual_literal = Literal::CreateR2FromArray2D(*result); @@ -326,8 +328,10 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithSamePadding) { dimension_numbers.set_input_feature_dimension(0); dimension_numbers.set_output_batch_dimension(2); dimension_numbers.set_output_feature_dimension(0); - dimension_numbers.add_spatial_dimensions(1); - dimension_numbers.add_spatial_dimensions(3); + dimension_numbers.add_input_spatial_dimensions(1); + dimension_numbers.add_output_spatial_dimensions(1); + dimension_numbers.add_input_spatial_dimensions(3); + dimension_numbers.add_output_spatial_dimensions(3); dimension_numbers.set_kernel_output_feature_dimension(0); dimension_numbers.set_kernel_input_feature_dimension(2); dimension_numbers.add_kernel_spatial_dimensions(1); @@ -380,8 +384,10 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) { dimension_numbers.set_input_feature_dimension(0); dimension_numbers.set_output_batch_dimension(2); dimension_numbers.set_output_feature_dimension(0); - dimension_numbers.add_spatial_dimensions(1); - dimension_numbers.add_spatial_dimensions(3); + dimension_numbers.add_input_spatial_dimensions(1); + dimension_numbers.add_output_spatial_dimensions(1); + dimension_numbers.add_input_spatial_dimensions(3); + dimension_numbers.add_output_spatial_dimensions(3); dimension_numbers.set_kernel_output_feature_dimension(0); dimension_numbers.set_kernel_input_feature_dimension(2); diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index c4e5a7eaf34b4002c072cccf6d8e156f0a311a43..bbf6c128fb3b31657e97b608c56c27bc12045ac1 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -108,6 +108,7 @@ tf_cc_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -115,6 +116,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/service:hlo_element_type_converter", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep @@ -1009,9 +1011,9 @@ tf_cc_test( ) cc_library( - name = "batchnorm_rewriter", - srcs = ["batchnorm_rewriter.cc"], - hdrs = ["batchnorm_rewriter.h"], + name = "batchnorm_expander", + srcs = ["batchnorm_expander.cc"], + hdrs = ["batchnorm_expander.h"], deps = [ ":hlo", ":hlo_pass", @@ -1029,11 +1031,11 @@ cc_library( ) tf_cc_test( - name = "batchnorm_rewriter_test", + name = "batchnorm_expander_test", size = "small", - srcs = ["batchnorm_rewriter_test.cc"], + srcs = ["batchnorm_expander_test.cc"], deps = [ - ":batchnorm_rewriter", + ":batchnorm_expander", ":hlo", ":hlo_matchers", ":hlo_pass", @@ -1143,6 +1145,22 @@ tf_cc_test( ], ) +cc_library( + name = "dot_decomposer", + srcs = ["dot_decomposer.cc"], + hdrs = ["dot_decomposer.h"], + deps = [ + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + cc_library( name = "tuple_simplifier", srcs = ["tuple_simplifier.cc"], @@ -1304,6 +1322,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", ], ) @@ -1638,10 +1657,14 @@ cc_library( deps = [ ":buffer_liveness", ":hlo", + ":hlo_alias_analysis", + ":hlo_dce", + ":hlo_graph_dumper", + ":hlo_ordering", ":hlo_pass", ":liveness_util", ":logical_buffer", - ":tuple_points_to_analysis", + ":tuple_simplifier", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -1656,15 +1679,17 @@ tf_cc_test( deps = [ ":copy_insertion", ":hlo", + ":hlo_graph_dumper", ":hlo_matchers", - ":tuple_points_to_analysis", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", ], ) @@ -1696,6 +1721,22 @@ cc_library( ], ) +tf_cc_test( + name = "hlo_verifier_test", + srcs = ["hlo_verifier_test.cc"], + deps = [ + ":hlo", + ":hlo_verifier", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "hlo_rematerialization", srcs = ["hlo_rematerialization.cc"], @@ -1882,6 +1923,22 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_element_type_converter", + srcs = ["hlo_element_type_converter.cc"], + hdrs = ["hlo_element_type_converter.h"], + deps = [ + ":hlo", + ":hlo_evaluator", + ":hlo_pass", + ":hlo_query", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + ], +) + cc_library( name = "device_memory_allocator", srcs = ["device_memory_allocator.cc"], diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index bc9a3ac43db08d1dcca72d4df8235fbe6d7f19cc..7dc09a8cbd295753570b6e554d4211335617509e 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -180,19 +180,46 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { static bool Run( HloComputation* computation, bool is_layout_sensitive, AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_simplification, bool enable_conv_simplification); + bool enable_dot_strength_reduction, bool enable_conv_simplification); private: explicit AlgebraicSimplifierVisitor( HloComputation* computation, bool is_layout_sensitive, AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_simplification, bool enable_conv_simplification) + bool enable_dot_strength_reduction, bool enable_conv_simplification) : computation_(computation), is_layout_sensitive_(is_layout_sensitive), valid_bitcast_callback_(std::move(valid_bitcast_callback)), - enable_dot_simplification_(enable_dot_simplification), + enable_dot_strength_reduction_(enable_dot_strength_reduction), enable_conv_simplification_(enable_conv_simplification) {} + // Transforms Dots where at least one input is a vector or has a degenerate + // dimension and converts it into a multiply and reduce. This should enable + // more fusion than leaving the nodes as Dot operations. + StatusOr HandleDotStrengthReduction(HloInstruction* dot); + + // Reshapes an instruction to rank 1 if it is not already rank 1. + HloInstruction* Flatten(HloInstruction* hlo) { + if (ShapeUtil::Rank(hlo->shape()) == 1) { + return hlo; + } + return computation_->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(hlo->shape().element_type(), + {ShapeUtil::ElementsIn(hlo->shape())}), + hlo)); + } + + // Helper method to perform and add reduction in a single dimension. + HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { + HloInstruction* zero = computation_->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloComputation* AddReduce_computation = CreateScalarBinaryComputation( + computation_->parent(), F32, HloOpcode::kAdd); + Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape()); + return computation_->AddInstruction(HloInstruction::CreateReduce( + shape, hlo, zero, {dim}, AddReduce_computation)); + } + // Convenience method for replacing an instruction with a bitcast. void ReplaceWithBitcast(HloInstruction* instruction); @@ -252,6 +279,11 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + StatusOr OptimizeDotOfConcat(HloInstruction* dot); + StatusOr OptimizeDotOfConcatHelper( + const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim, + HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped); + // Current HloComputation instance the AlgebraicSimplifierVisitor is // traversing. HloComputation* computation_; @@ -265,8 +297,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Callback used to determine if a bitcast is possible. AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback_; - // Disable dot simplication on platforms where it causes a slowdown. - bool enable_dot_simplification_; + // Disable dot strength reduction on platforms where it causes a slowdown. + bool enable_dot_strength_reduction_; // Disable convolution simplication on platforms where it causes a slowdown. bool enable_conv_simplification_; @@ -275,10 +307,10 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { bool AlgebraicSimplifierVisitor::Run( HloComputation* computation, bool is_layout_sensitive, AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_simplification, bool enable_conv_simplification) { + bool enable_dot_strength_reduction, bool enable_conv_simplification) { AlgebraicSimplifierVisitor visitor( computation, is_layout_sensitive, std::move(valid_bitcast_callback), - enable_dot_simplification, enable_conv_simplification); + enable_dot_strength_reduction, enable_conv_simplification); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } @@ -574,68 +606,72 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { return Status::OK(); } -Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { - auto lhs = dot->mutable_operand(0); - auto rhs = dot->mutable_operand(1); - if (!enable_dot_simplification_) { - return Status::OK(); - } - // Only optimize F32 dot operations where the dot, rhs and lhs are rank 2 or - // below. - if (dot->shape().element_type() != F32 || ShapeUtil::Rank(lhs->shape()) > 2 || - ShapeUtil::Rank(rhs->shape()) > 2 || ShapeUtil::Rank(dot->shape()) > 2) { - return Status::OK(); - } - - // Replace a zero element dot with a broadcast of the constant 0. - if (ShapeUtil::HasZeroElements(dot->shape()) || - ShapeUtil::HasZeroElements(lhs->shape()) || - ShapeUtil::HasZeroElements(rhs->shape())) { - auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - return ReplaceWithNewInstruction( - dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {})); - } - - // Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)). - if (lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) { - auto new_dot = computation_->AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::PermuteDimensions({1, 0}, dot->shape()), HloOpcode::kDot, - rhs->mutable_operand(0), lhs->mutable_operand(0))); - return ReplaceWithNewInstruction( - dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0})); - } +StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( + HloInstruction* dot) { + HloInstruction* lhs = dot->mutable_operand(0); + HloInstruction* rhs = dot->mutable_operand(1); + int64 lhs_collapsing_dim = + dot->dot_dimension_numbers().lhs_contracting_dimensions(0); + if (lhs->IsRank2Transpose()) { + lhs = lhs->mutable_operand(0); + lhs_collapsing_dim = 1 - lhs_collapsing_dim; + } + const int64 lhs_kept_dim = 1 - lhs_collapsing_dim; + + int64 rhs_collapsing_dim = + dot->dot_dimension_numbers().rhs_contracting_dimensions(0); + if (rhs->IsRank2Transpose()) { + rhs = rhs->mutable_operand(0); + rhs_collapsing_dim = 1 - rhs_collapsing_dim; + } + const int64 rhs_kept_dim = 1 - rhs_collapsing_dim; + + auto reshape_if_necessary = [&](HloInstruction* hlo) { + if (ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) { + return hlo; + } + return computation_->AddInstruction( + HloInstruction::CreateReshape(dot->shape(), hlo)); + }; - // Simplify outer product into multiply with implicit broadcasting. - // - // A dot(a[M, 1], b[1, N]) = multiply(a [M,1], b [1, N]) - if (ShapeUtil::Rank(rhs->shape()) == 2 && rhs->shape().dimensions(0) == 1) { - return ReplaceWithNewInstruction( - dot, HloInstruction::CreateBinary(dot->shape(), HloOpcode::kMultiply, - lhs, rhs)); - } + auto broadcast_to_dim = [&](HloInstruction* hlo, const Shape& shape, + int64 dim) { + return computation_->AddInstruction( + HloInstruction::CreateBroadcast(shape, hlo, {dim})); + }; - // The following graph transformations take Dots where at least one input is a - // vector or has a degenerate dimension and converts it into a multiply and - // reduce. This should enable more fusion than leaving the nodes as Dot - // operations. + auto multiply = [&](HloInstruction* local_lhs, HloInstruction* local_rhs) { + return computation_->AddInstruction(HloInstruction::CreateBinary( + local_lhs->shape(), HloOpcode::kMultiply, local_lhs, local_rhs)); + }; // Strength reduce dot(a[K] , b[K]) = // reshape(result.shape, // reduce_sum(multiply(a, b), {0})) if (ShapeUtil::Rank(rhs->shape()) == 1 && ShapeUtil::Rank(lhs->shape()) == 1) { - auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary( - rhs->shape(), HloOpcode::kMultiply, lhs, rhs)); - HloComputation* add_reduce_computation = CreateScalarBinaryComputation( - computation_->parent(), F32, HloOpcode::kAdd); - auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce( - ShapeUtil::MakeShape(dot->shape().element_type(), {}), multiply, zero, - {0}, add_reduce_computation)); - return ReplaceWithNewInstruction( - dot, HloInstruction::CreateReshape(dot->shape(), reduce)); + TF_RETURN_IF_ERROR( + ReplaceInstruction(dot, reshape_if_necessary(AddReduce( + multiply(Flatten(lhs), Flatten(rhs)), 0)))); + return true; + } + + if (ShapeUtil::IsEffectiveScalar(rhs->shape()) && + ShapeUtil::IsEffectiveScalar(lhs->shape())) { + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary(multiply(Flatten(lhs), Flatten(rhs))))); + return true; + } + + // Simplify outer product into multiply with implicit broadcasting. + // + // A dot(a[M, 1], b[1, N]) = multiply(a [M,1], b [1, N]) + if (ShapeUtil::Rank(rhs->shape()) == 2 && + rhs->shape().dimensions(rhs_collapsing_dim) == 1) { + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, multiply(broadcast_to_dim(Flatten(lhs), dot->shape(), 0), + broadcast_to_dim(Flatten(rhs), dot->shape(), 1)))); + return true; } // Strength reduce dot(a[1, K], b) = @@ -646,35 +682,21 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { // ) // ) if (ShapeUtil::Rank(lhs->shape()) == 1 || - (ShapeUtil::Rank(lhs->shape()) == 2 && lhs->shape().dimensions(0) == 1)) { - auto new_lhs = computation_->AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(lhs->shape().element_type(), - {ShapeUtil::ElementsIn(lhs->shape())}), - lhs)); - HloComputation* add_reduce_computation = CreateScalarBinaryComputation( - computation_->parent(), F32, HloOpcode::kAdd); - auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - HloInstruction* reduce; + (ShapeUtil::Rank(lhs->shape()) == 2 && + lhs->shape().dimensions(lhs_kept_dim) == 1)) { if (ShapeUtil::Rank(rhs->shape()) == 1) { - auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary( - rhs->shape(), HloOpcode::kMultiply, new_lhs, rhs)); - reduce = computation_->AddInstruction(HloInstruction::CreateReduce( - ShapeUtil::MakeShape(dot->shape().element_type(), {}), multiply, zero, - {0}, add_reduce_computation)); - } else { - new_lhs = computation_->AddInstruction( - HloInstruction::CreateBroadcast(rhs->shape(), new_lhs, {0})); - auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary( - rhs->shape(), HloOpcode::kMultiply, new_lhs, rhs)); - - reduce = computation_->AddInstruction(HloInstruction::CreateReduce( - ShapeUtil::MakeShape(dot->shape().element_type(), - {rhs->shape().dimensions(1)}), - multiply, zero, {0}, add_reduce_computation)); + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, + reshape_if_necessary(AddReduce(multiply(Flatten(lhs), rhs), 0)))); + return true; } - return ReplaceWithNewInstruction( - dot, HloInstruction::CreateReshape(dot->shape(), reduce)); + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary( + AddReduce(multiply(broadcast_to_dim(Flatten(lhs), rhs->shape(), + rhs_collapsing_dim), + rhs), + rhs_collapsing_dim)))); + return true; } // Strength reduce dot(a, b[K, 1]) = @@ -682,26 +704,208 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { // reduce_sum(multiply(a, broadcast(reshape([K],b), {1})), {0}) // ) if (ShapeUtil::Rank(rhs->shape()) == 1 || - (ShapeUtil::Rank(rhs->shape()) == 2 && rhs->shape().dimensions(1) == 1)) { - auto new_rhs = computation_->AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(rhs->shape().element_type(), - {ShapeUtil::ElementsIn(rhs->shape())}), - rhs)); - new_rhs = computation_->AddInstruction( - HloInstruction::CreateBroadcast(lhs->shape(), new_rhs, {1})); - auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary( - lhs->shape(), HloOpcode::kMultiply, lhs, new_rhs)); - HloComputation* add_reduce_computation = CreateScalarBinaryComputation( - computation_->parent(), F32, HloOpcode::kAdd); + (ShapeUtil::Rank(rhs->shape()) == 2 && + rhs->shape().dimensions(rhs_kept_dim) == 1)) { + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary(AddReduce( + multiply(lhs, broadcast_to_dim(Flatten(rhs), lhs->shape(), + lhs_collapsing_dim)), + lhs_collapsing_dim)))); + return true; + } + return false; +} + +StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcat( + HloInstruction* dot) { + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); + if (dnums.lhs_contracting_dimensions_size() != 1 || + dnums.lhs_batch_dimensions_size() != 0) { + return nullptr; + } + + const int64 lhs_contracting_dim = dnums.lhs_contracting_dimensions(0); + const int64 rhs_contracting_dim = dnums.rhs_contracting_dimensions(0); + HloInstruction* lhs = dot->mutable_operand(0); + HloInstruction* rhs = dot->mutable_operand(1); + + TF_ASSIGN_OR_RETURN( + HloInstruction * optimized_lhs_concat, + OptimizeDotOfConcatHelper(dot->shape(), lhs, lhs_contracting_dim, rhs, + rhs_contracting_dim, /*swapped=*/false)); + if (optimized_lhs_concat) { + return optimized_lhs_concat; + } + + return OptimizeDotOfConcatHelper(dot->shape(), rhs, rhs_contracting_dim, lhs, + lhs_contracting_dim, /*swapped=*/true); +} + +StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( + const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim, + HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped) { + bool can_optimize = lhs->opcode() == HloOpcode::kConcatenate && + lhs->concatenate_dimension() == lhs_contracting_dim && + rhs->opcode() == HloOpcode::kConstant; + if (!can_optimize) { + return nullptr; + } + + // We're replacing this: + // + // +-----+-----+-----+ +-------------------+ + // | | | | | | + // | | | | | R_0 | + // | | | | | | + // | | | | +-------------------+ + // | | | | | | + // | L_0 | L_1 | L_2 | * | R_1 | + // | | | | | | + // | | | | +-------------------+ + // | | | | | | + // | | | | | R_2 | + // | | | | | | + // +-----+-----+-----+ +-------------------+ + // + // with this: + // + // [Sum over i] + // + // +-----+ +-------------------+ + // | | | | + // | | * | R_i | + // | | | | + // | | +-------------------+ + // | | + // | L_i | + // | | + // | | + // | | + // | | + // | | + // +-----+ + // + // where the LHS is a concatenate operation (so we can "split" the LHS tensor + // for free) and the RHS is a constant tensor (and thus can be split at + // compile time). In the future, we may also want to do this when both the + // LHS and the RHS are concatenate operations that line up along the dimension + // being contracted over. + // + // We should be able to generalize this transform to work on a non-constant + // RHS when/if we have in-place slices or support input-fusing slices into + // Dots. + + // Dimension numbers for the new dot instructions we'll create (L_i * R_i in + // the diagram above). + DotDimensionNumbers new_dot_dnums; + new_dot_dnums.add_lhs_contracting_dimensions(swapped ? rhs_contracting_dim + : lhs_contracting_dim); + new_dot_dnums.add_rhs_contracting_dimensions(swapped ? lhs_contracting_dim + : rhs_contracting_dim); + + // Here we use the MKN notation, where the contracted dimension has K + // elements and the two non-contracted dimensions have M and N elements. + HloInstruction* add_result = nullptr; + int64 rhs_contracting_dim_offset = 0; + int64 n = rhs->shape().dimensions(1 - rhs_contracting_dim); + for (HloInstruction* concat_op : lhs->operands()) { + int64 sub_k = concat_op->shape().dimensions(lhs_contracting_dim); + Shape rhs_slice_shape(rhs->shape()); + rhs_slice_shape.set_dimensions(rhs_contracting_dim, sub_k); + + std::array start_indices; + start_indices[rhs_contracting_dim] = rhs_contracting_dim_offset; + start_indices[1 - rhs_contracting_dim] = 0; + + std::array limit_indices; + limit_indices[rhs_contracting_dim] = rhs_contracting_dim_offset + sub_k; + limit_indices[1 - rhs_contracting_dim] = n; + + HloInstruction* rhs_slice = + computation_->AddInstruction(HloInstruction::CreateSlice( + rhs_slice_shape, rhs, /*start_indices=*/start_indices, + /*limit_indices=*/limit_indices, /*strides=*/{1, 1})); + + // TODO(b/69062148): We can get rid of `swapped` once all backends support + // "non-canonical" contraction dimensions (that contracts dimension 1 of the + // LHS with dimension 0 of the RHS). But for now we keep the same + // contraction dimensions as the incoming dot operation to ensure the new + // dot operations can be lowered. + HloInstruction *new_dot_lhs, *new_dot_rhs; + if (swapped) { + new_dot_lhs = rhs_slice; + new_dot_rhs = concat_op; + } else { + new_dot_lhs = concat_op; + new_dot_rhs = rhs_slice; + } + + auto* new_dot = computation_->AddInstruction(HloInstruction::CreateDot( + dot_shape, new_dot_lhs, new_dot_rhs, new_dot_dnums)); + + if (add_result) { + add_result = computation_->AddInstruction(HloInstruction::CreateBinary( + dot_shape, HloOpcode::kAdd, add_result, new_dot)); + } else { + add_result = new_dot; + } + + rhs_contracting_dim_offset += sub_k; + } + + return add_result; +} + +Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { + auto lhs = dot->mutable_operand(0); + auto rhs = dot->mutable_operand(1); + + // Only optimize F32 dot operations where the dot, rhs and lhs are rank 2 or + // below. + if (dot->shape().element_type() != F32 || ShapeUtil::Rank(lhs->shape()) > 2 || + ShapeUtil::Rank(rhs->shape()) > 2 || ShapeUtil::Rank(dot->shape()) > 2) { + return Status::OK(); + } + + // Replace a zero element dot with a broadcast of the constant 0. + if (ShapeUtil::HasZeroElements(dot->shape()) || + ShapeUtil::HasZeroElements(lhs->shape()) || + ShapeUtil::HasZeroElements(rhs->shape())) { auto zero = computation_->AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce( - ShapeUtil::MakeShape(dot->shape().element_type(), - {lhs->shape().dimensions(0)}), - multiply, zero, {1}, add_reduce_computation)); return ReplaceWithNewInstruction( - dot, HloInstruction::CreateReshape(dot->shape(), reduce)); + dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {})); } + + TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_concat_optimized, + OptimizeDotOfConcat(dot)); + if (dot_of_concat_optimized) { + VLOG(10) << "Replaced dot(concat(...), constant) with add(dot(..., " + "constant)...)"; + return ReplaceInstruction(dot, dot_of_concat_optimized); + } + + if (enable_dot_strength_reduction_ && !is_layout_sensitive_) { + TF_ASSIGN_OR_RETURN(bool did_strength_reduction, + HandleDotStrengthReduction(dot)); + if (did_strength_reduction) { + return Status::OK(); + } + } + + // Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)). + if (lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) { + DotDimensionNumbers dot_dimension_numbers; + dot_dimension_numbers.add_lhs_contracting_dimensions(1); + dot_dimension_numbers.add_rhs_contracting_dimensions(0); + auto new_dot = computation_->AddInstruction(HloInstruction::CreateDot( + ShapeUtil::PermuteDimensions({1, 0}, dot->shape()), + rhs->mutable_operand(0), lhs->mutable_operand(0), + dot_dimension_numbers)); + return ReplaceWithNewInstruction( + dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0})); + } + return Status::OK(); } @@ -1108,10 +1312,37 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { if (IsAll(rhs, -1)) { auto* one = computation_->AddInstruction(HloInstruction::CreateConstant( Literal::One(rhs->shape().element_type()).CloneToUnique())); + + // Explicitly broadcast scalar 1 to the output shape, to avoid implicit + // broadcast in divide HLO as we are trying to eliminate implicit + // broadcasting at HLO level. + auto* broadcast_one = computation_->AddInstruction( + HloInstruction::CreateBroadcast(power->shape(), one, {})); return ReplaceWithNewInstruction( power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kDivide, - one, lhs)); + broadcast_one, lhs)); } + + VLOG(10) << "trying transform [pow(pow(A, X), Y) => pow(A, X*Y)]: " + << power->ToString(); + + // Don't perform this optimization if either of the exponents is complex; this + // identity is true only for real-valued exponents. In addition, we cowardly + // refuse to do this transformation if the two expontents have different + // element types. + if (lhs->opcode() == HloOpcode::kPower && + !ShapeUtil::ElementIsComplex(lhs->operand(1)->shape()) && + !ShapeUtil::ElementIsComplex(rhs->shape()) && + ShapeUtil::SameElementType(lhs->operand(1)->shape(), rhs->shape())) { + auto exponent_product = + computation_->AddInstruction(HloInstruction::CreateBinary( + rhs->shape(), HloOpcode::kMultiply, lhs->mutable_operand(1), rhs)); + return ReplaceWithNewInstruction( + power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kPower, + lhs->mutable_operand(0), + exponent_product)); + } + return Status::OK(); } @@ -1165,7 +1396,7 @@ StatusOr AlgebraicSimplifierVisitor:: ShapeUtil::MakeShapeWithLayout( user->shape().element_type(), AsInt64Slice(operand->shape().dimensions()), - AsInt64Slice(operand->shape().layout().minor_to_major())), + LayoutUtil::MinorToMajor(operand->shape())), new_user_operands)); VLOG(4) << " new user: " << new_user->ToString(); HloInstruction* new_reshape_or_broadcast = nullptr; @@ -1175,8 +1406,7 @@ StatusOr AlgebraicSimplifierVisitor:: ShapeUtil::MakeShapeWithLayout( user->shape().element_type(), AsInt64Slice(reshape_or_broadcast->shape().dimensions()), - AsInt64Slice( - reshape_or_broadcast->shape().layout().minor_to_major())), + LayoutUtil::MinorToMajor(reshape_or_broadcast->shape())), new_user)); } else { TF_RET_CHECK(reshape_or_broadcast->opcode() == HloOpcode::kBroadcast); @@ -1185,8 +1415,7 @@ StatusOr AlgebraicSimplifierVisitor:: ShapeUtil::MakeShapeWithLayout( user->shape().element_type(), AsInt64Slice(reshape_or_broadcast->shape().dimensions()), - AsInt64Slice( - reshape_or_broadcast->shape().layout().minor_to_major())), + LayoutUtil::MinorToMajor(reshape_or_broadcast->shape())), new_user, reshape_or_broadcast->dimensions())); } VLOG(4) << " new reshape/broadcast: " @@ -1398,6 +1627,15 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( auto operand = reduce_window->mutable_operand(0); const Window& window = reduce_window->window(); auto function = reduce_window->to_apply(); + if (ShapeUtil::IsScalar(operand->shape())) { + TF_RET_CHECK(ShapeUtil::IsScalar(reduce_window->shape())); + return ReplaceWithNewInstruction( + reduce_window, + HloInstruction::CreateMap(reduce_window->shape(), + {operand, reduce_window->mutable_operand(1)}, + function)); + } + VLOG(10) << "Considering folding Pad: " << operand->ToString() << "\ninto reduce-window: " << reduce_window->ToString(); @@ -1539,15 +1777,15 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( // still convert Conv into more efficient Matmul with operand transposition // (such as the transposition flags in cuBLAS SGEMM). if (!LayoutUtil::Equal(input_shape.layout(), convolution_shape.layout()) || - input_shape.layout().minor_to_major(0) != + LayoutUtil::Minor(input_shape.layout(), 0) != dnums.input_feature_dimension() || - convolution_shape.layout().minor_to_major(0) != + LayoutUtil::Minor(convolution_shape.layout(), 0) != dnums.output_feature_dimension() || // The input feature dimension should come later in the minor-to-major // order. - (PositionInContainer(filter_shape.layout().minor_to_major(), + (PositionInContainer(LayoutUtil::MinorToMajor(filter_shape), dnums.kernel_input_feature_dimension()) < - PositionInContainer(filter_shape.layout().minor_to_major(), + PositionInContainer(LayoutUtil::MinorToMajor(filter_shape), dnums.kernel_output_feature_dimension()))) { return Status::OK(); } @@ -1599,8 +1837,11 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( auto new_lhs = add_bitcast(new_input_shape, lhs); auto new_rhs = add_bitcast(new_filter_shape, rhs); - auto dot = computation_->AddInstruction(HloInstruction::CreateBinary( - dot_output_shape, HloOpcode::kDot, new_lhs, new_rhs)); + DotDimensionNumbers dot_dimension_numbers; + dot_dimension_numbers.add_lhs_contracting_dimensions(1); + dot_dimension_numbers.add_rhs_contracting_dimensions(0); + auto dot = computation_->AddInstruction(HloInstruction::CreateDot( + dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers)); return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)); } @@ -1688,7 +1929,7 @@ StatusOr AlgebraicSimplifier::Run(HloModule* module) { for (auto* comp : module->MakeNonfusionComputations()) { if (AlgebraicSimplifierVisitor::Run( comp, is_layout_sensitive_, valid_bitcast_callback_, - enable_dot_simplification_, enable_conv_simplification_)) { + enable_dot_strength_reduction_, enable_conv_simplification_)) { changed = true; } } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index a9f476178c7af74c275a10de7727ea64e17d590f..43315f5cdc7afbe79039420320f4a0d0535e11f1 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -40,11 +40,11 @@ class AlgebraicSimplifier : public HloPassInterface { // bitcasts. AlgebraicSimplifier(bool is_layout_sensitive, ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_simplification = true, + bool enable_dot_strength_reduction = true, bool enable_conv_simplification = true) : is_layout_sensitive_(is_layout_sensitive), valid_bitcast_callback_(std::move(valid_bitcast_callback)), - enable_dot_simplification_(enable_dot_simplification), + enable_dot_strength_reduction_(enable_dot_strength_reduction), enable_conv_simplification_(enable_conv_simplification) {} ~AlgebraicSimplifier() override = default; tensorflow::StringPiece name() const override { return "algsimp"; } @@ -58,7 +58,7 @@ class AlgebraicSimplifier : public HloPassInterface { ValidBitcastCallback valid_bitcast_callback_; // Enable dot simplication on platforms where it is profitable. - bool enable_dot_simplification_; + bool enable_dot_strength_reduction_; // Enable convolution simplication on platforms where it is profitable. bool enable_conv_simplification_; diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 620f0a54fa03e7239809e9f910893d887f9ff149..d4739ca113a4094d6d98a6a9d45fbb14cbd124c5 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -327,6 +327,55 @@ TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { EXPECT_EQ(0, negate_shape.dimensions_size()); } +// pow(pow(A, X), Y) => pow(A, X*Y) +TEST_F(AlgebraicSimplifierTest, PowerOfPower) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + Shape r1f32 = ShapeUtil::MakeShape(F32, {7}); + HloComputation::Builder builder(TestName()); + HloInstruction* base = builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "param0")); + HloInstruction* exp1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32, "param1")); + HloInstruction* exp2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, r0f32, "param2")); + HloInstruction* inner_power = builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, base, exp1)); + builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, + inner_power, exp2)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), + op::Power(base, op::Multiply(exp1, exp2))); +} + +// Don't simplify pow(pow(A, X), Y) => pow(A, X*Y) if X and Y are complex +// numbers. +TEST_F(AlgebraicSimplifierTest, PowerOfPowerComplex) { + Shape r0c64 = ShapeUtil::MakeShape(C64, {}); + Shape r1f32 = ShapeUtil::MakeShape(F32, {7}); + HloComputation::Builder builder(TestName()); + HloInstruction* base = builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "param0")); + HloInstruction* exp1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0c64, "param1")); + HloInstruction* exp2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, r0c64, "param2")); + HloInstruction* inner_power = builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, base, exp1)); + builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, + inner_power, exp2)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); +} + // Test that A/1 is simplified to A for a scalar. TEST_F(AlgebraicSimplifierTest, DivOneScalar) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); @@ -761,8 +810,10 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Divide(op::Constant(), param0)); - EXPECT_EQ(root->operand(0)->literal().GetFirstElement(), 1); + EXPECT_THAT(root, op::Divide(op::Broadcast(), param0)); + EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kBroadcast); + EXPECT_EQ(root->operand(0)->operand(0)->literal().GetFirstElement(), + 1); } TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { @@ -1622,8 +1673,11 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { ConvolutionDimensionNumbers dnums; std::vector in_dims; int in_channel_idx = -1; - dnums.add_spatial_dimensions(-1); // filled in later - dnums.add_spatial_dimensions(-1); // filled in later + // filled in later + dnums.add_input_spatial_dimensions(-1); + dnums.add_output_spatial_dimensions(-1); + dnums.add_input_spatial_dimensions(-1); + dnums.add_output_spatial_dimensions(-1); for (int i = 0; i < strlen(options.dim_order); ++i) { char ch = options.dim_order[i]; if (ch == 'N') { @@ -1631,10 +1685,12 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { dnums.set_output_batch_dimension(i); in_dims.push_back(options.in_batch); } else if (ch == 'H') { - dnums.set_spatial_dimensions(0, i); + dnums.set_input_spatial_dimensions(0, i); + dnums.set_output_spatial_dimensions(0, i); in_dims.push_back(options.in_height); } else if (ch == 'W') { - dnums.set_spatial_dimensions(1, i); + dnums.set_input_spatial_dimensions(1, i); + dnums.set_output_spatial_dimensions(1, i); in_dims.push_back(options.in_width); } else if (ch == 'C') { dnums.set_input_feature_dimension(i); @@ -2131,8 +2187,10 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x")); HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y")); - builder.AddInstruction( - HloInstruction::CreateBinary(r1f32, HloOpcode::kDot, x, y)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums)); std::unique_ptr dot_computation(builder.Build()); HloComputation::Builder call_builder(TestName() + ".Call"); @@ -2229,5 +2287,210 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { op::DynamicSlice(op::Parameter(), op::Parameter())); } +class DotStrengthReductionTest + : public AlgebraicSimplifierTest, + public ::testing::WithParamInterface< + ::testing::tuple> {}; +TEST_P(DotStrengthReductionTest, DotStrengthReduction) { + int m, k, n; + bool transpose_lhs, transpose_rhs; + std::tie(m, k, n, transpose_lhs, transpose_rhs) = GetParam(); + + Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n}); + Shape lhs_shape = ShapeUtil::MakeShape(F32, {m, k}); + Shape transposed_lhs_shape = ShapeUtil::MakeShape(F32, {k, m}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {k, n}); + Shape transposed_rhs_shape = ShapeUtil::MakeShape(F32, {n, k}); + HloComputation::Builder builder(TestName()); + + auto lhs = builder.AddInstruction(HloInstruction::CreateParameter( + 0, transpose_lhs ? transposed_lhs_shape : lhs_shape, "lhs")); + if (transpose_lhs) { + lhs = builder.AddInstruction( + HloInstruction::CreateTranspose(lhs_shape, lhs, {1, 0})); + } + auto rhs = builder.AddInstruction(HloInstruction::CreateParameter( + 1, transpose_rhs ? transposed_rhs_shape : rhs_shape, "rhs")); + if (transpose_rhs) { + rhs = builder.AddInstruction( + HloInstruction::CreateTranspose(rhs_shape, rhs, {1, 0})); + } + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + builder.AddInstruction( + HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get())); + const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1; + const bool computation_should_be_modified = + dot_should_be_transformed || (transpose_lhs && transpose_rhs); + EXPECT_EQ(changed, computation_should_be_modified); + bool has_no_dot = true; + for (const auto& hlo : computation->instructions()) { + if (hlo->opcode() == HloOpcode::kDot) { + has_no_dot = false; + break; + } + } + EXPECT_EQ(has_no_dot, dot_should_be_transformed); +} + +INSTANTIATE_TEST_CASE_P( + DotStrengthReductionTestInstantiation, DotStrengthReductionTest, + ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2), + ::testing::Values(1, 2), ::testing::Bool(), + ::testing::Bool())); + +struct DotOfConcatTestSpec { + int64 m; + int64 k; + int64 n; +}; + +class DotOfConcatSimplificationTest + : public HloTestBase, + public ::testing::WithParamInterface {}; + +// Test that we transform +// dot(const, concat(A, B, C)) +// to +// add(dot(const_0, A), dot(const_1, B), dot(const_2, C)) +TEST_P(DotOfConcatSimplificationTest, ConstantLHS) { + HloComputation::Builder builder(TestName()); + + DotOfConcatTestSpec spec = GetParam(); + + ASSERT_GE(spec.k, 3); + + int64 k0 = spec.k / 3; + int64 k1 = spec.k / 3; + int64 k2 = spec.k - k0 - k1; + + Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k}); + auto* lhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.m, /*cols=*/spec.k))); + + Shape rhs0_shape = ShapeUtil::MakeShape(F32, {k0, spec.n}); + Shape rhs1_shape = ShapeUtil::MakeShape(F32, {k1, spec.n}); + Shape rhs2_shape = ShapeUtil::MakeShape(F32, {k2, spec.n}); + + HloInstruction* rhs0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, rhs0_shape, "rhs0")); + HloInstruction* rhs1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, rhs1_shape, "rhs1")); + HloInstruction* rhs2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, rhs2_shape, "rhs2")); + + Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n}); + HloInstruction* rhs = builder.AddInstruction( + HloInstruction::CreateConcatenate(rhs_shape, {rhs0, rhs1, rhs2}, 0)); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + + Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n}); + builder.AddInstruction( + HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(module.get())); + ASSERT_TRUE(run_successful); + + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); + + auto match_dot_0 = op::Dot(op::Slice(op::Constant()), op::Parameter(0)); + auto match_dot_1 = op::Dot(op::Slice(op::Constant()), op::Parameter(1)); + auto match_dot_2 = op::Dot(op::Slice(op::Constant()), op::Parameter(2)); + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Add(match_dot_0, match_dot_1), match_dot_2)); +} + +// Test that we transform +// dot(concat(A, B, C), const) +// to +// add(dot(A, const_0), dot(B, const_1), dot(C, const_2)) +TEST_P(DotOfConcatSimplificationTest, ConstantRHS) { + HloComputation::Builder builder(TestName()); + + DotOfConcatTestSpec spec = GetParam(); + + ASSERT_GE(spec.k, 4); + + int64 k0 = spec.k / 4; + int64 k1 = spec.k / 4; + int64 k2 = spec.k / 4; + int64 k3 = spec.k - k0 - k1 - k2; + + Shape lhs0_shape = ShapeUtil::MakeShape(F32, {spec.m, k0}); + Shape lhs1_shape = ShapeUtil::MakeShape(F32, {spec.m, k1}); + Shape lhs2_shape = ShapeUtil::MakeShape(F32, {spec.m, k2}); + Shape lhs3_shape = ShapeUtil::MakeShape(F32, {spec.m, k3}); + + HloInstruction* lhs0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, lhs0_shape, "lhs0")); + HloInstruction* lhs1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, lhs1_shape, "lhs1")); + HloInstruction* lhs2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, lhs2_shape, "lhs2")); + HloInstruction* lhs3 = builder.AddInstruction( + HloInstruction::CreateParameter(3, lhs2_shape, "lhs3")); + + Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k}); + HloInstruction* lhs = + builder.AddInstruction(HloInstruction::CreateConcatenate( + lhs_shape, {lhs0, lhs1, lhs2, lhs3}, 1)); + + Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.m}); + auto* rhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.k, /*cols=*/spec.m))); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + + Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n}); + builder.AddInstruction( + HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(module.get())); + ASSERT_TRUE(run_successful); + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); + + auto match_dot_0 = op::Dot(op::Parameter(0), op::Slice(op::Constant())); + auto match_dot_1 = op::Dot(op::Parameter(1), op::Slice(op::Constant())); + auto match_dot_2 = op::Dot(op::Parameter(2), op::Slice(op::Constant())); + auto match_dot_3 = op::Dot(op::Parameter(3), op::Slice(op::Constant())); + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Add(op::Add(match_dot_0, match_dot_1), match_dot_2), + match_dot_3)); +} + +DotOfConcatTestSpec kDotOfConcatTestSpecs[] = { + {/*m=*/3, /*k=*/9, /*n=*/3}, // + {/*m=*/3, /*k=*/20, /*n=*/3}, // + {/*m=*/1, /*k=*/18, /*n=*/5}, // + {/*m=*/20, /*k=*/20, /*n=*/1}, // + {/*m=*/1, /*k=*/16, /*n=*/1}, // +}; + +INSTANTIATE_TEST_CASE_P(DotOfConcatSimplificationTestInstantiation, + DotOfConcatSimplificationTest, + ::testing::ValuesIn(kDotOfConcatTestSpecs)); } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index ad2fee2d39a8ca183b87212bdeea22c351aaa88a..b69a6e730fc65b2e590f22115569ce27145bf6ab 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -27,191 +27,163 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace se = ::perftools::gputools; namespace xla { -AllocationTracker::AllocationTracker() : next_handle_(1) {} - -GlobalDataHandle AllocationTracker::Register(Backend* backend, - int device_ordinal, - se::DeviceMemoryBase device_memory, - const Shape& shape, - const string& tag) { - tensorflow::mutex_lock lock(allocation_mutex_); +StatusOr AllocationTracker::Register( + std::unique_ptr shaped_buffer, const string& tag) { + tensorflow::mutex_lock lock(mutex_); VLOG(2) << "Register"; - return RegisterInternal(backend, device_ordinal, device_memory, shape, tag, - /*initial_ref_count=*/1); + return RegisterInternal(std::move(shaped_buffer), tag); } -GlobalDataHandle AllocationTracker::RegisterInternal( - Backend* backend, int device_ordinal, se::DeviceMemoryBase device_memory, - const Shape& shape, const string& tag, int initial_ref_count) { +StatusOr AllocationTracker::RegisterInternal( + std::unique_ptr shaped_buffer, const string& tag) { VLOG(2) << "RegisterInternal(" << "tag: \"" << tag << "\" " - << "device_ordinal: " << device_ordinal << " " - << "device_memory: " << device_memory.opaque() << " " - << "shape: " << shape.ShortDebugString() << ")"; - TF_CHECK_OK(ShapeUtil::ValidateShape(shape)); - - int64 handle; - HandleMap& handle_map = GetOrCreateOpaqueToHandleMap(device_ordinal); - auto handle_it = handle_map.find(device_memory.opaque()); - if (handle_it != handle_map.end()) { - handle = handle_it->second; - auto& allocation = FindOrDie(handle_to_allocation_, handle); - int ref_count = allocation->ref_count(); - CHECK_GT(ref_count, 0); - VLOG(2) << "ref_count: " << ref_count << " -> " << - (ref_count + initial_ref_count); - allocation->increment_ref_count(initial_ref_count); - } else { - handle = next_handle_++; - VLOG(2) << "ref_count: " << initial_ref_count; - InsertOrDie(&handle_map, device_memory.opaque(), handle); - auto inserted = handle_to_allocation_.emplace( - handle, MakeUnique(backend, device_ordinal, device_memory, - shape, tag, initial_ref_count)); - CHECK(inserted.second); + << "shaped_buffer: " << *shaped_buffer; + if (shaped_buffer->platform() != backend_->platform()) { + return InvalidArgument( + "AllocationTracker for platform %s cannot register buffer from " + "platform %s", + backend_->platform()->Name().c_str(), + shaped_buffer->platform()->Name().c_str()); } + int64 handle = next_handle_++; + std::vector shape_indices; + ShapeUtil::ForEachSubshape(shaped_buffer->on_device_shape(), + [this, &shape_indices](const Shape& /*subshape*/, + const ShapeIndex& index) { + shape_indices.push_back(index); + }); + for (const ShapeIndex& index : shape_indices) { + AddAllocationOrIncrementRefCount(shaped_buffer->buffer(index), + shaped_buffer->device_ordinal()); + } GlobalDataHandle result; result.set_handle(handle); + + handle_to_shaped_buffer_[handle] = std::move(shaped_buffer); + VLOG(2) << "handle: " << handle; return result; } tensorflow::Status AllocationTracker::Unregister(const GlobalDataHandle& data) { - tensorflow::mutex_lock lock(allocation_mutex_); - TF_ASSIGN_OR_RETURN(Allocation * allocation, ResolveInternal(data)); - std::set deallocated_buffers; - TF_RETURN_IF_ERROR( - DeallocateShape(allocation->backend(), allocation->device_ordinal(), - allocation->mutable_device_memory(), allocation->shape(), - &deallocated_buffers)); - return tensorflow::Status::OK(); -} - -tensorflow::Status AllocationTracker::DeallocateShape( - Backend* backend, int device_ordinal, se::DeviceMemoryBase* device_memory, - const Shape& shape, std::set* deallocated_buffers) { - VLOG(2) << "DeallocateShape(" - << "shape: \"" << shape.ShortDebugString() << "\" " - << "device_memory: " << device_memory->opaque() << ")"; - if (ContainsKey(*deallocated_buffers, device_memory->opaque())) { - // Buffer has already been deallocated. Nothing to do. - VLOG(2) << "already deallocated"; - return tensorflow::Status::OK(); + tensorflow::mutex_lock lock(mutex_); + VLOG(2) << "Unregister(" + << "handle: " << data.handle() << ")"; + TF_ASSIGN_OR_RETURN(ShapedBuffer * shaped_buffer, ResolveInternal(data)); + std::vector shape_indices; + ShapeUtil::ForEachSubshape(shaped_buffer->on_device_shape(), + [this, &shape_indices](const Shape& /*subshape*/, + const ShapeIndex& index) { + shape_indices.push_back(index); + }); + for (const ShapeIndex& index : shape_indices) { + TF_RETURN_IF_ERROR(DecrementRefCount(shaped_buffer->buffer(index), + shaped_buffer->device_ordinal())); } - // Add buffer to deallocated set so we do not try to deallocate it again - // if it is encountered again while traversing a tuple. - deallocated_buffers->insert(device_memory->opaque()); - - HandleMap& handle_map = GetOrCreateOpaqueToHandleMap(device_ordinal); - auto handle_it = handle_map.find(device_memory->opaque()); - if (handle_it != handle_map.end()) { - int64 handle = handle_it->second; - auto& allocation = FindOrDie(handle_to_allocation_, handle); - int ref_count = allocation->ref_count(); - VLOG(2) << "ref_count: " << ref_count << " -> " << ref_count - 1; - allocation->decrement_ref_count(); - if (allocation->ref_count() > 0) { - // Buffer is referred to by another allocation. Don't deallocate it. - return tensorflow::Status::OK(); - } - handle_map.erase(device_memory->opaque()); - } + // Keep a nullptr as a tombstone for unregistered handles. This enables better + // error messages. That is, "handle has been deallocated" versus "handle does + // not exist". + handle_to_shaped_buffer_.at(data.handle()).reset(); - if (ShapeUtil::IsTuple(shape)) { - // Traverse into tuple recursively deallocating buffers. - TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, - backend->stream_executor(device_ordinal)); - TF_ASSIGN_OR_RETURN(std::vector elements, - backend->transfer_manager()->ShallowCopyTupleFromDevice( - executor, *device_memory, shape)); - - TF_RET_CHECK(ShapeUtil::TupleElementCount(shape) == elements.size()) - << "tuple has unexpected number of elements: " << elements.size() - << " != " << ShapeUtil::TupleElementCount(shape); - for (size_t i = 0; i < elements.size(); ++i) { - VLOG(2) << "recursing onto the tuple elements"; - TF_RETURN_IF_ERROR(DeallocateShape(backend, device_ordinal, &elements[i], - shape.tuple_shapes(i), - deallocated_buffers)); - } - } - - return backend->memory_allocator()->Deallocate(device_ordinal, device_memory); + return tensorflow::Status::OK(); } StatusOr> AllocationTracker::DeconstructTuple( const GlobalDataHandle& data) { - tensorflow::mutex_lock lock(allocation_mutex_); - TF_ASSIGN_OR_RETURN(Allocation * allocation, ResolveInternal(data)); + tensorflow::mutex_lock lock(mutex_); - if (!ShapeUtil::IsTuple(allocation->shape())) { + TF_ASSIGN_OR_RETURN(ShapedBuffer * shaped_buffer, ResolveInternal(data)); + if (!ShapeUtil::IsTuple(shaped_buffer->on_host_shape())) { return InvalidArgument("global data handle %lld is not a tuple", data.handle()); } + // If the on-host representation is a tuple, then the on-device one should be + // as well. + TF_RET_CHECK(ShapeUtil::IsTuple(shaped_buffer->on_device_shape())); - if (ShapeUtil::IsNestedTuple(allocation->shape())) { + if (ShapeUtil::IsNestedTuple(shaped_buffer->on_device_shape())) { return Unimplemented("deconstructing nested tuples not yet supported"); } - TF_ASSIGN_OR_RETURN( - se::StreamExecutor * executor, - allocation->backend()->stream_executor(allocation->device_ordinal())); - TF_ASSIGN_OR_RETURN( - std::vector element_bases, - allocation->backend()->transfer_manager()->ShallowCopyTupleFromDevice( - executor, allocation->device_memory(), allocation->shape())); - std::vector element_handles; - element_handles.reserve(element_bases.size()); - for (int i = 0; i < element_bases.size(); ++i) { - element_handles.push_back(RegisterInternal( - allocation->backend(), allocation->device_ordinal(), element_bases[i], - ShapeUtil::GetSubshape(allocation->shape(), {i}), - tensorflow::strings::StrCat(allocation->tag(), ".element_", i), - /*initial_ref_count=*/2)); + for (int i = 0; + i < ShapeUtil::TupleElementCount(shaped_buffer->on_device_shape()); + ++i) { + auto element_buffer = MakeUnique( + ShapeUtil::GetTupleElementShape(shaped_buffer->on_host_shape(), i), + ShapeUtil::GetTupleElementShape(shaped_buffer->on_device_shape(), i), + shaped_buffer->platform(), shaped_buffer->device_ordinal()); + element_buffer->set_buffer(shaped_buffer->buffer(/*index=*/{i}), + /*index=*/{}); + TF_ASSIGN_OR_RETURN( + GlobalDataHandle element_handle, + RegisterInternal(std::move(element_buffer), "deconstructed tuple")); + + element_handles.push_back(element_handle); } return std::move(element_handles); } -StatusOr AllocationTracker::Resolve( +StatusOr AllocationTracker::Resolve( const GlobalDataHandle& data) { - tensorflow::mutex_lock lock(allocation_mutex_); + tensorflow::mutex_lock lock(mutex_); return AllocationTracker::ResolveInternal(data); } -StatusOr AllocationTracker::ResolveInternal( +StatusOr AllocationTracker::ResolveInternal( const GlobalDataHandle& data) { VLOG(2) << "resolve:" << data.handle(); - auto it = handle_to_allocation_.find(data.handle()); - if (it == handle_to_allocation_.end()) { + auto it = handle_to_shaped_buffer_.find(data.handle()); + if (it == handle_to_shaped_buffer_.end()) { return NotFound("no allocation record for global data handle: %lld", data.handle()); } - Allocation* allocation = it->second.get(); + ShapedBuffer* shaped_buffer = it->second.get(); - if (allocation->is_deallocated()) { + if (shaped_buffer == nullptr) { return InvalidArgument("global data handle %lld was previously deallocated", data.handle()); } - return allocation; + return shaped_buffer; } -AllocationTracker::HandleMap& AllocationTracker::GetOrCreateOpaqueToHandleMap( - int device_ordinal) { - if (opaque_to_handle_.size() <= device_ordinal) { - opaque_to_handle_.resize(device_ordinal + 1); +void AllocationTracker::AddAllocationOrIncrementRefCount( + perftools::gputools::DeviceMemoryBase device_memory, int device_ordinal) { + AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal]; + auto it = allocation_map.find(device_memory.opaque()); + if (it == allocation_map.end()) { + allocation_map[device_memory.opaque()] = {device_memory, device_ordinal, + /*ref_count=*/1}; + } else { + it->second.ref_count++; } - return opaque_to_handle_[device_ordinal]; +} + +Status AllocationTracker::DecrementRefCount( + perftools::gputools::DeviceMemoryBase device_memory, int device_ordinal) { + AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal]; + auto it = allocation_map.find(device_memory.opaque()); + TF_RET_CHECK(it != allocation_map.end()); + Allocation& allocation = it->second; + TF_RET_CHECK(allocation.ref_count >= 1); + if (allocation.ref_count == 1) { + TF_RETURN_IF_ERROR(backend_->memory_allocator()->Deallocate( + device_ordinal, &device_memory)); + allocation_map.erase(it); + } else { + allocation.ref_count--; + } + return tensorflow::Status::OK(); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h index ebbf35b6fe87bc7322ccb99cfe8f8eed56de06b3..8b25cbb482720f7debe95bb5ff74afe696bd8b73 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.h +++ b/tensorflow/compiler/xla/service/allocation_tracker.h @@ -28,147 +28,92 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" namespace xla { -// A global allocation in device space, tracked by the XLA service. -class Allocation { - public: - Allocation(Backend* backend, int device_ordinal, - perftools::gputools::DeviceMemoryBase device_memory, - const Shape& shape, const string& tag, int initial_ref_count) - : backend_(backend), - device_ordinal_(device_ordinal), - device_memory_(device_memory), - shape_(shape), - tag_(tag), - ref_count_(initial_ref_count) {} - - Backend* backend() const { return backend_; } - int device_ordinal() const { return device_ordinal_; } - perftools::gputools::DeviceMemoryBase device_memory() const { - return device_memory_; - } - const Shape& shape() const { return shape_; } - const string& tag() const { return tag_; } - - bool is_deallocated() const { - CHECK_GE(ref_count_, 0); - return ref_count_ == 0; - } - int ref_count() const { - CHECK_GE(ref_count_, 0); - return ref_count_; - } - void increment_ref_count(int inc) { - CHECK_GT(ref_count_, 0); - CHECK_LE(ref_count_, INT_MAX - inc); - ref_count_ += inc; - } - void decrement_ref_count() { - CHECK_GT(ref_count_, 0); - --ref_count_; - } - perftools::gputools::DeviceMemoryBase* mutable_device_memory() { - return &device_memory_; - } - - private: - // The backend that the memory is allocated on. - Backend* backend_; - - // The device that the memory is allocated on. - int device_ordinal_; - - // The pointer to this allocation. - perftools::gputools::DeviceMemoryBase device_memory_; - - // The shape of this allocation. - Shape shape_; - - // An informal description of this allocation shown in tools. - string tag_; - - // This is the number of Allocation objects which refer to this memory - // allocation. - int ref_count_; - - // Return a string representation of this allocation for debugging or logging - // purposes. - string ToString() const; -}; - // Tracks allocations for the XLA service; allocations can be registered // with shape/device/tag and resolved from a handle for later use. class AllocationTracker { public: - AllocationTracker(); + // The allocator is used for deallocating memory when allocations are + // deregistered. All registered allocations must have the same platform as the + // allocator. + AllocationTracker(Backend* backend) : backend_(backend), next_handle_(1) {} - // Registers device memory with a given shape, device identifier, and tag, and - // returns a corresponding handle that can be used for talking to XLA - // clients. - GlobalDataHandle Register(Backend* backend, int device_ordinal, - perftools::gputools::DeviceMemoryBase device_memory, - const Shape& shape, const string& tag); + // Registers a shaped buffer of device memory, and returns a corresponding + // handle that can be used for talking to XLA clients. + StatusOr Register( + std::unique_ptr shaped_buffer, const string& tag); // Unregister the allocation for the given data handle. - tensorflow::Status Unregister(const GlobalDataHandle& data); + Status Unregister(const GlobalDataHandle& data); // Returns a vector of global data handles that point to the tuple elements. StatusOr> DeconstructTuple( const GlobalDataHandle& Data); - // Resolve a handle from an XLA client to an allocation, or provide an - // error status to say whether it was not found (or found, but found - // deallocated). - StatusOr Resolve(const GlobalDataHandle& data); + // Resolve a handle from an XLA client to a shaped buffer, or provide an error + // status to say whether it was not found (or found, but found deallocated). + StatusOr Resolve(const GlobalDataHandle& data); private: - // Internal helper which resolves the given GlobalDataHandle to an Allocation. - StatusOr ResolveInternal(const GlobalDataHandle& data) - EXCLUSIVE_LOCKS_REQUIRED(allocation_mutex_); - - GlobalDataHandle RegisterInternal( - Backend* backend, int device_ordinal, - perftools::gputools::DeviceMemoryBase device_memory, const Shape& shape, - const string& tag, int initial_ref_count) - EXCLUSIVE_LOCKS_REQUIRED(allocation_mutex_); - - // Helper function which deallocates the memory buffer containing the given - // shape referred to by device_memory. Tuples are traversed recursively - // deallocating all nested buffers. The parameter deallocated_buffers contains - // the set of buffers deallocated so far stored as opaque values (void *) from - // DeviceMemoryBase. Keeping track of deallocated buffers prevents - // double-freeing of buffers which may be referred to more than once in a - // nested tuple. - tensorflow::Status DeallocateShape( - Backend* backend, int device_ordinal, - perftools::gputools::DeviceMemoryBase* device_memory, const Shape& shape, - std::set* deallocated_buffers) - EXCLUSIVE_LOCKS_REQUIRED(allocation_mutex_); - - // Returns the opaque_to_handle_ map for the given device_ordinal, creating - // a new map if there is not one for the device_ordinal. - using HandleMap = std::map; - HandleMap& GetOrCreateOpaqueToHandleMap(int device_ordinal) - EXCLUSIVE_LOCKS_REQUIRED(allocation_mutex_); - - tensorflow::mutex allocation_mutex_; // Guards the allocation mapping. + // Data structure encapsulating single memory allocation on the device. + struct Allocation { + // The pointer to this allocation. + perftools::gputools::DeviceMemoryBase device_memory; + + // The device that the memory is allocated on. + int device_ordinal; + + // This is the number of times this memory allocation is refered to by + // registered data handles. + int ref_count; + }; + + // Internal helper which resolves the given GlobalDataHandle to a + // ShapedBuffer. + StatusOr ResolveInternal(const GlobalDataHandle& data) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Internal helper which registers a shaped buffer. + StatusOr RegisterInternal( + std::unique_ptr shaped_buffer, const string& tag) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Adds the given device address to the allocation tracker, or if it already + // exists, then increment it's reference count. + void AddAllocationOrIncrementRefCount( + perftools::gputools::DeviceMemoryBase device_memory, int device_ordinal) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Decrements the reference count of the given device memory. Then, if it is + // zero, deallocate the memory. + Status DecrementRefCount(perftools::gputools::DeviceMemoryBase device_memory, + int device_ordinal) EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // A map from device memory opaque value to allocation. One such map is + // maintained per device ordinal. + using AllocationMap = tensorflow::gtl::FlatMap; + + tensorflow::mutex mutex_; + + // Backend to use with this tracker. The backend supplies the memory allocator + // to use when deallocating memory. + Backend* backend_; // The next handle to assign to an allocation, guarded by the same mutex as // the mapping as they'll be mutated at the same time. - int64 next_handle_ GUARDED_BY(allocation_mutex_); + int64 next_handle_ GUARDED_BY(mutex_); - // A map from DeviceMemoryBase to handle for each device_ordinal. - std::vector opaque_to_handle_ GUARDED_BY(allocation_mutex_); + // A map from device ordinal to AllocationMap. + tensorflow::gtl::FlatMap opaque_to_allocation_map_ + GUARDED_BY(mutex_); - // Mapping from GlobalDataHandle handle to the corresponding registered - // Allocation object. - std::map> handle_to_allocation_ - GUARDED_BY(allocation_mutex_); + // A map from data handle to ShapedBuffer. + tensorflow::gtl::FlatMap> + handle_to_shaped_buffer_ GUARDED_BY(mutex_); TF_DISALLOW_COPY_AND_ASSIGN(AllocationTracker); }; diff --git a/tensorflow/compiler/xla/service/batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc similarity index 50% rename from tensorflow/compiler/xla/service/batchnorm_rewriter.cc rename to tensorflow/compiler/xla/service/batchnorm_expander.cc index abe881cd1a58a6173b9b93f10a7308d70106c889..b806d61663e2ca371d90dfe39d7fe66becfe4bc2 100644 --- a/tensorflow/compiler/xla/service/batchnorm_rewriter.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/batchnorm_rewriter.h" +#include "tensorflow/compiler/xla/service/batchnorm_expander.h" #include #include @@ -45,9 +45,9 @@ limitations under the License. namespace xla { -// BatchNormRewriterVisitor traverses the HLO computation and rewrites BatchNorm +// BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm // operations into smaller operations. -class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault { +class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { public: // Default visitor action is to do nothing and return OK. Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { @@ -68,10 +68,10 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault { // Returns whether any batch norm ops were rewritten. const bool changed() const { return changed_; } - ~BatchNormRewriterVisitor() override = default; + ~BatchNormExpanderVisitor() override = default; private: - explicit BatchNormRewriterVisitor(HloComputation* computation, + explicit BatchNormExpanderVisitor(HloComputation* computation, bool rewrite_training_op, bool rewrite_inference_op, bool rewrite_grad_op, bool use_fusion) @@ -85,16 +85,16 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault { HloOpcode opcode) { HloComputation::Builder b("scalar_computation"); auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {}), "scalar_lhs")); + 0, ShapeUtil::MakeShape(primitive_type, {}), "scalar_lhs")); auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {}), "scalar_rhs")); + 1, ShapeUtil::MakeShape(primitive_type, {}), "scalar_rhs")); auto scalar_op = b.AddInstruction( HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}), opcode, scalar_lhs, scalar_rhs)); return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); } - // Current HloComputation instance the BatchNormRewriter is + // Current HloComputation instance the BatchNormExpander is // traversing. HloComputation* computation_; @@ -130,11 +130,11 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault { } }; -bool BatchNormRewriterVisitor::Run(HloComputation* computation, +bool BatchNormExpanderVisitor::Run(HloComputation* computation, bool rewrite_training_op, bool rewrite_inference_op, bool rewrite_grad_op, bool use_fusion) { - BatchNormRewriterVisitor visitor( + BatchNormExpanderVisitor visitor( computation, /*rewrite_training_op=*/rewrite_training_op, /*rewrite_inference_op=*/rewrite_inference_op, @@ -144,31 +144,46 @@ bool BatchNormRewriterVisitor::Run(HloComputation* computation, return visitor.changed_; } -Status BatchNormRewriterVisitor::HandleBatchNormTraining( +Status BatchNormExpanderVisitor::HandleBatchNormTraining( HloInstruction* batch_norm) { if (!rewrite_training_op_) { return Status::OK(); } + + std::vector added_instructions; + auto add = [&](std::unique_ptr inst) { + HloInstruction* added_inst = computation_->AddInstruction(std::move(inst)); + added_instructions.push_back(added_inst); + return added_inst; + }; + int64 instruction_count_before = computation_->instruction_count(); + // Expand batch norm training into smaller HLO ops. HloInstruction* operand = batch_norm->mutable_operand(0); const Shape operand_shape = operand->shape(); + PrimitiveType ptype = operand_shape.element_type(); int64 feature_index = batch_norm->feature_index(); const int64 feature_count = operand_shape.dimensions(feature_index); const int64 size_in_elements = ShapeUtil::ElementsIn(operand_shape); - auto elements_per_feature = - computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(size_in_elements / feature_count))); + auto elements_per_feature_literal = + Literal::CreateR0(size_in_elements / feature_count); + TF_ASSIGN_OR_RETURN(elements_per_feature_literal, + elements_per_feature_literal->Convert(ptype)); + auto elements_per_feature = add( + HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); HloInstruction* scale = batch_norm->mutable_operand(1); HloInstruction* offset = batch_norm->mutable_operand(2); const Shape feature_shape = scale->shape(); - auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - - auto epsilon = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); + auto zero_literal = Literal::CreateR0(0.0f); + TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); + auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); + auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); + TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); + auto epsilon = + add(HloInstruction::CreateConstant(std::move(epsilon_literal))); std::vector dimensions_without_feature; for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) { @@ -177,107 +192,114 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining( } } - auto scale_broadcasted = computation_->AddInstruction( + auto scale_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index})); - auto offset_broadcasted = computation_->AddInstruction( + auto offset_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index})); HloComputation* add_reduce_computation = - GetScalarBinaryComputation(F32, HloOpcode::kAdd); + GetScalarBinaryComputation(ptype, HloOpcode::kAdd); // X^2. - auto operand_squared = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, operand, operand)); + auto operand_squared = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kMultiply, operand, operand)); // Sum[X]. - auto sum = computation_->AddInstruction(HloInstruction::CreateReduce( - feature_shape, operand, zero, dimensions_without_feature, - add_reduce_computation)); + auto sum = add(HloInstruction::CreateReduce(feature_shape, operand, zero, + dimensions_without_feature, + add_reduce_computation)); // Sum[X^2]. - auto squared_sum = computation_->AddInstruction(HloInstruction::CreateReduce( + auto squared_sum = add(HloInstruction::CreateReduce( feature_shape, operand_squared, zero, dimensions_without_feature, add_reduce_computation)); // Fuse two parallel reduces together to improve performance. - if (use_fusion_) { - auto tuple = computation_->AddInstruction( - HloInstruction::CreateTuple({sum, squared_sum})); + if (use_fusion_ && !batch_norm->has_sharding()) { + auto tuple = add(HloInstruction::CreateTuple({sum, squared_sum})); auto fused = computation_->CreateFusionInstruction( {tuple, sum, squared_sum, operand_squared}, HloInstruction::FusionKind::kInput); - sum = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); + sum = add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); - squared_sum = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); + squared_sum = + add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); } // E[X]. - auto mean = computation_->AddInstruction(HloInstruction::CreateBinary( + auto mean = add(HloInstruction::CreateBinary( feature_shape, HloOpcode::kDivide, sum, elements_per_feature)); - auto mean_broadcasted = computation_->AddInstruction( + auto mean_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index})); // E[X^2]. - auto square_mean = computation_->AddInstruction(HloInstruction::CreateBinary( + auto square_mean = add(HloInstruction::CreateBinary( feature_shape, HloOpcode::kDivide, squared_sum, elements_per_feature)); // E^2[X]. - auto mean_square = computation_->AddInstruction(HloInstruction::CreateBinary( + auto mean_square = add(HloInstruction::CreateBinary( feature_shape, HloOpcode::kMultiply, mean, mean)); // Var[X]. - auto var = computation_->AddInstruction(HloInstruction::CreateBinary( + auto var = add(HloInstruction::CreateBinary( feature_shape, HloOpcode::kSubtract, square_mean, mean_square)); - auto var_broadcasted = computation_->AddInstruction( - HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); + auto var_broadcasted = + add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); // Var[X] + epsilon. - auto var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); + auto var_add_epsilon = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); - auto neg_half = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(-0.5f))); + auto neg_half_literal = Literal::CreateR0(-0.5f); + TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); + auto neg_half = + add(HloInstruction::CreateConstant(std::move(neg_half_literal))); // 1 / Sqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); + auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); // X - E[X]. - auto operand_minus_mean = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); + auto operand_minus_mean = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); // (X - E[X]) / Sqrt[Var[X] + epsilon]. - auto normalized = computation_->AddInstruction( + auto normalized = add( HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply, operand_minus_mean, rsqrt_var_add_epsilon)); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale. - auto scaled_normalized = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); + auto scaled_normalized = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset. - auto shifted_normalized = computation_->AddInstruction( - HloInstruction::CreateBinary(operand_shape, HloOpcode::kAdd, - scaled_normalized, offset_broadcasted)); - - TF_CHECK_OK(ReplaceWithNewInstruction( - batch_norm, - HloInstruction::CreateTuple({shifted_normalized, mean, var}))); + auto shifted_normalized = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted)); + + auto tuple = HloInstruction::CreateTuple({shifted_normalized, mean, var}); + + if (batch_norm->has_sharding()) { + int64 instruction_count_after = computation_->instruction_count(); + CHECK_EQ(instruction_count_after, + instruction_count_before + added_instructions.size()); + for (HloInstruction* inst : added_instructions) { + if (ShapeUtil::Equal(inst->shape(), operand_shape)) { + inst->set_sharding(batch_norm->sharding()); + } else { + inst->set_sharding(HloSharding::Replicate()); + } + } + tuple->set_sharding(batch_norm->sharding()); + } + TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple))); return Status::OK(); } -Status BatchNormRewriterVisitor::HandleBatchNormInference( +Status BatchNormExpanderVisitor::HandleBatchNormInference( HloInstruction* batch_norm) { if (!rewrite_inference_op_) { return Status::OK(); @@ -286,6 +308,7 @@ Status BatchNormRewriterVisitor::HandleBatchNormInference( HloInstruction* operand = batch_norm->mutable_operand(0); const Shape operand_shape = operand->shape(); int64 feature_index = batch_norm->feature_index(); + PrimitiveType ptype = operand_shape.element_type(); HloInstruction* scale = batch_norm->mutable_operand(1); HloInstruction* offset = batch_norm->mutable_operand(2); @@ -293,8 +316,10 @@ Status BatchNormRewriterVisitor::HandleBatchNormInference( HloInstruction* var = batch_norm->mutable_operand(4); const Shape feature_shape = scale->shape(); + auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); + TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); auto epsilon = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); + HloInstruction::CreateConstant(std::move(epsilon_literal))); std::vector dimensions_without_feature; @@ -304,56 +329,75 @@ Status BatchNormRewriterVisitor::HandleBatchNormInference( } } - auto scale_broadcasted = computation_->AddInstruction( + std::vector added_instructions; + auto add = [&](std::unique_ptr inst) { + HloInstruction* added_inst = computation_->AddInstruction(std::move(inst)); + added_instructions.push_back(added_inst); + return added_inst; + }; + int64 instruction_count_before = computation_->instruction_count(); + + auto scale_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index})); - auto offset_broadcasted = computation_->AddInstruction( + auto offset_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index})); - auto mean_broadcasted = computation_->AddInstruction( + auto mean_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index})); - auto var_broadcasted = computation_->AddInstruction( - HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); + auto var_broadcasted = + add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); // Var[X] + epsilon. - auto var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); + auto var_add_epsilon = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); - auto neg_half = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(-0.5f))); + auto neg_half_literal = Literal::CreateR0(-0.5f); + TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); + auto neg_half = + add(HloInstruction::CreateConstant(std::move(neg_half_literal))); // 1 / Sqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); + auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); // X - E[X]. - auto operand_minus_mean = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); + auto operand_minus_mean = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); // (X - E[X]) / Sqrt[Var[X] + epsilon]. - auto normalized = computation_->AddInstruction( + auto normalized = add( HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply, operand_minus_mean, rsqrt_var_add_epsilon)); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale. - auto scaled_normalized = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); + auto scaled_normalized = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset. auto shifted_normalized = HloInstruction::CreateBinary( operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted); + int64 instruction_count_after = computation_->instruction_count(); + CHECK_EQ(instruction_count_after, + instruction_count_before + added_instructions.size()); + if (batch_norm->has_sharding()) { + for (HloInstruction* inst : added_instructions) { + if (ShapeUtil::Equal(inst->shape(), operand_shape)) { + inst->set_sharding(batch_norm->sharding()); + } else { + inst->set_sharding(HloSharding::Replicate()); + } + } + shifted_normalized->set_sharding(batch_norm->sharding()); + } TF_CHECK_OK( ReplaceWithNewInstruction(batch_norm, std::move(shifted_normalized))); return Status::OK(); } -Status BatchNormRewriterVisitor::HandleBatchNormGrad( +Status BatchNormExpanderVisitor::HandleBatchNormGrad( HloInstruction* batch_norm) { // Use the following formulas to calculate gradients: // scale_grad = @@ -370,9 +414,17 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad( if (!rewrite_grad_op_) { return Status::OK(); } + std::vector added_instructions; + auto add = [&](std::unique_ptr inst) { + HloInstruction* added_inst = computation_->AddInstruction(std::move(inst)); + added_instructions.push_back(added_inst); + return added_inst; + }; + int64 instruction_count_before = computation_->instruction_count(); HloInstruction* activation = batch_norm->mutable_operand(0); const Shape activation_shape = activation->shape(); + PrimitiveType ptype = activation_shape.element_type(); HloInstruction* scale = batch_norm->mutable_operand(1); const Shape feature_shape = scale->shape(); HloInstruction* mean = batch_norm->mutable_operand(2); @@ -383,18 +435,26 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad( const int64 size_in_elements = ShapeUtil::ElementsIn(activation_shape); const int64 feature_count = activation_shape.dimensions(feature_index); - auto elements_per_feature = - computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(size_in_elements / feature_count))); - - auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - - auto neg_half = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(-0.5f))); - - auto epsilon = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); + auto elements_per_feature_literal = + Literal::CreateR0(size_in_elements / feature_count); + TF_ASSIGN_OR_RETURN(elements_per_feature_literal, + elements_per_feature_literal->Convert(ptype)); + auto elements_per_feature = add( + HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); + + auto zero_literal = Literal::CreateR0(0.0f); + TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); + auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); + + auto neg_half_literal = Literal::CreateR0(-0.5f); + TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); + auto neg_half = + add(HloInstruction::CreateConstant(std::move(neg_half_literal))); + + auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); + TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); + auto epsilon = + add(HloInstruction::CreateConstant(std::move(epsilon_literal))); std::vector dimensions_without_feature; @@ -404,141 +464,146 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad( } } - auto scale_broadcasted = - computation_->AddInstruction(HloInstruction::CreateBroadcast( - activation_shape, scale, {feature_index})); - auto variance_broadcasted = - computation_->AddInstruction(HloInstruction::CreateBroadcast( - activation_shape, variance, {feature_index})); + auto scale_broadcasted = add(HloInstruction::CreateBroadcast( + activation_shape, scale, {feature_index})); + auto variance_broadcasted = add(HloInstruction::CreateBroadcast( + activation_shape, variance, {feature_index})); // E[X]. - auto mean_broadcasted = computation_->AddInstruction( + auto mean_broadcasted = add( HloInstruction::CreateBroadcast(activation_shape, mean, {feature_index})); // rsqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon_broadcasted = - computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kPower, - computation_->AddInstruction( - HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd, - variance_broadcasted, epsilon)), - neg_half)); - - auto rsqrt_var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kPower, - computation_->AddInstruction(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kAdd, variance, epsilon)), - neg_half)); + auto rsqrt_var_add_epsilon_broadcasted = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kPower, + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd, + variance_broadcasted, epsilon)), + neg_half)); + + auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( + feature_shape, HloOpcode::kPower, + add(HloInstruction::CreateBinary(feature_shape, HloOpcode::kAdd, variance, + epsilon)), + neg_half)); // X - E[X]. - auto activation_minus_mean = computation_->AddInstruction( - HloInstruction::CreateBinary(activation_shape, HloOpcode::kSubtract, - activation, mean_broadcasted)); + auto activation_minus_mean = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kSubtract, activation, mean_broadcasted)); // Grad[Y] * (X - E[X]). - auto grad_output_times_activiation_minus_mean = computation_->AddInstruction( - HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - grad_output, activation_minus_mean)); + auto grad_output_times_activiation_minus_mean = + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, + grad_output, activation_minus_mean)); HloComputation* add_reduce_computation = - GetScalarBinaryComputation(F32, HloOpcode::kAdd); + GetScalarBinaryComputation(ptype, HloOpcode::kAdd); // sum(Grad[Y] * (X - E[X])). auto sum_grad_output_times_activiation_minus_mean = - computation_->AddInstruction(HloInstruction::CreateReduce( + add(HloInstruction::CreateReduce( feature_shape, grad_output_times_activiation_minus_mean, zero, dimensions_without_feature, add_reduce_computation)); // Grad[beta] = Sum(Grad[Y]). - auto grad_beta = computation_->AddInstruction(HloInstruction::CreateReduce( + auto grad_beta = add(HloInstruction::CreateReduce( feature_shape, grad_output, zero, dimensions_without_feature, add_reduce_computation)); - if (use_fusion_) { - auto tuple = computation_->AddInstruction(HloInstruction::CreateTuple( + if (use_fusion_ && !batch_norm->has_sharding()) { + auto tuple = add(HloInstruction::CreateTuple( {sum_grad_output_times_activiation_minus_mean, grad_beta})); auto fused = computation_->CreateFusionInstruction( {tuple, sum_grad_output_times_activiation_minus_mean, grad_beta}, HloInstruction::FusionKind::kInput); - sum_grad_output_times_activiation_minus_mean = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); + sum_grad_output_times_activiation_minus_mean = + add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); - grad_beta = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); + grad_beta = + add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); } // Grad[scale] = Sum(Grad[Y] * (X - E[X]) * rsqrt[Var[X] + epsilon]). - auto grad_scale = computation_->AddInstruction(HloInstruction::CreateBinary( + auto grad_scale = add(HloInstruction::CreateBinary( feature_shape, HloOpcode::kMultiply, sum_grad_output_times_activiation_minus_mean, rsqrt_var_add_epsilon)); // I2 = Sum(Grad[Y]) - auto I2 = computation_->AddInstruction(HloInstruction::CreateBroadcast( - activation_shape, grad_beta, {feature_index})); + auto i2 = add(HloInstruction::CreateBroadcast(activation_shape, grad_beta, + {feature_index})); // I3 = Sum(Grad[Y] * (X - E[X])) - auto I3 = computation_->AddInstruction(HloInstruction::CreateBroadcast( + auto i3 = add(HloInstruction::CreateBroadcast( activation_shape, sum_grad_output_times_activiation_minus_mean, {feature_index})); // I4 = (X - E[X]) * I3 - auto I4 = computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kMultiply, I3, activation_minus_mean)); + auto i4 = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kMultiply, i3, activation_minus_mean)); // I5 = I4 / (Var[X] + epsilon) - auto I5 = computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kDivide, I4, - computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kAdd, variance_broadcasted, epsilon)))); + auto i5 = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kDivide, i4, + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd, + variance_broadcasted, epsilon)))); // scale * rsqrt[Var[X] + epsilon] * 1/N - auto scale_times_rsqrt_var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kMultiply, scale_broadcasted, - rsqrt_var_add_epsilon_broadcasted)); + auto scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kMultiply, scale_broadcasted, + rsqrt_var_add_epsilon_broadcasted)); - scale_times_rsqrt_var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kDivide, - scale_times_rsqrt_var_add_epsilon, elements_per_feature)); + scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kDivide, scale_times_rsqrt_var_add_epsilon, + elements_per_feature)); - auto I1 = computation_->AddInstruction( - HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - grad_output, elements_per_feature)); + auto i1 = + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, + grad_output, elements_per_feature)); // I6 = I1 - I2 - I5 - auto I6 = computation_->AddInstruction(HloInstruction::CreateBinary( + auto i6 = add(HloInstruction::CreateBinary( activation_shape, HloOpcode::kSubtract, - computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kSubtract, I1, I2)), - I5)); + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kSubtract, + i1, i2)), + i5)); // Grad[X] = scale * rsqrt[Var[X] + epsilon] * 1/N * I6. - auto grad_activation = computation_->AddInstruction( - HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - scale_times_rsqrt_var_add_epsilon, I6)); + auto grad_activation = + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, + scale_times_rsqrt_var_add_epsilon, i6)); + auto tuple = + HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta}); + if (batch_norm->has_sharding()) { + int64 instruction_count_after = computation_->instruction_count(); + CHECK_EQ(instruction_count_after, + instruction_count_before + added_instructions.size()); + for (HloInstruction* inst : added_instructions) { + if (ShapeUtil::Equal(inst->shape(), activation_shape)) { + inst->set_sharding(batch_norm->sharding()); + } else { + inst->set_sharding(HloSharding::Replicate()); + } + } + tuple->set_sharding(batch_norm->sharding()); + } - TF_CHECK_OK(ReplaceWithNewInstruction( - batch_norm, - HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta}))); + TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple))); return Status::OK(); } -StatusOr BatchNormRewriter::Run(HloModule* module) { - XLA_VLOG_LINES(2, "BatchNormRewriter::Run(), before:\n" + module->ToString()); +StatusOr BatchNormExpander::Run(HloModule* module) { + XLA_VLOG_LINES(2, "BatchNormExpander::Run(), before:\n" + module->ToString()); bool changed = false; for (auto* comp : module->MakeNonfusionComputations()) { - if (BatchNormRewriterVisitor::Run(comp, rewrite_training_op_, + if (BatchNormExpanderVisitor::Run(comp, rewrite_training_op_, rewrite_inference_op_, rewrite_grad_op_, use_fusion_)) { changed = true; } } - XLA_VLOG_LINES(2, "BatchNormRewriter::Run(), after:\n" + module->ToString()); + XLA_VLOG_LINES(2, "BatchNormExpander::Run(), after:\n" + module->ToString()); return changed; } diff --git a/tensorflow/compiler/xla/service/batchnorm_rewriter.h b/tensorflow/compiler/xla/service/batchnorm_expander.h similarity index 83% rename from tensorflow/compiler/xla/service/batchnorm_rewriter.h rename to tensorflow/compiler/xla/service/batchnorm_expander.h index f601741d964376058a2bafade311ede4c8567fd2..4ad987085da91684bb7891070afeefd19be4138f 100644 --- a/tensorflow/compiler/xla/service/batchnorm_rewriter.h +++ b/tensorflow/compiler/xla/service/batchnorm_expander.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BATCHNORM_REWRITER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_BATCHNORM_REWRITER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BATCHNORM_EXPANDER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BATCHNORM_EXPANDER_H_ #include @@ -26,18 +26,18 @@ namespace xla { // A pass which rewrites batch norm operations into more operations. Breaking a // big operation into smaller operations helps leverage our generic fusion // logic. -class BatchNormRewriter : public HloPassInterface { +class BatchNormExpander : public HloPassInterface { public: // When use_fusion is set, a multi-output fusion node is created. - BatchNormRewriter(bool rewrite_training_op = false, + BatchNormExpander(bool rewrite_training_op = false, bool rewrite_inference_op = false, bool rewrite_grad_op = false, bool use_fusion = true) : rewrite_training_op_(rewrite_training_op), rewrite_inference_op_(rewrite_inference_op), rewrite_grad_op_(rewrite_grad_op), use_fusion_(use_fusion) {} - ~BatchNormRewriter() = default; - tensorflow::StringPiece name() const override { return "batchnorm_rewriter"; } + ~BatchNormExpander() = default; + tensorflow::StringPiece name() const override { return "batchnorm_expander"; } // Run operation expander on the given computation. Returns whether the // computation was changed. @@ -52,4 +52,4 @@ class BatchNormRewriter : public HloPassInterface { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BATCHNORM_REWRITER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BATCHNORM_EXPANDER_H_ diff --git a/tensorflow/compiler/xla/service/batchnorm_rewriter_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc similarity index 93% rename from tensorflow/compiler/xla/service/batchnorm_rewriter_test.cc rename to tensorflow/compiler/xla/service/batchnorm_expander_test.cc index 590f79aee51ccf410823b91fd8ad09fc7c429c7d..aa36e64b07099a372dab67babc7a18a2d39596bc 100644 --- a/tensorflow/compiler/xla/service/batchnorm_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/batchnorm_rewriter.h" +#include "tensorflow/compiler/xla/service/batchnorm_expander.h" #include #include @@ -36,10 +36,10 @@ limitations under the License. namespace xla { namespace { -using BatchNormRewriterTest = HloTestBase; +using BatchNormExpanderTest = HloTestBase; // Test that we expand BatchNormTraining. -TEST_F(BatchNormRewriterTest, BatchNormTraining) { +TEST_F(BatchNormExpanderTest, BatchNormTraining) { Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2, 2}); Shape scale_shape = ShapeUtil::MakeShape(F32, {2}); Shape offset_shape = ShapeUtil::MakeShape(F32, {2}); @@ -63,7 +63,7 @@ TEST_F(BatchNormRewriterTest, BatchNormTraining) { auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormTraining); - BatchNormRewriter rewriter(/*rewrite_training_op=*/true, + BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); @@ -73,7 +73,7 @@ TEST_F(BatchNormRewriterTest, BatchNormTraining) { } // Test that we expand BatchNormGrad. -TEST_F(BatchNormRewriterTest, BatchNormGrad) { +TEST_F(BatchNormExpanderTest, BatchNormGrad) { Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2, 2}); Shape scale_shape = ShapeUtil::MakeShape(F32, {2}); Shape mean_shape = ShapeUtil::MakeShape(F32, {2}); @@ -105,7 +105,7 @@ TEST_F(BatchNormRewriterTest, BatchNormGrad) { auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormGrad); - BatchNormRewriter rewriter(/*rewrite_training_op=*/true, + BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 033034b4210fa1bd3ae78f0ef869ec2be879f229..7ece79d781acfaffc21d6a29e8a12e68622a1617 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -581,6 +581,7 @@ Status GatherComputationsByAllocationType( instruction->called_computations()) { switch (instruction->opcode()) { case HloOpcode::kCall: + case HloOpcode::kConditional: case HloOpcode::kWhile: // Call and while must be called from a computation with global // allocations as they may return references to buffers inside the @@ -976,8 +977,8 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( const HloOrdering& hlo_ordering = assignment->liveness().hlo_ordering(); if (run_whole_module_heap_simulation) { // Run the heap simulation over the whole module. This reduces memory usage, - // since buffers for kCall and kWhile sub-computations are only live for the - // duration of their calling instructions. + // since buffers for kCall, kWhile, and kConditional sub-computations are + // only live for the duration of their calling instructions. VLOG(1) << "Running whole-module heap simulation"; SequentialHloOrdering::HloModuleSequence module_sequence; FlatSet all_buffers_to_assign; @@ -1265,7 +1266,6 @@ const LogicalBuffer* AddBufferToColocatedSet( // CopyInsertion ensures root points-to set is unambiguous and distinct. const auto& points_to = points_to_analysis.GetPointsToSet(instruction); DCHECK(!points_to.IsAmbiguous()); - DCHECK(points_to.IsDistinct()); colocated_set->push_back(points_to.element(index)[0]); return colocated_set->back(); } @@ -1273,7 +1273,8 @@ const LogicalBuffer* AddBufferToColocatedSet( } // namespace // Builds sets of buffers in 'colocated_buffer_sets' which should be colocated -// in the same allocation (currently just supports kWhile and kCall). +// in the same allocation (currently just supports kWhile, kCall, and +// kConditional). void BufferAssigner::BuildColocatedBufferSets( const HloModule* module, const BufferLiveness& buffer_liveness, const LogicalBuffer::SizeFunction& buffer_size, @@ -1337,6 +1338,26 @@ void BufferAssigner::BuildColocatedBufferSets( &colocated_set); AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); }); + } else if (opcode == HloOpcode::kConditional) { + const HloInstruction* conditional_hlo = instruction; + ShapeUtil::ForEachSubshape( + conditional_hlo->shape(), + [this, conditional_hlo, &points_to_analysis, colocated_buffer_sets]( + const Shape& /*subshape*/, const ShapeIndex& index) { + std::vector colocated_set; + // Add conditional.result. + AddBufferToColocatedSet(conditional_hlo, index, + points_to_analysis, &colocated_set); + // Add conditional.true_computation.root. + AddBufferToColocatedSet( + conditional_hlo->true_computation()->root_instruction(), + index, points_to_analysis, &colocated_set); + // Add conditional.false_computation.root. + AddBufferToColocatedSet( + conditional_hlo->false_computation()->root_instruction(), + index, points_to_analysis, &colocated_set); + AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); + }); } } } diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 89410f42bd7b5fa8f9b380c868fcd4fedb54576c..6fc9d783f1b34de8c0f93c6aa342591891d08eaf 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -85,7 +85,7 @@ class BufferAssignmentTest : public HloTestBase { std::unique_ptr RunBufferAssignment(HloModule* module, int64 alignment = 1) { return BufferAssigner::Run( - module, MakeUnique(module), + module, xla::MakeUnique(module), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }) .ConsumeValueOrDie(); @@ -94,7 +94,7 @@ class BufferAssignmentTest : public HloTestBase { std::unique_ptr RunColoredBufferAssignment( HloModule* module, BufferLiveness::Colorer colorer, int64 alignment = 1) { return BufferAssigner::Run( - module, MakeUnique(module), + module, xla::MakeUnique(module), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, false, std::move(colorer)) @@ -166,6 +166,15 @@ class BufferAssignmentTest : public HloTestBase { return builder.Build(); } + std::unique_ptr BuildR0F32UnaryOpComputation( + HloOpcode opcode, const string& name) { + auto builder = HloComputation::Builder(name); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x")); + builder.AddInstruction(HloInstruction::CreateUnary(r0f32_, opcode, param)); + return builder.Build(); + } + // Verifies that the given instruction hlo has a valid input buffer assigned, // i.e., the parameter number matches the op's. const BufferAllocation& GetAssignedInputAllocation( @@ -740,6 +749,56 @@ TEST_F(BufferAssignmentTest, ExampleWhile) { << " instructions; total buffer size " << size0 + sizec + sizeb; } +TEST_F(BufferAssignmentTest, ExampleConditional) { + auto module = CreateNewModule(); + auto true_computation = module->AddEmbeddedComputation( + BuildR0F32UnaryOpComputation(HloOpcode::kCeil, "Ceil")); + auto false_computation = module->AddEmbeddedComputation( + BuildR0F32UnaryOpComputation(HloOpcode::kFloor, "Floor")); + + auto builder = HloComputation::Builder(TestName()); + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + auto const1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(56.4f))); + auto const2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(12.4f))); + auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( + r0f32_, pred, const1, true_computation, const2, false_computation)); + module->AddEntryComputation(builder.Build()); + + const std::vector conditional_instrs = + GetInstructions(conditional); + const std::vector true_instrs = + GetInstructions(true_computation->root_instruction()); + const std::vector false_instrs = + GetInstructions(false_computation->root_instruction()); + EXPECT_EQ(4, conditional_instrs.size()); + EXPECT_EQ(2, true_instrs.size()); + EXPECT_EQ(2, false_instrs.size()); + + auto buffers = RunBufferAssignment(module.get()); + ValidateBuffers(conditional_instrs, *buffers); + ValidateBuffers(true_instrs, *buffers); + ValidateBuffers(false_instrs, *buffers); + + EXPECT_FALSE(BuffersDistinct(conditional_instrs, true_instrs, *buffers)) + << "Should be reuse between conditional and true computation."; + EXPECT_FALSE(BuffersDistinct(conditional_instrs, false_instrs, *buffers)) + << "Should be reuse between conditional and false computation."; + EXPECT_FALSE(BuffersDistinct(true_instrs, false_instrs, *buffers)) + << "Should be reuse between true and false computations."; + + const BufferAllocation& conditional_buffer = + GetTopLevelAllocation(*buffers, conditional); + const BufferAllocation& true_buffer = + GetTopLevelAllocation(*buffers, true_computation->root_instruction()); + const BufferAllocation& false_buffer = + GetTopLevelAllocation(*buffers, false_computation->root_instruction()); + EXPECT_EQ(conditional_buffer.size(), true_buffer.size()); + EXPECT_EQ(conditional_buffer.size(), false_buffer.size()); +} + TEST_F(BufferAssignmentTest, UnaryOpReuseChain) { // param0[100] ---> (exp) ---> (tanh) ---> (exp) ---> (neg) auto builder = HloComputation::Builder(TestName()); @@ -1360,10 +1419,13 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) { HloInstruction::CreateParameter(1, shape_3x4, "param_b")); auto param_c = builder.AddInstruction( HloInstruction::CreateParameter(2, shape_4x4, "param_c")); - auto dot_ab = builder.AddInstruction(HloInstruction::CreateBinary( - shape_2x4, HloOpcode::kDot, param_a, param_b)); - auto dot_bc = builder.AddInstruction(HloInstruction::CreateBinary( - shape_3x4, HloOpcode::kDot, param_b, param_c)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot_ab = builder.AddInstruction( + HloInstruction::CreateDot(shape_2x4, param_a, param_b, dot_dnums)); + auto dot_bc = builder.AddInstruction( + HloInstruction::CreateDot(shape_3x4, param_b, param_c, dot_dnums)); builder.AddInstruction( HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 1)); @@ -1448,7 +1510,7 @@ class WhileBufferAssignmentTest : public HloTestBase { auto sequence = CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie(); return BufferAssigner::Run( - module, MakeUnique(module, sequence), + module, xla::MakeUnique(module, sequence), ByteSizeOf, [alignment](LogicalBuffer::Color) { return alignment; }) .ConsumeValueOrDie(); @@ -1469,7 +1531,7 @@ static void RunCopyInsertion(HloModule* module) { } TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { - auto module = MakeUnique(TestName()); + auto module = xla::MakeUnique(TestName()); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( @@ -1526,7 +1588,7 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { } TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { - auto module = MakeUnique(TestName()); + auto module = xla::MakeUnique(TestName()); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( @@ -1538,8 +1600,6 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { HloInstruction::CreateConstant(Literal::CreateR0(0.0))); auto output0 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, zero, {1})); - auto output1 = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, zero, {1})); auto cond0 = module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); @@ -1556,10 +1616,8 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { auto body1 = module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); - auto tuple1 = builder.AddInstruction( - HloInstruction::CreateTuple({input0, weights0, output1})); auto while1 = builder.AddInstruction( - HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1)); + HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, while0)); module->AddEntryComputation(builder.Build()); RunCopyInsertion(module.get()); @@ -1575,7 +1633,7 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { } TEST_F(BufferAssignmentTest, TwoCalls) { - auto module = MakeUnique(TestName()); + auto module = xla::MakeUnique(TestName()); Shape r0f32 = ShapeUtil::MakeShape(xla::F32, {}); HloComputation* sub_computation; { @@ -1640,7 +1698,7 @@ static bool IsPostOrderTraversal( } TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { - auto module = MakeUnique(TestName()); + auto module = xla::MakeUnique(TestName()); auto builder = HloComputation::Builder(TestName()); auto zero = builder.AddInstruction( @@ -1676,11 +1734,14 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { auto while1 = builder.AddInstruction( HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple1)); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, while0, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, while1, 1)); auto root_add = builder.AddInstruction(HloInstruction::CreateBinary( - while0->shape(), HloOpcode::kAdd, while0, while1)); - module->AddEntryComputation(builder.Build()); + while0->shape(), HloOpcode::kAdd, gte0, gte1)); - RunCopyInsertion(module.get()); + module->AddEntryComputation(builder.Build()); { FlattenCallGraph flatten; @@ -1688,84 +1749,35 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { EXPECT_TRUE(result); } + RunCopyInsertion(module.get()); + auto sequence = CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie(); // To trigger b/38494731, we want a specific Hlo sequence for the // root computation, so we overwrite that entry with a manually // crafted sequence. - std::vector sequence_for_buffer_assigment = { - input1, weights1, one, output1, tuple1, while1, input0, - weights0, zero, output0, tuple0, while0, root_add}; + sequence[module->entry_computation()] = { + input1, weights1, one, output1, while1->operand(0), while1, + input0, weights0, zero, output0, while0->operand(0), while0, + gte0, gte1, root_add}; // If this ASSERT_TRUE fails, we constructed a bogus sequence above // and this test itself is buggy. - ASSERT_TRUE(IsPostOrderTraversal(sequence_for_buffer_assigment)); - - sequence[module->entry_computation()] = - std::move(sequence_for_buffer_assigment); + ASSERT_TRUE(IsPostOrderTraversal(sequence[module->entry_computation()])); auto assignment = BufferAssigner::Run( module.get(), - MakeUnique(module.get(), sequence), ByteSizeOf, - [](LogicalBuffer::Color) { return 1; }) + xla::MakeUnique(module.get(), sequence), + ByteSizeOf, [](LogicalBuffer::Color) { return 1; }) .ConsumeValueOrDie(); EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); } -// Test buffer assignment for while nodes with multiple uses. -// TODO(b/37245345): Fix buffer assignment for this case. -TEST_F(WhileBufferAssignmentTest, DISABLED_TwoWhiles) { - auto module = MakeUnique(TestName()); - auto builder = HloComputation::Builder(TestName()); - - auto input0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, data_shape_, "input0")); - auto weights0 = builder.AddInstruction( - HloInstruction::CreateParameter(1, data_shape_, "weights0")); - - auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0))); - auto output0 = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, zero, {1})); - - auto cond0 = - module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); - auto body0 = - module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); - - auto tuple0 = builder.AddInstruction( - HloInstruction::CreateTuple({input0, weights0, output0})); - auto while0 = builder.AddInstruction( - HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0)); - auto while1 = builder.AddInstruction( - HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, while0)); - - auto get0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, while0, 2)); - auto get1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, while1, 2)); - builder.AddInstruction( - HloInstruction::CreateBinary(data_shape_, HloOpcode::kAdd, get0, get1)); - module->AddEntryComputation(builder.Build()); - - RunCopyInsertion(module.get()); - - { - FlattenCallGraph flatten; - TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get())); - EXPECT_TRUE(result); - } - - auto assignment = RunBufferAssignment(module.get()); - - EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); -} - TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) { - auto module = MakeUnique(TestName()); + auto module = xla::MakeUnique(TestName()); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index 56600b583803e23324db778959de620440fce5cf..13825fe05bb1b98045f1a3dac3d7272a2d1151fb 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -120,7 +120,7 @@ TEST_F(BufferLivenessTest, ElementwiseChain) { auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique(module.get())) .ConsumeValueOrDie(); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate)); @@ -167,10 +167,10 @@ TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) { SequentialHloOrdering::HloModuleSequence sequence; sequence.insert({entry, {param0, negate, param1, exp, add}}); - auto liveness = BufferLiveness::Run( - module.get(), - MakeUnique(module.get(), sequence)) - .ConsumeValueOrDie(); + auto liveness = + BufferLiveness::Run(module.get(), xla::MakeUnique( + module.get(), sequence)) + .ConsumeValueOrDie(); // Entry parameters interfere as if they are defined simultaneously at // the very beginning. @@ -216,7 +216,7 @@ TEST_F(BufferLivenessTest, NonElementwiseOperand) { auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique(module.get())) .ConsumeValueOrDie(); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp)); @@ -250,7 +250,7 @@ TEST_F(BufferLivenessTest, OverlappedBuffers) { auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique(module.get())) .ConsumeValueOrDie(); EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate)); @@ -294,7 +294,7 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { std::vector order = {param, negate, exp, add}; module_sequence.emplace(computation, order); auto liveness = - BufferLiveness::Run(module.get(), MakeUnique( + BufferLiveness::Run(module.get(), xla::MakeUnique( module.get(), module_sequence)) .ConsumeValueOrDie(); @@ -334,7 +334,7 @@ TEST_F(BufferLivenessTest, TupleLiveOut) { auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique(module.get())) .ConsumeValueOrDie(); // All buffers should be live out except the param @@ -370,7 +370,7 @@ TEST_F(BufferLivenessTest, EmbeddedComputation) { auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique(module.get())) .ConsumeValueOrDie(); // Buffers in different computations should always interfere. @@ -409,7 +409,7 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) { auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique(module.get())) .ConsumeValueOrDie(); // Only the element buffers of the tuple constant which are pointed to by @@ -474,7 +474,7 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) { auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique(module.get())) .ConsumeValueOrDie(); // We compare tuple element pairs that are input/output to the computation: @@ -536,7 +536,7 @@ TEST_F(BufferLivenessTest, DependentTupleElements) { auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique(module.get())) .ConsumeValueOrDie(); // We compare tuple element pairs that are input/output to the computation: @@ -624,8 +624,8 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { // Run BufferLiveness on 'module'. auto liveness = - BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), xla::MakeUnique(module.get())) .ConsumeValueOrDie(); // Return whether or not buffers interference is detected between // 'tuple_param0' and 'tuple_root' at shape index '{1}'. @@ -736,8 +736,8 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { module->AddEmbeddedComputation(builder.Build()); // Run BufferLiveness on 'module'. auto liveness = - BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), xla::MakeUnique(module.get())) .ConsumeValueOrDie(); // Return whether or not buffers interference is detected between // 'tuple_param0' and 'tuple_root' at shape index '{1}'. diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index 1adecdb939cb2c1259003d3be2c90b5a299b0f30..13eb02ca012f44b2b5ed7c6f5becb7d54b07c33c 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -54,6 +54,7 @@ std::ostream& operator<<(std::ostream& out, const CallContext& context) { CallContext GetInstructionCallContext(const HloInstruction* instruction) { switch (instruction->opcode()) { case HloOpcode::kCall: + case HloOpcode::kConditional: case HloOpcode::kWhile: return CallContext::kSequential; case HloOpcode::kMap: diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc index 0395ea8c8b52315f7ca2221f412750ebadda2dd8..1ea7d538cd515c3098b6a1f03c6146d288330406 100644 --- a/tensorflow/compiler/xla/service/call_graph_test.cc +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -34,12 +34,13 @@ using ::testing::UnorderedElementsAre; class CallGraphTest : public HloTestBase { protected: // Build and return a trivial computation taking and returning a scalar. - std::unique_ptr MakeScalarComputation() { + std::unique_ptr MakeScalarComputation( + HloOpcode opcode = HloOpcode::kNegate) { HloComputation::Builder builder(TestName() + ".ScalarComputation"); HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, kScalarShape, "param0")); builder.AddInstruction( - HloInstruction::CreateUnary(kScalarShape, HloOpcode::kNegate, param0)); + HloInstruction::CreateUnary(kScalarShape, opcode, param0)); return builder.Build(); } @@ -236,6 +237,54 @@ TEST_F(CallGraphTest, ContextBothComputations) { EXPECT_EQ(CallContext::kBoth, sub_node.context()); } +TEST_F(CallGraphTest, ComputationWithConditional) { + // Test a call graph of a module with a conditional. + auto module = CreateNewModule(); + HloComputation* true_computation = + module->AddEmbeddedComputation(MakeScalarComputation(HloOpcode::kCeil)); + HloComputation* false_computation = + module->AddEmbeddedComputation(MakeScalarComputation(HloOpcode::kFloor)); + + HloComputation::Builder builder(TestName()); + HloInstruction* pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction* const1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(56.4f))); + HloInstruction* const2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(12.6f))); + HloInstruction* conditional = + builder.AddInstruction(HloInstruction::CreateConditional( + kScalarShape, pred, const1, true_computation, const2, + false_computation)); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + + std::unique_ptr call_graph = CallGraph::Build(module.get()); + + EXPECT_EQ(3, call_graph->nodes().size()); + + const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); + EXPECT_EQ(entry_computation, entry_node.computation()); + EXPECT_EQ(1, entry_node.callsites().size()); + + const CallSite& conditional_callsite = entry_node.callsites()[0]; + EXPECT_EQ(conditional, conditional_callsite.instruction()); + EXPECT_THAT(conditional_callsite.called_computations(), + UnorderedElementsAre(true_computation, false_computation)); + EXPECT_EQ(CallContext::kSequential, conditional_callsite.context()); + EXPECT_EQ(entry_node.GetCallSite(conditional), &conditional_callsite); + + const CallGraphNode& true_node = call_graph->GetNode(true_computation); + EXPECT_TRUE(true_node.callees().empty()); + EXPECT_EQ(1, true_node.callers().size()); + EXPECT_EQ(entry_computation, true_node.callers()[0]); + + const CallGraphNode& false_node = call_graph->GetNode(false_computation); + EXPECT_TRUE(false_node.callees().empty()); + EXPECT_EQ(1, false_node.callers().size()); + EXPECT_EQ(entry_computation, false_node.callers()[0]); +} + TEST_F(CallGraphTest, ComplexGraph) { // Test a call graph of a module with several computation called in various // contexts. The call graph looks like: diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 5f021900c8b647077661da1cdec9d462bbb0146e..fc67330f5cbdbcb0d1a259d284599916a908d1fe 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -97,21 +97,32 @@ class Compiler { // Returns the ID of the platform that this compiler targets. virtual perftools::gputools::Platform::Id PlatformId() const = 0; + // Runs Hlo passes to optimize the given Hlo module, returns the optimized + // module. + virtual StatusOr> RunHloPasses( + std::unique_ptr module, + perftools::gputools::StreamExecutor* executor) = 0; + // Compiles the HLO module for execution on a device given by the executor, - // and returns an executable object or an error status. Takes ownership of the - // HLO module and is free to transform it. + // and returns an executable object or an error status. No HLO passes are + // applied to module. Generally a module should be passed through RunHloPasses + // prior to calling this method because the some HLO passes are required for + // correctness. Takes ownership of the HLO module and is free to transform it. // // The compiler may optionally specialize to the individual device // (not just type of device) indicated by the executor. // // Use the overload below to compile computations that run in parallel. - virtual StatusOr> Compile( + virtual StatusOr> RunBackend( std::unique_ptr module, perftools::gputools::StreamExecutor* executor) = 0; // Compiles a set of HLO modules that can run in parallel, potentially // communicating data between the modules, and returns a corresponding // sequence of executable objects. + // + // TODO(b/68666782): Remove this method after adding support for multiple + // modules to RunHloPasses and RunBackends. virtual StatusOr>> Compile( std::vector> modules, std::vector> diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc index 6b7b0d25e87edf39d9f3c0c19305ebe8f173bafe..657fba6b6231104bf47f9dec80f7cd36a0ba3efd 100644 --- a/tensorflow/compiler/xla/service/computation_placer.cc +++ b/tensorflow/compiler/xla/service/computation_placer.cc @@ -52,6 +52,12 @@ Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const { /* static */ StatusOr> DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) { TF_RET_CHECK(proto.computation_devices_size() == proto.computation_count()); + if (proto.replica_count() <= 0 || proto.computation_count() <= 0) { + return InvalidArgument( + "Invalid device assignment topology: replica_count=%d, " + "computation_count=%d", + proto.replica_count(), proto.computation_count()); + } auto assignment = MakeUnique(proto.replica_count(), proto.computation_count()); for (int computation = 0; computation < proto.computation_count(); diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 0453a698a09b740d68b35258ede7c537fcf290d4..cd983bc03e993caed883916de01d75dffdbc4bab 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -15,15 +15,17 @@ limitations under the License. #include "tensorflow/compiler/xla/service/copy_insertion.h" -#include - +#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" +#include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -31,597 +33,1174 @@ limitations under the License. #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { +using ::tensorflow::str_util::Join; +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; + namespace { -using tensorflow::gtl::FlatMap; -using tensorflow::gtl::FlatSet; +bool IsEntryParameterValue(const HloValue& value) { + const HloComputation* computation = value.defining_instruction()->parent(); + return value.defining_instruction()->opcode() == HloOpcode::kParameter && + computation == computation->parent()->entry_computation(); +} + +bool IsConstantValue(const HloValue& value) { + return value.defining_instruction()->opcode() == HloOpcode::kConstant; +} + +bool ValueIsReadOnly(const HloValue& value) { + return IsConstantValue(value) || IsEntryParameterValue(value); +} -// InstructionCopier encapsulates indices at which to copy 'instruction'. -// All 'instruction' users in 'copy_users' are updated to use the copy. +// Deep copy the given instructions 'from' and 'to' at the ShapeIndexes given in +// 'indices_to_copy'. Add control edges from the respective kCopy instructions +// in deep copy of 'from' to the respective kCopy instruction in the deep copy +// of 'to'. // -// Instruction copies are generated in two phases: -// 1) Recording buffer indices at which 'instruction' requires copies (i.e. -// setting 'indices_to_copy_[index]'=true). -// 2) Inserting kCopy instructions based on indices recorded in phase 1). -// *) Array instructions are copied by inserting a single kCopy instruction. -// *) Tuple-shaped instructions are copied by recursively expanding tuples -// (and tuple-shaped elements), and inserting kCopy instructions for any -// tuple elements which require a copy. As the recursion unwinds, new tuple -// instructions are added to gather the copied (and uncopied) references -// into the output tuple (i.e. the copy of the tuple-shaped instruction). +// Requirements: 'from' and 'to' must have compatible shapes. // -// Example two-element tuple with one element that needs a copy: +// For example, suppose 'from' and 'to' are two-element tuples where index 0 is +// the only index to copy. Prior to deep-copying we have: // -// original-instruction -// / \ -// GTE(0) GTE(1) -// | | -// Copy | -// \ / -// Tuple // copied-instruction // -// As an optimization, if the original instruction is itself a Tuple -// instruction, we elide the unnecessary extra GTE and Tuple instructions, -// and just insert the copy into a new Tuple instruction, with control -// dependencies to ensure the copy occurs after any possible interference. -class InstructionCopier { - public: - InstructionCopier(HloInstruction* instruction, - const std::vector& copy_users) - : instruction_(instruction), - copy_users_(copy_users), - indices_to_copy_(instruction->shape()), - control_predecessors_(instruction->shape()) {} - - // Sets indices that are read-only, and thus do not need to be copied. - void SetReadOnlyIndices(const ShapeTree& read_only_indices) { - read_only_indices_ = read_only_indices; - } +// 'from' +// | +// ... +// | +// 'to' +// +// DeepCopyAndAddControlEdges produces: +// +// 'from' +// / \ +// GTE GTE +// | | +// Copy | +// / \ / +// | Tuple +// | | +// ctrl ... +// edge | +// | | +// | 'to' +// | / \ +// | GTE GTE +// \ | | +// Copy | +// \ / +// Tuple +// +StatusOr> +DeepCopyAndAddControlEdges(HloInstruction* from, HloInstruction* to, + const ShapeTree& indices_to_copy) { + DCHECK(ShapeUtil::Compatible(from->shape(), to->shape())); + // to/from_copy_tree hold the kCopy instruction produces by the deep + // copies. Elements which are not copied (indices_to_copy.element(index) == + // false) have nullptr at that index. + ShapeTree from_copy_tree(from->shape(), + /*init_value=*/nullptr); + TF_ASSIGN_OR_RETURN(HloInstruction * from_deep_copy, + from->parent()->DeepCopyInstruction( + from, &indices_to_copy, &from_copy_tree)); - // Sets copy overrides, which are copy instructions to use at each index. This - // is used to share a single copy of read-only entry parameters and constants - // between multiple While loops. - void SetCopyOverrides(const ShapeTree& copy_overrides) { - copy_overrides_ = copy_overrides; + ShapeTree to_copy_tree(to->shape(), /*init_value=*/nullptr); + TF_ASSIGN_OR_RETURN( + HloInstruction * to_deep_copy, + to->parent()->DeepCopyInstruction(to, &indices_to_copy, &to_copy_tree)); + + // Add control edges between the respective kCopy instructions. + for (const auto& pair : from_copy_tree) { + const ShapeIndex& index = pair.first; + HloInstruction* from_copy = pair.second; + HloInstruction* to_copy = to_copy_tree.element(index); + if (from_copy == nullptr) { + TF_RET_CHECK(to_copy == nullptr); + continue; + } + TF_RET_CHECK(to_copy != nullptr); + TF_RETURN_IF_ERROR(from_copy->AddControlDependencyTo(to_copy)); } - // Returns true if all recorded indices are false (returns true otherwise). - bool HasAllIndicesFalse() const; + return std::make_pair(from_deep_copy, to_deep_copy); +} - // Records instruction buffer indices which point-to a Parameter or Constant. - Status RecordIndicesWhichPointToParamOrConstant( - const TuplePointsToAnalysis& points_to_analysis); +// Compute the indices of the loop state which need copies in order to avoid +// live range interference. Generally, an element in the loop state does not +// need to be copied if the element is passed through transparently through the +// body. +// +// Returns whether any indices need to be copied. +bool IndicesToCopyForWhile(const HloDataflowAnalysis& dataflow, + const HloInstruction* xla_while, + ShapeTree* indices_to_copy) { + DCHECK(ShapeUtil::Compatible(indices_to_copy->shape(), xla_while->shape())); - // Records instruction buffer indices to copy which are necessary to ensure: - // *) PointsToSet of 'instruction_' is unambiguous and distinct. - // *) No liveness interference between 'instruction_' and 'other_instruction'. - // - // If 'read_only_indices_out' is non-null, read-only indices are set to true. - Status RecordIndicesToCopyForColocatingBuffers( - const BufferLiveness& liveness, const HloInstruction* other_instruction, - ShapeTree* read_only_indices_out); + bool any_copies = false; + const HloInstruction* init = xla_while->operand(0); + for (auto& pair : *indices_to_copy) { + const ShapeIndex& index = pair.first; + bool& should_copy = pair.second; + // If there is any ambiguity, then loop state must be copied. + if (dataflow.GetValueSet(init, index).values().size() > 1 || + dataflow.GetValueSet(xla_while, index).values().size() > 1) { + should_copy = true; + } else { + // If the output of the while instruction is not the same as the init + // value of the while, then this element is not passed through the body + // transparently and must be copied. + should_copy = dataflow.GetUniqueValueAt(xla_while, index) != + dataflow.GetUniqueValueAt(init, index); + } + any_copies |= should_copy; + } + return any_copies; +} - // Records control predecessors to add for inserted copy instructions. - // 'parameter' must have the same shape as the instruction that will be - // copied, and must define all buffers in the shape. Control predecessors are - // only recorded for indices that have already been marked for copying. - Status RecordControlPredecessors( - const TuplePointsToAnalysis& points_to_analysis, - HloInstruction* parameter); +// Add kCopy instructions around the given kWhile instruction to eliminate any +// possible live range interference of HLO values assuming a dependency-based +// ordering (HloDependencyOrdering). Copies are added conservatively. There +// likely are copies which are not strictly necessary, but there are removed +// later in the pass via CopyRemover. +// +// +// Elements (each ShapeIndex) in the loop state are considered independently. A +// copy is added to each element of the loop state which is modified in the +// while body. For each such element, a total of three kCopy instructions are +// added at following locations: +// +// (1) The init value is copied before the kWhile instruction. Before: +// +// (Init) +// | +// kWhile +// | +// ... +// +// After: +// +// (Init) +// | +// kCopy +// | +// kWhile +// | +// ... +// +// This copy is necessary in case the init value is simultaneously live +// with the kWhile. +// +// (2) Copies are added to the parameter and root of the while body +// computation. Before: +// +// kParameter +// | +// ... +// | +// (body root) +// +// After: +// +// kParameter +// | +// kCopy ----------+ +// | | +// ... ctrl +// | edge +// (body root) | +// | | +// kCopy <---------+ +// +// The root kCopy becomes the new root of the computation. Both copies are +// necessary to any potential interference between the parameter value and +// the root value. The control edge prevents potential interference +// between the copies themselves. +// +// If the loop state is a tuple then the above kCopy instructions are a deep +// copy constructed of kCopy, KGetTupleElement, and kTuple instruction as +// constructed by HloInstruction::DeepCopyInstruction. +Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis, + HloInstruction* xla_while) { + VLOG(2) << "Adding copies for kWhile instruction " << xla_while->name(); + TF_RET_CHECK(xla_while->opcode() == HloOpcode::kWhile); - // Inserts copies of 'instruction' buffers at indices in 'indices_to_copy', - // and replaces all uses for instructions in 'copy_users_' with copy. - // Returns the instruction which is a copy 'instruction'. - HloInstruction* Copy(); + ShapeTree indices_to_copy(xla_while->shape()); + if (!IndicesToCopyForWhile(alias_analysis.dataflow_analysis(), xla_while, + &indices_to_copy)) { + VLOG(2) << "No copies necessary for kWhile instruction " + << xla_while->name(); + return Status::OK(); + } - HloInstruction* instruction() { return instruction_; } + VLOG(2) << "Adding copies for " << xla_while->name() << " at indices:"; + for (auto& pair : indices_to_copy) { + if (pair.second) { + VLOG(2) << " " << pair.first; + } + } - const std::vector& copy_users() const { return copy_users_; } + // Deep copy init. + HloInstruction* while_init = xla_while->mutable_operand(0); + TF_ASSIGN_OR_RETURN( + HloInstruction * while_init_copy, + xla_while->parent()->DeepCopyInstruction(while_init, &indices_to_copy)); + TF_RETURN_IF_ERROR(while_init->ReplaceUseWith(xla_while, while_init_copy)); - private: - // Does the given index represent a read-only buffer? - bool IsReadOnlyIndex(const ShapeIndex& index) const { - return !ShapeUtil::IsNil(read_only_indices_.shape()) && - read_only_indices_.element(index); - } + // Deep copy the parameter and the root. Extend a control edge from the copy + // of the parameter value to the corresponding copy value of the root. + HloComputation* body = xla_while->while_body(); + HloInstruction* param = body->parameter_instruction(0); + HloInstruction* root = body->root_instruction(); - // Returns the copy override at the given index, or nullptr. - HloInstruction* GetCopyOverride(const ShapeIndex& index) const { - return ShapeUtil::IsNil(copy_overrides_.shape()) - ? nullptr - : copy_overrides_.element(index); - } + // If param is the root then all indices should have been passed through the + // while body and we should have returned early above. + TF_RET_CHECK(param != root); - // Records instruction buffer indices which have ambiguous or non-distinct - // points-to sets. - Status RecordAmbiguousOrNonDistinctIndices( - const TuplePointsToAnalysis& points_to_analysis); + // Copy users before making a deep copy of the parameter as the deep copy + // will create new users of the parameter (eg, the GTE instructions of the + // deep copy). + std::vector param_users = param->users(); - // Records instruction buffer indices which have interfering live ranges - // with 'other_instruction' buffers at same index. - Status RecordIndicesWhichInterfereWithOtherInstruction( - const BufferLiveness& liveness, const HloInstruction* other_instruction, - ShapeTree* read_only_indices_out); + ShapeIndex current_index; + TF_ASSIGN_OR_RETURN(auto pair, + DeepCopyAndAddControlEdges(param, root, indices_to_copy)); - // Recursively inserts copies of 'instruction' tuple elements at indices - // specified in 'indices_to_copy', and returns the copy of 'instruction'. - HloInstruction* CopyTuple(HloInstruction* instruction, ShapeIndex* index); + HloInstruction* param_copy = pair.first; + HloInstruction* root_copy = pair.second; - void RecordIndex(const ShapeIndex& index) { - *indices_to_copy_.mutable_element(index) = true; + for (HloInstruction* user : param_users) { + TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, param_copy)); } - HloInstruction* instruction_; - const std::vector copy_users_; - ShapeTree indices_to_copy_; - ShapeTree> control_predecessors_; - ShapeTree read_only_indices_; - ShapeTree copy_overrides_; -}; + body->set_root_instruction(root_copy); -bool InstructionCopier::HasAllIndicesFalse() const { - bool all_indices_false = true; - indices_to_copy_.ForEachElement( - [&all_indices_false](const ShapeIndex& /*index*/, bool data) { - if (data) { - all_indices_false = false; - } - }); - return all_indices_false; + return Status::OK(); } -Status InstructionCopier::RecordIndicesWhichPointToParamOrConstant( - const TuplePointsToAnalysis& points_to_analysis) { - const PointsToSet& points_to = - points_to_analysis.GetPointsToSet(instruction_); - // Shallow copy the instruction if the points-to set of the top-level - // buffer is ambiguous. This is necessary because the backends must know - // statically what the top-level buffer of the result is. - if (points_to.element(/*index=*/{}).size() > 1) { - RecordIndex({}); +// Removes any control dependencies to or from the given instruction. +Status StripControlDependenciesFrom(HloInstruction* instruction) { + while (!instruction->control_successors().empty()) { + TF_RETURN_IF_ERROR(instruction->RemoveControlDependencyTo( + instruction->control_successors().front())); + } + + while (!instruction->control_predecessors().empty()) { + TF_RETURN_IF_ERROR( + instruction->control_predecessors().front()->RemoveControlDependencyTo( + instruction)); } - // Multiple buffers within a parameter/constant may be live out, so collect - // a set of indices at which to copy first. - points_to.ForEachElement([this](const ShapeIndex& index, - const PointsToSet::BufferList& buffers) { - if (IsReadOnlyIndex(index)) { - return; - } - for (const LogicalBuffer* buffer : buffers) { - // pointee is the HloInstruction producing the buffer which may be - // liveout. - HloInstruction* pointee = buffer->instruction(); - if (pointee->opcode() == HloOpcode::kParameter || - pointee->opcode() == HloOpcode::kConstant) { - VLOG(2) << "Parameter or constant buffer " << buffer->ToString() - << " index: " << tensorflow::str_util::Join(index, ",") - << " may be live out of computation: " << pointee->ToString(); - RecordIndex(index); - break; - } - } - }); return Status::OK(); } -Status InstructionCopier::RecordIndicesToCopyForColocatingBuffers( - const BufferLiveness& liveness, const HloInstruction* other_instruction, - ShapeTree* read_only_indices_out) { - TF_RETURN_IF_ERROR( - RecordAmbiguousOrNonDistinctIndices(liveness.points_to_analysis())); - TF_RETURN_IF_ERROR(RecordIndicesWhichInterfereWithOtherInstruction( - liveness, other_instruction, read_only_indices_out)); +// Add kCopy instructions to the given module to guarantee there is no +// live-range interference. Generally interference can only occur around kWhile +// instructions which have update-in-place semantics. +Status AddCopiesToResolveInterference(HloModule* module) { + TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, + HloAliasAnalysis::Run(module)); + + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kWhile) { + TF_RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction)); + } + } + } return Status::OK(); } -Status InstructionCopier::RecordAmbiguousOrNonDistinctIndices( - const TuplePointsToAnalysis& points_to_analysis) { - const PointsToSet& points_to = - points_to_analysis.GetPointsToSet(instruction_); - // Mapping from LogicalBuffer to index (used to detect non-distinct indices). - FlatMap> - buffer_to_source_indices; - points_to.ForEachElement( - [this, &buffer_to_source_indices]( - const ShapeIndex& index, const PointsToSet::BufferList& buffers) { - if (buffers.size() > 1) { - // Record ambiguous points-to set at 'index'. - if (!indices_to_copy_.element(index)) { - VLOG(2) << "Adding copy of buffer for instruction: " - << instruction_->name() - << " at index: " << tensorflow::str_util::Join(index, ",") - << " with ambiguous points-to set."; - RecordIndex(index); +// Class for removing unnecessary copies from the module. +// +// kCopy instructions are added conservatively to guarantee no live range +// interference between HLO values. This class uses a more fine-grained analysis +// to remove some of these added copies which are not strictly necessary. +class CopyRemover { + public: + CopyRemover(const HloAliasAnalysis& alias_analysis, + const HloOrdering& ordering, HloModule* module) + : module_(module), + alias_analysis_(alias_analysis), + ordering_(ordering), + buffer_value_tracker_(*module, alias_analysis, ordering) {} + + // Try to elide the given copy. The copy is elided if the instruction is not + // necessary to prevent live-range interference of HLO values. Returns true if + // copy was elided. + // + // The copy instruction is not actually removed here. Instead it is left for + // dead in the graph. Later calls to DCE will remove the instruction. + StatusOr TryElideCopy(HloInstruction* copy) { + if (buffer_value_tracker_.TryElideCopy(copy)) { + TF_RETURN_IF_ERROR(StripControlDependenciesFrom(copy)); + TF_RETURN_IF_ERROR(copy->ReplaceAllUsesWith(copy->mutable_operand(0))); + return true; + } + return false; + } + + string ToString() const { + string out = StrCat("CopyRemover, module ", module_->name(), "\n"); + StrAppend(&out, " Buffer values, in dependency order:\n"); + for (const HloBuffer& buffer : alias_analysis_.buffers()) { + StrAppend(&out, " HloBuffer ", buffer.id(), ":\n"); + } + return out; + } + + private: + // Class which tracks the HLO values within each HLO buffer in the module + // during copy removal. + // + // The values are held in a linked list where there is one list for each + // buffer. Removing a copy instruction merges together the values in the + // source buffer of the copy to the destination buffer of the copy. This class + // tracks these value lists as copies are removed from the graph (and value + // lists are merged). + // + // The BufferValueTracker object is initialized to match the state of + // HloAliasAnalysis. However, as copies are removed this state diverges. The + // values-to-buffer mapping is maintained outside of HloAliasAnalysis because + // a fully updatable alias analysis is very slow. + class BufferValueTracker { + public: + // The values held in a single HLO buffer are represented using a linked + // list. An element type in this list is ValueNode. + // + // This linked list is hand-rolled to enable efficient splicing of lists + // using only references to list elements without knowing which lists are + // being spliced. std::list requires a reference to the list object to + // splice. + struct ValueNode { + explicit ValueNode(const HloValue* v) : value(v) {} + + const HloValue* value; + + // The uses are maintained outside of HloValue::uses() because + // HloValue::uses() is not updatable (a fully updatable dataflow analysis + // is slow). + std::vector uses; + + // next/prev elements in the linked list. The list is circularly linked so + // these values are never null for elements in the list. + ValueNode* prev = nullptr; + ValueNode* next = nullptr; + }; + + BufferValueTracker(const HloModule& module, + const HloAliasAnalysis& alias_analysis, + const HloOrdering& ordering) + : dataflow_(alias_analysis.dataflow_analysis()), ordering_(ordering) { + // Construct a list for each HLO buffer in the alias analysis. Maintain a + // map from HloValue to the respective list element representing that + // value. The map is used to construct the copy info map below. + tensorflow::gtl::FlatMap value_to_node; + for (const HloBuffer& buffer : alias_analysis.buffers()) { + // Verify values contained in the buffer are strictly ordered. This + // should always be the case after adding copies to eliminate + // interference. Specifically, the addition of the control flow edges + // between copies added around aliased operations (kWhile) guarantees + // this strict order. + for (const HloValue* value_a : buffer.values()) { + for (const HloValue* value_b : buffer.values()) { + if (value_a != value_b) { + DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b, + dataflow_) || + ordering_.LiveRangeStrictlyBefore(*value_b, *value_a, + dataflow_)) + << value_a->ToShortString() << " and " + << value_b->ToShortString() << " are not ordered"; + } } } - // For each 'buffer': record a mapping from 'buffer' to 'index'. - for (const LogicalBuffer* buffer : buffers) { - buffer_to_source_indices[buffer].push_back(index); - } - }); - // Record all non-distinct indices detected in 'buffer_to_source_indices'. - for (const auto& buff_to_src : buffer_to_source_indices) { - if (buff_to_src.second.size() == 1) { - continue; + std::vector values = buffer.values(); + std::sort(values.begin(), values.end(), + [this](const HloValue* a, const HloValue* b) { + return ordering_.IsDefinedBefore(*a, *b); + }); + + // Create a list containing all of the values in the buffer. + AddValueList(values, &value_to_node); + } + + // Create copy_map_ which contains the source and destination values + // of all copies. + CreateCopyMap(module, value_to_node); + + XLA_VLOG_LINES(3, ToString()); + TF_DCHECK_OK(Verify()); } - for (const ShapeIndex& src_index : buff_to_src.second) { - // Record non-distinct points-to set at 'src_index'. - if (!indices_to_copy_.element(src_index)) { - VLOG(2) << "Adding copy of buffer for instruction: " - << instruction_->name() - << " at index: " << tensorflow::str_util::Join(src_index, ",") - << " because of non-distinct points-to set."; - RecordIndex(src_index); + + // Add a list containing the given values to BufferValueTracker. This + // represents the values contained in a single buffer. For each value in + // 'values' an entry is created in value_to_node which indicates the + // respective ValueNode representing that value. + void AddValueList( + tensorflow::gtl::ArraySlice values, + tensorflow::gtl::FlatMap* value_to_node) { + ValueNode* tail = nullptr; + ValueNode* head = nullptr; + for (const HloValue* value : values) { + auto new_node = new ValueNode(value); + (*value_to_node)[value] = new_node; + + // Copy the HLO values's uses into the ValueNode for the value. These + // uses in ValueNode are updated as copies are removed. + new_node->uses.reserve(value->uses().size()); + for (const HloUse& use : value->uses()) { + new_node->uses.push_back(&use); + } + + // Connect the new node into the linked list. + if (tail == nullptr) { + head = new_node; + } else { + tail->next = new_node; + new_node->prev = tail; + } + tail = new_node; } + + // The linked list is circular so connect the head and tail. + tail->next = head; + head->prev = tail; + value_lists_.insert(head); } - } - return Status::OK(); -} -Status InstructionCopier::RecordIndicesWhichInterfereWithOtherInstruction( - const BufferLiveness& liveness, const HloInstruction* other_instruction, - ShapeTree* read_only_indices_out) { - // Record all buffer indices for 'instruction_', which interfere with - // 'other_instruction' at the same index. - ShapeUtil::ForEachSubshape( - instruction_->shape(), - [this, &liveness, other_instruction, read_only_indices_out]( - const Shape& /*subshape*/, const ShapeIndex& index) { - if (IsReadOnlyIndex(index)) { - return; + // This method also fills in copy_map_ which indicates which nodes + // in the value lists corresponding to the source and destination values of + // kCopy instructions. value_to_node should map each HloValue to its + // respective ValueNode. + void CreateCopyMap( + const HloModule& module, + const tensorflow::gtl::FlatMap& + value_to_node) { + for (HloComputation* computation : module.computations()) { + for (HloInstruction* instruction : computation->instructions()) { + // Add copies with unambiguous source values to the map. Copies with + // ambiguous sources are not removable. + if (instruction->opcode() == HloOpcode::kCopy) { + const HloValueSet& src_value_set = + dataflow_.GetValueSet(instruction->operand(0)); + if (src_value_set.values().size() == 1) { + CopyNodes& copy_node = copy_map_[instruction]; + copy_node.dest = + value_to_node.at(&dataflow_.GetUniqueValueAt(instruction)); + copy_node.src = value_to_node.at(&src_value_set.GetUniqueValue()); + } + } } - if (indices_to_copy_.element(index)) { - // Return if previous pass already set index. - return; + } + } + + ~BufferValueTracker() { + for (const ValueNode* head : value_lists_) { + const ValueNode* p = head; + do { + const ValueNode* tmp = p->next; + delete p; + p = tmp; + } while (p != head); + } + } + + // Verify invariants within the linked lists. + Status Verify() const { + for (const ValueNode* head : value_lists_) { + const ValueNode* p = head; + do { + // Verify links between elements are consistent. + TF_RET_CHECK(p->prev->next == p); + TF_RET_CHECK(p->next->prev == p); + + const HloInstruction* def = p->value->defining_instruction(); + if (def->opcode() == HloOpcode::kCopy && + ContainsKey(copy_map_, def)) { + TF_RET_CHECK(copy_map_.at(def).dest == p); + } + for (const HloUse* use : p->uses) { + if (use->instruction->opcode() == HloOpcode::kCopy && + ContainsKey(copy_map_, use->instruction)) { + TF_RET_CHECK(copy_map_.at(use->instruction).src == p); + } + } + + p = p->next; + } while (p != head); + } + return Status::OK(); + } + + // Try to elide the given copy. Elision of a copy is possible only if no + // live range interference is introduced by the copy's elimination. If + // elision is possible, then the internal state (value lists) are updated, + // and true is returned. Returns false otherwise. + bool TryElideCopy(const HloInstruction* copy) { + VLOG(2) << "Trying to remove " << copy->name(); + + if (!ContainsKey(copy_map_, copy)) { + VLOG(2) << copy->name() << " is not removable"; + return false; + } + + const CopyNodes& copy_node = copy_map_.at(copy); + ValueNode* src = copy_node.src; + ValueNode* dest = copy_node.dest; + DCHECK(src != nullptr); + DCHECK(dest != nullptr); + + auto is_live_range_before = [this](const ValueNode& a, + const ValueNode& b) { + if (LiveRangeBefore(a, b)) { + VLOG(2) << " Live range of " << a.value->ToShortString() + << " is before " << b.value->ToShortString(); + return true; + } else { + VLOG(2) << " Live range of " << a.value->ToShortString() + << " is not before " << b.value->ToShortString(); + return false; } - const auto& points_to_analysis = liveness.points_to_analysis(); - // Lookup buffers for 'instruction_' and 'other_instruction'. - const auto instruction_buffers = - points_to_analysis.GetPointsToSet(instruction_).element(index); - // If 'instruction_' has ambiguous points-to-set at 'index', it would - // have been recorded in a previous pass (and we would have returned - // early at the entry to this function). As a result, here we know that - // 'instruction_' has just one buffer in its points-to-set. - CHECK_EQ(1, instruction_buffers.size()); - const LogicalBuffer* instruction_buffer = instruction_buffers[0]; - - const auto other_instruction_buffers = - points_to_analysis.GetPointsToSet(other_instruction).element(index); - // Do not insert a copy if both instructions point at the same buffer. - // This eliminates unnecessary copies of read-only tuple elements. - // If 'instruction_' and 'other_instruction' point to the same buffer, - // then that buffer is not updated on the path between the two - // instructions. Therefore, any other (possibly interference-causing) - // users of that buffer from 'other_instruction' will see the same data, - // irrespective of whether we insert a copy of this buffer at - // 'instruction_' or not. - if (other_instruction_buffers.size() == 1 && - other_instruction_buffers[0]->id() == instruction_buffer->id()) { - if (read_only_indices_out != nullptr) { - *read_only_indices_out->mutable_element(index) = true; + }; + + VLOG(3) << copy->name() << " copies value " + << src->value->ToShortString(); + VLOG(3) << "Source buffer values: " << ValueListToString(src); + VLOG(3) << "Dest buffer values: " << ValueListToString(src); + + // A kCopy instruction copies an HLO value from a source buffer and + // defines an HLO value in a destination buffer. Most generally, the + // source and destination buffers may each hold more than one value at + // different points in the computation so we define the following: + // + // Values in source buffer: {s_0, ..., s_n} + // Values in destination buffer: {d_0, ..., d_m} + // + // A kCopy instruction between these buffers copies a value s_x in the + // source buffer and defines a value d_y in the destination buffer. The + // elision of a copy merges the source and destination buffers together, + // so the list of values for the source and destination buffers are + // merged. + // + // We handle two different cases for copy elision: + // + // (1) the kCopy defines the first value in the destination buffer (d_0). + // + // (2) the kCopy copies the last value in the source buffer (s_n). + // + // For the remaining case where the kCopy copies a not-last value from the + // source buffer to a not-first value of the destination buffer, the kCopy + // instruction cannot be removed. This case is generated, for example, if + // the kCopy copies a while body parameter of the loop state at one tuple + // index to a different tuple index in the while body root. Removal of the + // copy necessarily results in live range interference of values in the + // loop state at the two different tuple indices. + // + // We can only perform copy elision if the resulting merged values have + // totally ordered live ranges; otherwise the merged buffer would have + // live range interference. + if (IsHead(*dest)) { + // The copy copies an arbitrary value in the source buffer (call it s_x) + // and defines d_0, the first value in the destination buffer. After + // merging, the values in the combined buffer must be strictly ordered + // as follows** to elide the copy: + // + // {s_0, ..., s_x, d_1, ..., d_m, s_{x+1}, ..., s_n} + // + // Removing the copy eliminates d_0, and uses of d_0 become uses of + // s_x. In the above ordering, the live range of d_m must be ordered + // before the live range of s_{x+1} and the definition and all uses of + // s_x must be ordered before the definition of d_1. These conditions + // are checked below prior to elision. + // + // ** Technically it might be possible to have a non-interfering + // non-trivial interleaving of the values of the source and + // destination buffers in the resulting order. However, this case is + // slow and complicated to check and likely not worth it. So instead + // we simply check for the case where *all* values of the destination + // buffer (d_1 through d_m) are spliced into the point where the copy + // used to be. + VLOG(2) << copy->name() << " defines the first value in its buffer"; + ValueNode* next_dest = Next(*dest); + if (next_dest != nullptr) { + // Live range of 'from' value (s_x) must be before 'next_dest' (d_1); + if (!is_live_range_before(*src, *next_dest)) { + return false; } - return; } - // We can't say anything about the ambiguity of 'other_instruction' at - // this point, so we need to check interference between the single - // buffer in the points-to set of 'instruction_' and all buffers in - // 'other_instruction_buffers'. - for (const LogicalBuffer* other_buffer : other_instruction_buffers) { - if (liveness.MayInterfere(*instruction_buffer, *other_buffer)) { - VLOG(2) << "Adding copy of buffer for instruction: " - << instruction_->name() - << " instruction_buffer: " << instruction_buffer->ToString() - << " at index: " << tensorflow::str_util::Join(index, ",") - << " because of interference with buffer: " - << other_buffer->ToString(); - RecordIndex(index); - break; + ValueNode* next_src = Next(*src); + + if (next_src != nullptr) { + // Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}. + ValueNode* last_dest = dest->prev; + DCHECK(IsTail(*last_dest)); + if (!is_live_range_before(*last_dest, *next_src)) { + return false; } } - }); - return Status::OK(); -} -// This is called when 'instruction_' is a while body root, and 'parameter' is -// the while body parameter. We record all users of all aliases of 'parameter' -// as control predecessors, so that when we add a copy of 'instruction_', we can -// mark the control dependencies. This is necessary because points-to and -// liveness analysis doesn't know about the aliasing between the while body root -// and param. Without these control dependencies, the copy might get scheduled -// to run at a point that interferes with users of the buffer. -Status InstructionCopier::RecordControlPredecessors( - const TuplePointsToAnalysis& points_to_analysis, - HloInstruction* parameter) { - return indices_to_copy_.ForEachElementWithStatus( - [this, &points_to_analysis, parameter](const ShapeIndex& index, - bool will_copy) { - if (will_copy) { - TF_ASSIGN_OR_RETURN( - const LogicalBuffer* buffer, - points_to_analysis.GetBufferDefinedAt(parameter, index)); - for (const BufferAlias& alias : - points_to_analysis.GetBufferAliases(*buffer)) { - for (HloInstruction* user : alias.instruction()->users()) { - if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), - user, points_to_analysis)) { - continue; - } - - if (user != instruction_) { - control_predecessors_.mutable_element(index)->push_back(user); - } - } + // Splice in destination buffer values list right after 'src'. + SpliceAfter(dest, src); + } else if (IsTail(*src)) { + // The copy copies the last value in the source buffer, s_n, and defines + // an arbitrary value in the destination buffer, d_y. After + // merging, the values in the combined buffer must be strictly ordered + // as follows** to elide the copy: + // + // {d_0, ..., d_{y-1}, s_0, ..., s_n, d_{y+1}, ..., d_m} + // + // Removing the copy eliminates d_y, and uses of d_y become uses of + // s_n. To enforce the above order, the live range of d_{y-1} must be + // before the live range of s_0, and the live range of s_n must be + // before the live range of d_{y+1}. + // + // ** See comment above in the code handling Case (1). + VLOG(2) << copy->name() << " copies the last value (" + << src->value->ToShortString() << ") in its buffer"; + + ValueNode* prev_dest = Prev(*dest); + // nullptr condition handled above in the first 'if' case. + DCHECK(prev_dest != nullptr); + ValueNode* first_src = src->next; + DCHECK(IsHead(*first_src)); + if (!is_live_range_before(*prev_dest, *first_src)) { + // Live range of value d_{y-1} is not before s_0. + return false; + } + ValueNode* next_dest = Next(*dest); + if (next_dest != nullptr) { + if (!is_live_range_before(*src, *next_dest)) { + // Live range of value s_n is not before d_{y+1}. + return false; } } - return Status::OK(); - }); -} -// Recursively inserts copies of 'instruction' tuple element buffers at -// indices in 'indices_to_copy_', expanding tuples as needed. -HloInstruction* InstructionCopier::CopyTuple(HloInstruction* instruction, - ShapeIndex* index) { - const int64 num_tuple_elements = - ShapeUtil::TupleElementCount(instruction->shape()); - std::vector elem_copies(num_tuple_elements); - for (int64 i = 0; i < num_tuple_elements; ++i) { - HloInstruction* elem; - if (instruction->opcode() == HloOpcode::kTuple) { - // If the instruction is already a Tuple instruction, we know that the - // element buffers are aliased, so we can just grab the operand directly. - elem = instruction->mutable_operand(i); - } else { - // Otherwise we need to add a GTE to unpack the element out of the tuple. - elem = instruction->parent()->AddInstruction( - HloInstruction::CreateGetTupleElement( - ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction, - i)); - } - index->push_back(i); - if (ShapeUtil::IsTuple(elem->shape())) { - elem_copies[i] = CopyTuple(elem, index); - } else if (!indices_to_copy_.element(*index)) { - elem_copies[i] = elem; - } else if (HloInstruction* copy_override = GetCopyOverride(*index)) { - elem_copies[i] = copy_override; - } else { - HloInstruction* elem_copy = elem->parent()->AddInstruction( - HloInstruction::CreateUnary(elem->shape(), HloOpcode::kCopy, elem)); - for (HloInstruction* control_predecessor : - control_predecessors_.element(*index)) { - VLOG(2) << "Adding control dependency from " - << control_predecessor->ToString() << " to " - << elem_copy->ToString(); - TF_CHECK_OK(control_predecessor->AddControlDependencyTo(elem_copy)); + // Splice source buffer values list right after 'prev_dest'. + SpliceAfter(first_src, prev_dest); + } else { + VLOG(2) + << copy->name() + << " copies value in middle of source buffer to value in middle " + "of destination buffer"; + return false; } - elem_copies[i] = elem_copy; + + RemoveCopyValue(dest); + + XLA_VLOG_LINES(4, ToString()); + TF_DCHECK_OK(Verify()); + + return true; } - index->pop_back(); - } - return instruction->parent()->AddInstruction( - HloInstruction::CreateTuple(elem_copies)); -} -// Inserts copies of 'instruction_' buffers at indices in 'indices_to_copy_'. -HloInstruction* InstructionCopier::Copy() { - ShapeIndex index; - HloInstruction* copy; - if (ShapeUtil::IsTuple(instruction_->shape())) { - copy = CopyTuple(instruction_, &index); - } else { - copy = instruction_->parent()->AddInstruction(HloInstruction::CreateUnary( - instruction_->shape(), HloOpcode::kCopy, instruction_)); - } - for (HloInstruction* user : copy_users_) { - VLOG(2) << "Adding copy between instruction: " << instruction_->name() - << " and user: " << user->name(); - TF_CHECK_OK(instruction_->ReplaceUseWith(user, copy)); + // Delete the given ValueNode associated with a elided kCopy + // instruction. This should be called after splicing the value lists of the + // source and destination buffers together. + void RemoveCopyValue(ValueNode* copy_value_node) { + CHECK_EQ(copy_value_node->value->defining_instruction()->opcode(), + HloOpcode::kCopy); + ValueNode* operand_node = copy_value_node->prev; + CHECK(operand_node != copy_value_node); + + VLOG(2) << "Removing copy " << operand_node->value->ToShortString() + << " => " << copy_value_node->value->ToShortString(); + + // Splice out the copy value node. + operand_node->next = copy_value_node->next; + copy_value_node->next->prev = operand_node; + + // Patch up uses. Remove use of copy from operand_node uses. + auto it = + std::find_if(operand_node->uses.begin(), operand_node->uses.end(), + [copy_value_node](const HloUse* use) { + return use->instruction == + copy_value_node->value->defining_instruction(); + }); + CHECK(it != operand_node->uses.end()); + operand_node->uses.erase(it); + + // If the elided copy has any uses which are themselves kCopy instructions + // then patch up the copy info to reflect the that this kCopy instruction + // has a different operand (the operand of the elided copy). + for (const HloUse* copy_use : copy_value_node->uses) { + operand_node->uses.push_back(copy_use); + if (copy_use->instruction->opcode() == HloOpcode::kCopy) { + copy_map_.at(copy_use->instruction).src = operand_node; + } + } + + // Delete the copy info and the value node. + copy_map_.erase(copy_value_node->value->defining_instruction()); + delete copy_value_node; + } + + // Returns true if the live range of given value 'a' is before the live + // range of 'b'. + // + // We cannot use LiveRangeStrictlyBefore because HloValue::uses() is not + // updated as copies are removed. + bool LiveRangeBefore(const ValueNode& a, const ValueNode& b) { + if (a.uses.empty()) { + VLOG(2) << "Empty uses"; + return ordering_.IsDefinedBefore(*a.value, *b.value); + } + for (const HloUse* use : a.uses) { + VLOG(2) << "use: " << *use; + VLOG(2) << "is before:" << *b.value; + if (!ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_)) { + VLOG(2) << "Not before"; + return false; + } + } + return true; + } + + // Returns whether 'node' is the last node in its list. + bool IsTail(const ValueNode& node) const { + return ContainsKey(value_lists_, node.next); + } + + // Returns whether 'node' is the first node in its list. + bool IsHead(const ValueNode& node) const { + return ContainsKey(value_lists_, &node); + } + + // Returns the next node in the list after 'node'. If 'node' is the + // tail, then nullptr is returned. + ValueNode* Next(const ValueNode& node) const { + if (IsTail(node)) { + return nullptr; + } else { + return node.next; + } + } + + // Returns the previous node in the list before 'node'. If 'node' + // is the head, then nullptr is returned. + ValueNode* Prev(const ValueNode& node) const { + if (IsHead(node)) { + return nullptr; + } else { + return node.prev; + } + } + + // Splices the entire linked list with 'head' as its head right after the + // node 'insert_after' in another linked list. + void SpliceAfter(ValueNode* head, ValueNode* insert_after) { + DCHECK(IsHead(*head)); + value_lists_.erase(head); + + ValueNode* tail = head->prev; + tail->next = insert_after->next; + insert_after->next->prev = tail; + + insert_after->next = head; + head->prev = insert_after; + } + + string ValueListToString(const ValueNode* element) { + const ValueNode* head = element; + while (!IsHead(*head)) { + head = Prev(*head); + } + std::vector values; + for (const ValueNode* p = head; p != nullptr; p = Next(*p)) { + values.push_back(p->value); + } + return StrCat("{", + Join(values, ", ", + [](string* s, const HloValue* value) { + StrAppend(s, value->ToShortString()); + }), + "}"); + } + + string ToString() const { + string out = StrCat("BufferValueTracker:\n"); + StrAppend(&out, " Def-use chains in each buffer:\n"); + for (const ValueNode* head : value_lists_) { + StrAppend(&out, " Buffer defined by ", head->value->ToShortString(), + ":\n"); + const ValueNode* p = head; + do { + StrAppend(&out, " ", p->value->ToShortString(), ", uses: ", + Join(p->uses, "; ", + [](string* s, const HloUse* use) { + StrAppend(s, use->ToString()); + }), + "\n"); + + p = p->next; + } while (p != head); + } + StrAppend(&out, " Potentially removable copies:\n"); + for (const auto& pair : copy_map_) { + const HloInstruction* copy = pair.first; + const CopyNodes& copy_info = pair.second; + + StrAppend(&out, " ", copy->name(), " : ", + copy_info.src->value->ToShortString(), " => ", + copy_info.dest->value->ToShortString(), "\n"); + } + return out; + } + + private: + const HloDataflowAnalysis& dataflow_; + const HloOrdering& ordering_; + + // The heads of all the value lists. Each value list represents the HLO + // values contained in a particular HLO buffer. The values in the list are + // in dependency order. + tensorflow::gtl::FlatSet value_lists_; + + // Copy removal requires fast access to the value list elements + // corresponding to the source and destination values of the kCopy + // instruction. This data structure holds pointers to these elements for + // each kCopy instruction in the graph. + struct CopyNodes { + // The source and destinations values of the kCopy instruction. + ValueNode* src = nullptr; + ValueNode* dest = nullptr; + }; + tensorflow::gtl::FlatMap copy_map_; + }; + + HloModule* module_; + const HloAliasAnalysis& alias_analysis_; + const HloOrdering& ordering_; + + // Object tracking the HLO values contained in each HLO buffer. + BufferValueTracker buffer_value_tracker_; +}; + +// Try to remove as many copies from the module as possible without introducing +// live range interference. Copy instructions (identified by their unique id) in +// the set copies_to_exclude are not considered for removal. +Status RemoveUnnecessaryCopies( + const HloOrdering& ordering, + const tensorflow::gtl::FlatSet& copies_to_exclude, HloModule* module) { + TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, + HloAliasAnalysis::Run(module)); + CopyRemover copy_remover(*alias_analysis, ordering, module); + XLA_VLOG_LINES(3, copy_remover.ToString()); + + tensorflow::gtl::FlatSet existing_copies; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kCopy && + !ContainsKey(copies_to_exclude, instruction->unique_id())) { + TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status()); + } + } } - return copy; + + return Status::OK(); } -// The 'read_only_indices' are initialized based on points-to analysis on the -// while body corresponding to 'while_hlo'. If the init buffer corresponding to -// a read-only index aliases with a constant, it cannot be considered read-only, -// and must be copied. This is necessary because BufferAssignment does not -// currently assign an allocation for constants (b/32248867). -// This function performs this fix-up of 'read_only_indices'. +// Add copies to address special constraints on the roots of computations not +// related to live range interference: // -// Returns a ShapeTree of copy_overrides, which implements an optimization to -// allow multiple while loops that share the same read-only constants to -// share a single copy. -StatusOr> RevertReadOnlyIndicesForConstants( - const HloInstruction* while_hlo, - const TuplePointsToAnalysis& points_to_analysis, - ShapeTree* read_only_indices, - FlatMap* shared_copies) { - const HloInstruction* init_hlo = while_hlo->operand(0); - const PointsToSet& points_to = points_to_analysis.GetPointsToSet(init_hlo); - - // Mapping from LogicalBuffer to index (used to detect non-distinct indices). - FlatSet buffer_set; - - ShapeTree copy_overrides(init_hlo->shape()); - points_to.ForEachElement([init_hlo, read_only_indices, shared_copies, - &buffer_set, ©_overrides]( - const ShapeIndex& index, - const PointsToSet::BufferList& buffers) { - // Look for read-only entry parameters. - if (!read_only_indices->element(index)) { - return; - } - for (const LogicalBuffer* buffer : buffers) { - HloInstruction* pointee = buffer->instruction(); - const bool is_constant = pointee->opcode() == HloOpcode::kConstant; - if (!is_constant) { - continue; - } +// (1) Entry computation root must be unambiguous and distinct. +// +// (2) Any computation called by a kCall instruction must have an +// unambiguous root. +// +// (3) Constants and parameters cannot be live out of the entry computation +// +Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { + TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, + HloAliasAnalysis::Run(module)); + + // Identify which shape indices of which instructions need to be copied. Store + // these results in 'instructions_to_copy'. + std::unordered_map> instructions_to_copy; + auto add_index_to_copy = [&instructions_to_copy](HloInstruction* instruction, + const ShapeIndex& index) { + auto it = instructions_to_copy.find(instruction); + if (it == instructions_to_copy.end()) { + auto it_added = instructions_to_copy.emplace( + std::piecewise_construct, std::forward_as_tuple(instruction), + std::forward_as_tuple(instruction->shape(), /*init_value=*/false)); + it = it_added.first; + } + *it->second.mutable_element(index) = true; + }; - // We have found an constant that is read-only in - // the while body. These buffers are managed by the caller, and cannot - // be aliased with HLO buffers. Revert this read-only index, - // to allow it to be copied. - *read_only_indices->mutable_element(index) = false; - - // Optimization to allow multiple while loops that share the same - // read-only entry constants to share a single copy. - // Only unambiguous and distinct array-shaped buffers are allowed, to - // reduce code complexity. The shape of the entry parameter must be - // identical to the shape of the init_hlo at this index, to ensure - // there were no intervening bitcast or GTE instructions, which are - // also hard to handle. - const Shape& pointee_shape = pointee->shape(); - const Shape& init_shape = - ShapeUtil::GetSubshape(init_hlo->shape(), index); - if (buffers.size() == 1 && ShapeUtil::IsArray(pointee_shape) && - ShapeUtil::Equal(pointee_shape, init_shape) && - buffer_set.count(buffer) < 1) { - HloInstruction** copy = &(*shared_copies)[pointee]; - if (*copy == nullptr) { - *copy = pointee->parent()->AddInstruction(HloInstruction::CreateUnary( - pointee_shape, HloOpcode::kCopy, pointee)); + // Iterate through values of all constants and entry parameters. These values + // are special because they are held in read-only buffers. If any of these + // values share a buffer with other values (for example, the init value of a + // while is a constant) then copy the value at its definition and replace all + // its uses with the copy. + for (const HloValue* value : alias_analysis->dataflow_analysis().values()) { + if (ValueIsReadOnly(*value) && + alias_analysis->GetBufferContainingValue(*value).values().size() > 1) { + VLOG(2) << "Value " << value->ToShortString() + << " is read only, but its buffer contains more than one value. " + "Copying."; + add_index_to_copy(value->defining_instruction(), value->defining_index()); + } + } + + // Identify copies which must be added at root instructions + for (HloComputation* computation : module->computations()) { + const CallGraphNode& node = call_graph.GetNode(computation); + if (node.context() == CallContext::kParallel) { + continue; + } + TF_RET_CHECK(node.context() == CallContext::kSequential); + + const bool is_entry = computation == module->entry_computation(); + HloInstruction* root = computation->root_instruction(); + + // Mark nondistinct/ambiguous indices. + tensorflow::gtl::FlatSet seen; + ShapeUtil::ForEachSubshape( + root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) { + std::vector buffers_at_index = + alias_analysis->ComputeBuffersAt(root, index); + bool buffer_seen_before = false; + for (const HloBuffer* buffer : buffers_at_index) { + buffer_seen_before |= !seen.insert(buffer).second; + } + if (buffers_at_index.size() > 1 || (buffer_seen_before && is_entry)) { + VLOG(2) << "Index " << index << " of root of computation " + << computation->name() << " (" << root->name() + << ") has ambiguous or non-distinct buffer. Copying."; + add_index_to_copy(root, index); + } + }); + + // For entry instructions, mark any parameter or constant values. + if (is_entry) { + for (const auto& pair : + alias_analysis->dataflow_analysis().GetInstructionValueSet(root)) { + const ShapeIndex& index = pair.first; + const HloValueSet& value_set = pair.second; + for (const HloValue* value : value_set.values()) { + if (ValueIsReadOnly(*value)) { + VLOG(2) << "Root of entry computation (" << root->name() + << ") has constant or entry parameter value at index " + << index << ". Copying."; + add_index_to_copy(root, index); + } } - // Add the copy as an override. - *copy_overrides.mutable_element(index) = *copy; } + } + } - // Tracks whether this current buffer is distinct. - buffer_set.insert(buffer); + // Add copy instructions indicated in 'instructions_to_copy' to the module. + for (const auto& pair : instructions_to_copy) { + HloInstruction* instruction = pair.first; + const ShapeTree& indices_to_copy = pair.second; - // We've already reverted the read-only index and handled the - // single-copy optimization above, so there's nothing more to do. - break; + std::vector users = instruction->users(); + TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy, + instruction->parent()->DeepCopyInstruction( + instruction, &indices_to_copy)); + for (HloInstruction* user : users) { + TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy)); + } + if (instruction == instruction->parent()->root_instruction()) { + instruction->parent()->set_root_instruction(deep_copy); } - }); - return copy_overrides; + } + + return Status::OK(); +} + +Status VerifyNoLiveRangeInterference(HloModule* module) { + TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, + HloAliasAnalysis::Run(module)); + DependencyHloOrdering ordering(module); + TF_RET_CHECK(!alias_analysis->HasLiveRangeInterference(ordering)); + return Status::OK(); } -} // anonymous namespace - -// NOTE: This is only called by gpu::CopyInsertion. It's not called here in the -// base class, since the regular CopyInsertion logic above selectively copies -// tuple elements, while this method assumes all buffers need to be deep copied. -StatusOr CopyInsertion::FindOrInsertCopy(HloInstruction* hlo) { - auto copy_it = inserted_copies_.find(hlo); - if (copy_it == inserted_copies_.end()) { - HloInstruction* copy = hlo->parent()->DeepCopyInstruction(hlo).ValueOrDie(); - inserted_copies_.insert({hlo, copy}); - return copy; - } else { - return copy_it->second; +void MaybeDumpModule(const string& message, const HloModule& module) { + if (VLOG_IS_ON(3)) { + VLOG(3) << message; + XLA_VLOG_LINES(3, module.ToString()); + hlo_graph_dumper::MaybeDumpHloModule(module, message); } } +} // namespace + StatusOr CopyInsertion::Run(HloModule* module) { - bool changed = false; - VLOG(2) << "CopyInsertion for module " << module->name(); + // Copy insertion is performed in three steps: + // + // (1) Add copies conservatively to guarantee that there is no live-range + // interference. This is done simplistically and usually results in more + // copies than is strictly necessary. + // + // (2) Using a more fine-grained analysis, remove as many copies that were + // added in (1) as possible while ensuring no live-range interference. + // + // (3) Add copies to resolve issues not related to live range interference + // such as parameters and constants live out of the entry computation. + // + // We add copies then remove them (step (1) then (2)) rather than simply + // adding only the copies that are necessary because, in general, it is + // difficult to figure out the minimal set of copies to add once there is + // interference. On the other hand, it is easy to determine if removing a copy + // will introduce interference. + // + // The final copy insertion in (3) is done separately to simplify the + // implementation of copy removal in (2) which is the most complicated part of + // the pass. As is, copy removal only has to reason about live range + // interference. If all copies were added in step (1) then copy removal would + // also have to reason about things like constants and parameters live out of + // the computation. + MaybeDumpModule("before copy insertion", *module); - TF_ASSIGN_OR_RETURN( - std::unique_ptr liveness, - BufferLiveness::Run(module, MakeUnique(module))); - const auto& points_to_analysis = liveness->points_to_analysis(); - XLA_VLOG_LINES(2, points_to_analysis.ToString()); - XLA_VLOG_LINES(2, module->ToString()); - - // Gather all while body computations and while instructions. - FlatSet while_body_computations; - std::vector while_instructions; - for (auto* computation : module->computations()) { + std::unique_ptr call_graph = CallGraph::Build(module); + if (!call_graph->IsFlattened()) { + return FailedPrecondition( + "Call graph must be flattened before copy insertion."); + } + + // Gather Ids of existing kCopy instructions in the module. We avoid removing + // these copies (except via DCE in TupleSimplifier) because they may have been + // added for reasons not considered by copy insertion (eg, layout assignment). + // Instruction id is used instead of HloInstruction* because the pointer + // values may be recycled. + tensorflow::gtl::FlatSet existing_copies; + for (HloComputation* computation : module->computations()) { for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kWhile) { - while_body_computations.insert(instruction->while_body()); - while_instructions.push_back(instruction); + if (instruction->opcode() == HloOpcode::kCopy) { + existing_copies.insert(instruction->unique_id()); } } } - // Collect instruction buffer indices to copy in 'instructions_to_copy'. - std::vector instructions_to_copy; - - // Add copies of computation root instructions, if needed. - FlatMap> while_body_read_only_indices; - for (auto* computation : module->MakeNonfusionComputations()) { - VLOG(2) << "computation " << computation->name(); - InstructionCopier root_copier(computation->root_instruction(), - /*copy_users=*/{}); - if (while_body_computations.count(computation) > 0) { - // Record root indices to copy for while body sub-computations. We do not - // need to call RecordIndicesWhichPointToParamOrConstant for the while - // body root instruction here, because any necessary copies needed to - // avoid constants or parameters in the output are handled by while.init - // operand copy insertion below (which will share an allocation). - HloInstruction* while_body_param = computation->parameter_instruction(0); - ShapeTree read_only_indices(while_body_param->shape()); - TF_RETURN_IF_ERROR(root_copier.RecordIndicesToCopyForColocatingBuffers( - *liveness, while_body_param, &read_only_indices)); - while_body_read_only_indices[computation] = read_only_indices; - - // Mark control predecessors, based on the body param, for any copies - // we'll be inserting. This ensures the copy doesn't run too early. - TF_RETURN_IF_ERROR(root_copier.RecordControlPredecessors( - points_to_analysis, while_body_param)); - } else { - // Record root indices to copy for general computations. - TF_RETURN_IF_ERROR(root_copier.RecordIndicesWhichPointToParamOrConstant( - points_to_analysis)); + TF_RETURN_IF_ERROR(AddCopiesToResolveInterference(module)); + + // Simplify the tuple structures introduced by the deep copies. This should be + // done before removing copies (RemoveUnnecessaryCopies) because tuple + // simplification changes dependencies in the graph which changes live range + // interference in the graph. Also run DCE to remove the dead Tuple/GTE + // instructions introduced by tuple simplification. + TupleSimplifier tuple_simplifier; + HloDCE dce; + TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); + TF_RETURN_IF_ERROR(dce.Run(module).status()); + + TF_DCHECK_OK(VerifyNoLiveRangeInterference(module)); + + MaybeDumpModule("after adding copies to resolve interference", *module); + + DependencyHloOrdering ordering(module); + TF_RETURN_IF_ERROR( + RemoveUnnecessaryCopies(ordering, existing_copies, module)); + + MaybeDumpModule("after removing unnecessary copies", *module); + + TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module)); + + MaybeDumpModule("after adding special-case copies", *module); + + TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); + TF_RETURN_IF_ERROR(dce.Run(module).status()); + TF_DCHECK_OK(VerifyNoLiveRangeInterference(module)); + + MaybeDumpModule("after copy insertion", *module); + + if (VLOG_IS_ON(1)) { + int64 num_total_copies = 0; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kCopy) { + num_total_copies++; + } + } } - instructions_to_copy.push_back(root_copier); + VLOG(1) << "Num copies before copy-insertion: " << existing_copies.size(); + VLOG(1) << "Num copies after copy-insertion: " << num_total_copies; } - // Add copies of while 'init' operand instructions, if needed. 'shared_copies' - // is used to ensure that multiple while loops can share a single copy of the - // same entry parameter or constant, if all loops use it read-only. - // - // TODO(b/33301720) Remove redundant while instruction copies. - FlatMap shared_copies; - for (HloInstruction* while_hlo : while_instructions) { - // Fix read_only_indices to account for entry constants. Also - // initialize copy_overrides, which ensures a single copy for each read-only - // constant that is used in multiple while loops. - ShapeTree* read_only_indices = - &while_body_read_only_indices[while_hlo->while_body()]; - TF_ASSIGN_OR_RETURN( - const ShapeTree copy_overrides, - RevertReadOnlyIndicesForConstants(while_hlo, points_to_analysis, - read_only_indices, &shared_copies)); - // Create InstructionCopier for init operand of while instruction. - HloInstruction* init_hlo = while_hlo->mutable_operand(0); - InstructionCopier init_copier(init_hlo, {while_hlo}); - init_copier.SetReadOnlyIndices(*read_only_indices); - init_copier.SetCopyOverrides(copy_overrides); - // Record 'init' buffer indices which point-to a Constant or Parameter. - TF_RETURN_IF_ERROR(init_copier.RecordIndicesWhichPointToParamOrConstant( - points_to_analysis)); - // Record indices necessary to colocate while and init operand buffers. - TF_RETURN_IF_ERROR(init_copier.RecordIndicesToCopyForColocatingBuffers( - *liveness, while_hlo, /*read_only_indices_out=*/nullptr)); - instructions_to_copy.push_back(init_copier); + return true; +} + +namespace { + +bool IsWhileBody(const HloComputation* computation, + const CallGraph& call_graph) { + const CallGraphNode& node = call_graph.GetNode(computation); + + if (node.context() == CallContext::kSequential && + !node.caller_callsites().empty()) { + // Callgraph should be flattened so sequential context computations can + // have at most one caller. + CHECK_EQ(node.caller_callsites().size(), 1); + const HloInstruction* calling_instruction = + node.caller_callsites()[0].instruction(); + if (calling_instruction->opcode() == HloOpcode::kWhile && + calling_instruction->while_body() == node.computation()) { + return true; + } } + return false; +} - for (InstructionCopier& to_copy : instructions_to_copy) { - if (to_copy.HasAllIndicesFalse()) { +} // namespace + +/* static */ StatusOr CopyInsertion::AddCopiesForBufferAssignment( + HloModule* module) { + std::unique_ptr call_graph = CallGraph::Build(module); + TF_ASSIGN_OR_RETURN(std::unique_ptr dataflow, + HloDataflowAnalysis::Run(module)); + + bool changed = false; + + // If a buffer live out of a computation is a constant, a parameter, or not + // defined in the computation, then copy it to account for the limited + // computation-scoped analysis in buffer assignment. An exception to this rule + // is the while body which is handled properly without copies. + for (HloComputation* computation : module->computations()) { + if (computation == module->entry_computation() || + IsWhileBody(computation, *call_graph)) { continue; } - changed = true; - // Copy instruction at recorded buffer indices. - HloComputation* computation = to_copy.instruction()->parent(); - HloInstruction* copy = to_copy.Copy(); - if (to_copy.instruction() == computation->root_instruction()) { - computation->set_root_instruction(copy); + HloInstruction* root = computation->root_instruction(); + ShapeTree indices_to_copy(root->shape(), /*init_value=*/false); + bool copy_root = false; + for (const auto& pair : dataflow->GetInstructionValueSet(root)) { + const ShapeIndex& index = pair.first; + const HloValueSet& value_set = pair.second; + for (const HloValue* value : value_set.values()) { + HloInstruction* def = value->defining_instruction(); + if (def->parent() != computation || + def->opcode() == HloOpcode::kConstant || + def->opcode() == HloOpcode::kParameter) { + *indices_to_copy.mutable_element(index) = true; + copy_root = true; + } + } + } + if (copy_root) { + TF_ASSIGN_OR_RETURN( + HloInstruction * root_copy, + computation->DeepCopyInstruction(root, &indices_to_copy)); + computation->set_root_instruction(root_copy); + changed = true; } } - VLOG(3) << "After copy insertion for module " << module->name(); - XLA_VLOG_LINES(3, module->ToString()); + TupleSimplifier tuple_simplifier; + HloDCE dce; + TF_ASSIGN_OR_RETURN(bool tuple_simplifier_changed, + tuple_simplifier.Run(module)); + TF_ASSIGN_OR_RETURN(bool dce_changed, dce.Run(module)); - return changed; + return changed || tuple_simplifier_changed || dce_changed; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h index 28bb62e40c7674960dbb1bb63dc8967b06956028..65e3d31e347e2cb249a072e7d06ca10c55401748 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.h +++ b/tensorflow/compiler/xla/service/copy_insertion.h @@ -25,12 +25,25 @@ limitations under the License. namespace xla { -// HLO pass which inserts a copy of the root instruction (creating a new root) -// if the root is or points-to any constant or parameter instruction. -// If the root instruction is a Tuple, only tuple elements which point to -// constant or parameter instructions will be copied. -// Copy insertion is necessary because constant and parameter arrays have -// different lifetimes than computation results. +// Copy insertion is a legalization HLO pass which inserts copies (kCopy +// instructions) to eliminate several kinds of problems in the HLO module. +// +// (1) Entry parameter or a constant live out of the entry computation. Entry +// computation arguments and constants have different lifetimes than the +// computation result and cannot share the same allocation. Parameters and +// constants live out of non-entry computations do not need copies. +// +// (2) Different values which are simultaneously live and which must be held +// in the same buffer. This can occur in while bodies. Specifically, the +// while loop state (the arguments to the while instruction) is updated +// in-place and the update may clobber the value from the previous +// iteration before the previous value is dead. Computations called from +// kCall instructions do not need such copies because kCall has no update +// in-place semantics. +// +// (3) The buffer set of the root instruction of the entry computation must be +// unambiguous and distinct. That is, InstructionAliasSet::IsAmbiguous and +// InstructionAliasSet::IsDistinct return true. class CopyInsertion : public HloPassInterface { public: tensorflow::StringPiece name() const override { return "copy-insertion"; } @@ -39,14 +52,16 @@ class CopyInsertion : public HloPassInterface { // (copies were inserted). StatusOr Run(HloModule* module) override; - protected: - // Returns a copy of `hlo`. Looks in inserted_copies_ first to avoid making - // duplicate copies. - StatusOr FindOrInsertCopy(HloInstruction* hlo); - - // A map containing all copies inserted during the copy insertion pass. The - // key is the copied instruction and the value is the copy. - tensorflow::gtl::FlatMap inserted_copies_; + // The CPU and GPU backend need additional copies added due to deficiencies in + // buffer assignment. Specifically, copies are needed for constants live-out + // of computations, and for values which are live-in and live-out of the same + // computation. These copies are needed because buffer-assignment uses a + // computation-scoped analyis (TuplePointsToAnalysis) and has limited + // visibility across computation boundaries. This method adds these necessary + // copies. Returns whether the module was modified. + // + // TODO(b/62548313): Remove this when buffer assignment is module-scoped. + static StatusOr AddCopiesForBufferAssignment(HloModule* module); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index a2eacc5c7dae2424e01fdd49d82546b5488d4312..8388574716ad1b78eb8868a8cd732005050b3310 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -17,18 +17,19 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test_benchmark.h" namespace op = xla::testing::opcode_matchers; @@ -37,35 +38,53 @@ namespace { using ::testing::UnorderedElementsAre; +int64 CountCopies(const HloComputation& computation) { + int64 count = 0; + for (const auto& instruction : computation.instructions()) { + if (instruction->opcode() == HloOpcode::kCopy) { + count++; + } + } + return count; +} + +int64 CountCopies(const HloModule& module) { + int64 count = 0; + for (const auto& computation : module.computations()) { + count += CountCopies(*computation); + } + return count; +} + +int64 CountControlEdges(const HloComputation& computation) { + int64 count = 0; + for (const auto& instruction : computation.instructions()) { + count += instruction->control_successors().size(); + } + return count; +} + +int64 CountControlEdges(const HloModule& module) { + int64 count = 0; + for (const auto& computation : module.computations()) { + count += CountControlEdges(*computation); + } + return count; +} + class CopyInsertionTest : public HloTestBase { protected: void InsertCopies(HloModule* module) { CopyInsertion copy_insertion; - EXPECT_IS_OK(copy_insertion.Run(module).status()); - - // Verify the points to set of the root of the computation after copy - // insertion contains no constants or parameters, and is distinct and - // non-ambiguous. - auto points_to_analysis = - TuplePointsToAnalysis::Run(module).ConsumeValueOrDie(); - const auto& points_to = points_to_analysis->GetPointsToSet( - module->entry_computation()->root_instruction()); - EXPECT_TRUE(points_to.IsDistinct()); - EXPECT_TRUE(!points_to.IsAmbiguous()); - - auto maybe_live_out_buffers = - points_to_analysis - ->GetPointsToSet(module->entry_computation()->root_instruction()) - .CreateFlattenedSet(); - - for (const LogicalBuffer* buffer : maybe_live_out_buffers) { - EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kConstant); - EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kParameter); - } + ASSERT_IS_OK(copy_insertion.Run(module).status()); } + + const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); }; TEST_F(CopyInsertionTest, SingleParameter) { + // Computation is a single parameter passed into a tuple. The parameter should + // be copied before entering the tuple. auto builder = HloComputation::Builder(TestName()); HloInstruction* x = builder.AddInstruction( HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x")); @@ -77,14 +96,15 @@ TEST_F(CopyInsertionTest, SingleParameter) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Tuple(op::Copy(old_root->operand(0)))); + op::Tuple(op::Copy(x))); } TEST_F(CopyInsertionTest, SingleConstant) { + // Computation is a single constant passed into a tuple. The parameter should + // be copied before entering the tuple. auto builder = HloComputation::Builder(TestName()); HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(1.0))); @@ -96,11 +116,42 @@ TEST_F(CopyInsertionTest, SingleConstant) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Tuple(op::Copy(old_root->operand(0)))); + op::Tuple(op::Copy(constant))); +} + +TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) { + // Verify that an kCopy instructions which exist in the pass before + // copy-insertion remain in the graph after copy-insertion. + auto module = CreateNewModule(); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction* copy_1 = builder.AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kCopy, constant)); + HloInstruction* copy_2 = builder.AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kCopy, constant)); + HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( + constant->shape(), HloOpcode::kAdd, copy_1, copy_2)); + HloInstruction* add_copy = builder.AddInstruction( + HloInstruction::CreateUnary(constant->shape(), HloOpcode::kCopy, add)); + + module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(CountCopies(*module), 3); + + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*module), 3); + + EXPECT_EQ(module->entry_computation()->root_instruction(), add_copy); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Copy(op::Add(op::Copy(op::Constant()), op::Copy(op::Constant())))); } TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) { @@ -127,12 +178,12 @@ TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 2); - EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Tuple(op::Copy(old_root->operand(0)), - op::Copy(old_root->operand(1)), old_root->operand(2))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple(op::Copy(constant2), op::Copy(x), op::Add(constant1, y))); } TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { @@ -165,6 +216,7 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 2); EXPECT_THAT(module->entry_computation()->root_instruction(), op::Tuple(op::Copy(op::GetTupleElement(old_root)), @@ -187,6 +239,7 @@ TEST_F(CopyInsertionTest, BitcastParameter) { HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); EXPECT_THAT(module->entry_computation()->root_instruction(), op::Copy(old_root)); @@ -208,6 +261,7 @@ TEST_F(CopyInsertionTest, BitcastConstant) { HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); EXPECT_THAT(module->entry_computation()->root_instruction(), op::Copy(old_root)); @@ -227,11 +281,11 @@ TEST_F(CopyInsertionTest, BitcastTupleElementParameter) { EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast)); - HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Tuple(op::Copy(old_root->operand(0)))); + op::Tuple(op::Copy(bitcast))); } TEST_F(CopyInsertionTest, NestedTupleParameter) { @@ -257,6 +311,8 @@ TEST_F(CopyInsertionTest, NestedTupleParameter) { HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 3); + HloInstruction* new_root = module->entry_computation()->root_instruction(); EXPECT_NE(old_root, new_root); @@ -283,7 +339,7 @@ TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) { ShapeUtil::MakeShape(F32, {42})}), "param0")); - // The return value of the computation is the zero-th elemnt of the nested + // The return value of the computation is the zero-th element of the nested // tuple. This element is itself a tuple. auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement( ShapeUtil::GetSubshape(param->shape(), {0}), param, 0)); @@ -293,12 +349,13 @@ TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) { EXPECT_EQ(gte, module->entry_computation()->root_instruction()); - HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 2); - EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Tuple(op::Copy(op::GetTupleElement(old_root)), - op::Copy(op::GetTupleElement(old_root)))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple(op::Copy(op::GetTupleElement(op::GetTupleElement(param))), + op::Copy(op::GetTupleElement(op::GetTupleElement(param))))); } TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { @@ -331,6 +388,7 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); EXPECT_THAT(module->entry_computation()->root_instruction(), op::Copy(old_root)); @@ -346,12 +404,10 @@ class WhileCopyInsertionTest : public CopyInsertionTest { // The parameter 'nested' specifies the loop state shape from which to // read the induction variable. std::unique_ptr BuildConditionComputation( - bool nested = false) { + const Shape& loop_state_shape) { auto builder = HloComputation::Builder(TestName() + ".Condition"); auto limit_const = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(10))); - const Shape& loop_state_shape = - nested ? nested_loop_state_shape_ : loop_state_shape_; auto loop_state = builder.AddInstruction( HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); auto induction_variable = @@ -582,7 +638,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto loop_state_init = builder.AddInstruction( HloInstruction::CreateTuple({induction_var_init, inner_init})); auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile( - loop_state_shape_, condition, body, loop_state_init)); + loop_state_init->shape(), condition, body, loop_state_init)); module_->AddEntryComputation(builder.Build()); return while_hlo; } @@ -658,11 +714,28 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant( Literal::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); // Take a reference to 'data_init' to make it interfere with while result. - builder.AddInstruction(HloInstruction::CreateBinary( + auto add = builder.AddInstruction(HloInstruction::CreateBinary( data_shape_, HloOpcode::kAdd, data_init, one_vec)); - return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init, - &builder); + auto xla_while = BuildWhileInstructionWithCustomInit(loop_state_shape_, + data_init, &builder); + + // Add an additional binary operation operating on the while and the + // interfering add so that neither operation is dead. + auto gte = xla_while->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(xla_while->shape(), {1}), xla_while, 1)); + auto sub = xla_while->parent()->AddInstruction(HloInstruction::CreateBinary( + data_shape_, HloOpcode::kSubtract, add, gte)); + auto gte0 = xla_while->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(xla_while->shape(), {0}), xla_while, 0)); + auto tuple = xla_while->parent()->AddInstruction( + HloInstruction::CreateTuple({gte0, sub})); + + xla_while->parent()->set_root_instruction(tuple); + + return xla_while; } HloInstruction* BuildWhileInstructionWithCustomInit( @@ -672,8 +745,8 @@ class WhileCopyInsertionTest : public CopyInsertionTest { ShapeUtil::Equal(loop_state_shape, nested_loop_state_shape_); auto induction_var_init = builder->AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(0))); - auto condition = - module_->AddEmbeddedComputation(BuildConditionComputation(nested)); + auto condition = module_->AddEmbeddedComputation( + BuildConditionComputation(loop_state_shape)); auto body = module_->AddEmbeddedComputation( BuildIndependentBodyComputation(nested)); auto loop_state_init = builder->AddInstruction( @@ -706,23 +779,21 @@ class WhileCopyInsertionTest : public CopyInsertionTest { // CopyInsertion pass should not generate any copies. // TEST_F(WhileCopyInsertionTest, IndependentTupleElements) { - auto condition = module_->AddEmbeddedComputation(BuildConditionComputation()); + auto condition = module_->AddEmbeddedComputation( + BuildConditionComputation(loop_state_shape_)); auto body = module_->AddEmbeddedComputation(BuildIndependentBodyComputation()); auto while_hlo = BuildWhileInstruction(condition, body); - const HloInstruction* old_init = while_hlo->operand(0); - HloInstruction* old_root = body->root_instruction(); InsertCopies(module_.get()); - HloInstruction* new_root = body->root_instruction(); - const HloInstruction* new_init = while_hlo->operand(0); - // No copies should be inserted so root should not be updated. - EXPECT_EQ(old_root, new_root); + // Body should have no copies as the adds can be done inplace. + EXPECT_EQ(CountCopies(*body), 0); + EXPECT_EQ(CountControlEdges(*module_), 0); - // Both init indices need copies. - EXPECT_THAT(new_init, op::Tuple(op::Copy(old_init->operand(0)), - op::Copy(old_init->operand(1)))); + // Both init indices need copies as they are constants. + EXPECT_THAT(while_hlo->operand(0), + op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant()))); } // Tests while body computation with dependent tuple elements: @@ -737,20 +808,33 @@ TEST_F(WhileCopyInsertionTest, IndependentTupleElements) { // Tuple(Copy(out0), out1) // TEST_F(WhileCopyInsertionTest, DependentTupleElements) { - auto condition = module_->AddEmbeddedComputation(BuildConditionComputation()); + auto condition = module_->AddEmbeddedComputation( + BuildConditionComputation(loop_state_shape_)); auto body = module_->AddEmbeddedComputation(BuildDependentBodyComputation()); auto while_hlo = BuildWhileInstruction(condition, body); - const HloInstruction* old_init = while_hlo->operand(0); - HloInstruction* old_root = body->root_instruction(); InsertCopies(module_.get()); - HloInstruction* new_root = body->root_instruction(); - const HloInstruction* new_init = while_hlo->operand(0); - EXPECT_THAT(new_root, - op::Tuple(op::Copy(old_root->operand(0)), old_root->operand(1))); - EXPECT_THAT(new_init, op::Tuple(op::Copy(old_init->operand(0)), - op::Copy(old_init->operand(1)))); + EXPECT_EQ(CountCopies(*body), 1); + EXPECT_EQ(CountControlEdges(*body), 0); + + EXPECT_THAT( + body->root_instruction(), + op::Tuple(op::Add(), op::Add(op::GetTupleElement(), op::Broadcast()))); + + auto add = body->root_instruction()->operand(0); + auto bcast = body->root_instruction()->operand(1)->operand(1); + ASSERT_EQ(add->opcode(), HloOpcode::kAdd); + ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast); + + EXPECT_THAT( + while_hlo->while_body()->root_instruction(), + op::Tuple(op::Add(op::Copy(), op::Constant()), + op::Add(op::GetTupleElement(), op::Broadcast(op::Copy())))); + + // Both init indices need copies as they are constants. + EXPECT_THAT(while_hlo->operand(0), + op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant()))); } // Tests while body computation with read-only tuple element 0: @@ -768,33 +852,26 @@ TEST_F(WhileCopyInsertionTest, DependentTupleElements) { // // CopyInsertion pass should not generate any copies for the while body. TEST_F(WhileCopyInsertionTest, DependentTupleElements_OneReadOnly) { - auto condition = module_->AddEmbeddedComputation(BuildConditionComputation()); + auto condition = module_->AddEmbeddedComputation( + BuildConditionComputation(loop_state_shape_)); auto body = module_->AddEmbeddedComputation( BuildDependentBodyOneReadOnlyComputation()); - auto while_hlo = BuildWhileInstruction(condition, body); + BuildWhileInstruction(condition, body); - const HloInstruction* old_init = while_hlo->operand(0); - HloInstruction* old_root = body->root_instruction(); InsertCopies(module_.get()); - HloInstruction* new_root = body->root_instruction(); - const HloInstruction* new_init = while_hlo->operand(0); - - // No copies should be inserted in the body, so root should not be updated. - EXPECT_EQ(old_root, new_root); - // Both indices need copies, even though Index 0 is read-only, since both are - // constants, which must be copied. - EXPECT_THAT(new_init, op::Tuple(op::Copy(old_init->operand(0)), - op::Copy(old_init->operand(1)))); + // No copies or control edges should be inserted. The body is legal as is. + EXPECT_EQ(CountCopies(*body), 0); + EXPECT_EQ(CountControlEdges(*body), 0); } // Same as above, but with two while loops, sharing entry parameters. TEST_F(WhileCopyInsertionTest, DependentTupleElements_OneReadOnly_TwoLoops_EntryParams) { - auto condition1 = - module_->AddEmbeddedComputation(BuildConditionComputation()); - auto condition2 = - module_->AddEmbeddedComputation(BuildConditionComputation()); + auto condition1 = module_->AddEmbeddedComputation( + BuildConditionComputation(loop_state_shape_)); + auto condition2 = module_->AddEmbeddedComputation( + BuildConditionComputation(loop_state_shape_)); auto body1 = module_->AddEmbeddedComputation( BuildDependentBodyOneReadOnlyComputation()); auto body2 = module_->AddEmbeddedComputation( @@ -812,30 +889,46 @@ TEST_F(WhileCopyInsertionTest, loop_state_shape_, condition1, body1, loop_init)); auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile( loop_state_shape_, condition2, body2, loop_init)); - module_->AddEntryComputation(builder.Build()); + + // Add a couple elements from each of the while so both whiles are live. + auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0)); + auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(while_hlo2->shape(), {0}), while_hlo2, 0)); + builder.AddInstruction( + HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2)); + + auto entry = module_->AddEntryComputation(builder.Build()); InsertCopies(module_.get()); - // Both while loops alias iter_param, since index 0 is read-only in the body. - EXPECT_EQ(while_hlo1->operand(0)->operand(0), - while_hlo2->operand(0)->operand(0)); - EXPECT_EQ(while_hlo1->operand(0)->operand(0), iter_param); + // Neither body should have any copies or control edges in them. + EXPECT_EQ(CountCopies(*body1), 0); + EXPECT_EQ(CountCopies(*body2), 0); + EXPECT_EQ(CountControlEdges(*body1), 0); + EXPECT_EQ(CountControlEdges(*body2), 0); - // Each while loop gets its own copy of data_param, since index 1 is not - // read-only in the body. + // Only two copies should be necessary. Each of the whiles should have + // a copy of tuple element 1 (init value is a parameter, and the element is + // not non-read-only) so each of the while bodies gets its own buffer to write + // element 1 into. + EXPECT_EQ(CountCopies(*entry), 2); + + EXPECT_EQ(while_hlo1->operand(0)->operand(1)->opcode(), HloOpcode::kCopy); + EXPECT_EQ(while_hlo2->operand(0)->operand(1)->opcode(), HloOpcode::kCopy); + + // The two copies of element 1 should be different. EXPECT_NE(while_hlo1->operand(0)->operand(1), while_hlo2->operand(0)->operand(1)); - EXPECT_THAT(while_hlo1->operand(0)->operand(1), op::Copy(data_param)); - EXPECT_THAT(while_hlo2->operand(0)->operand(1), op::Copy(data_param)); } // Same as above, but with two while loops, sharing non-parameters. TEST_F(WhileCopyInsertionTest, DependentTupleElements_OneReadOnly_TwoLoops_NonParams) { - auto condition1 = - module_->AddEmbeddedComputation(BuildConditionComputation()); - auto condition2 = - module_->AddEmbeddedComputation(BuildConditionComputation()); + auto condition1 = module_->AddEmbeddedComputation( + BuildConditionComputation(loop_state_shape_)); + auto condition2 = module_->AddEmbeddedComputation( + BuildConditionComputation(loop_state_shape_)); auto body1 = module_->AddEmbeddedComputation( BuildDependentBodyOneReadOnlyComputation()); auto body2 = module_->AddEmbeddedComputation( @@ -858,21 +951,28 @@ TEST_F(WhileCopyInsertionTest, loop_state_shape_, condition1, body1, loop_init)); auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile( loop_state_shape_, condition2, body2, loop_init)); - module_->AddEntryComputation(builder.Build()); + + // Add a couple elements from each of the while so both whiles are not dead. + auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0)); + auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(while_hlo2->shape(), {0}), while_hlo2, 0)); + builder.AddInstruction( + HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2)); + auto entry = module_->AddEntryComputation(builder.Build()); InsertCopies(module_.get()); - // No copies of iter_value are necessary, since index 0 is read-only in both - // while bodies. - EXPECT_EQ(while_hlo1->operand(0)->operand(0), iter_value); - EXPECT_EQ(while_hlo2->operand(0)->operand(0), iter_value); + // Ideally only one copy should be necessary. One of the whiles should + // have a copy of tuple element 1 (the non-read-only element) so each of the + // while bodies gets its own buffer to write element 1 into. However, the + // analysis isn't perfect and adds an additional copy of element 0. + EXPECT_EQ(CountCopies(*entry), 2); - // Each while loop gets its own copy of data_value, since index 1 is not - // read-only in the body. - EXPECT_NE(while_hlo1->operand(0)->operand(1), - while_hlo2->operand(0)->operand(1)); - EXPECT_THAT(while_hlo1->operand(0)->operand(1), op::Copy(data_value)); - EXPECT_THAT(while_hlo2->operand(0)->operand(1), op::Copy(data_value)); + EXPECT_THAT(while_hlo1->operand(0), + op::Tuple(op::Exp(), op::Copy(op::Exp()))); + EXPECT_THAT(while_hlo2->operand(0), + op::Tuple(op::Exp(), op::Copy(op::Exp()))); } // Tests while body computation with nested tuple elements: @@ -905,18 +1005,34 @@ TEST_F(WhileCopyInsertionTest, // Tuple // new root // TEST_F(WhileCopyInsertionTest, NestedTupleElements) { - auto condition = - module_->AddEmbeddedComputation(BuildConditionComputation(true)); + auto condition = module_->AddEmbeddedComputation( + BuildConditionComputation(nested_loop_state_shape_)); auto body = module_->AddEmbeddedComputation(BuildNestedBodyComputation()); BuildWhileInstruction(condition, body, true); - HloInstruction* old_root = body->root_instruction(); + // HloInstruction* old_root = body->root_instruction(); InsertCopies(module_.get()); - EXPECT_THAT(body->root_instruction(), - op::Tuple(old_root->operand(0), - op::Tuple(old_root->operand(1)->operand(0), - op::Copy(old_root->operand(1)->operand(1))))); + // The only copy necessary is for the kReverse as it cannot be done + // in-place (instruction can share buffer with operand). The other elements of + // the loop state are kAdd instructions which can be done in-place. + EXPECT_EQ(CountCopies(*body), 1); + + // Each element of the init needs a copy as all are constants. + EXPECT_EQ(CountCopies(*module_), 4); + + // Either the kReverse itself must be copied or the operand of the kReverse + // must be copied. + if (body->root_instruction()->operand(1)->operand(1)->opcode() == + HloOpcode::kCopy) { + EXPECT_THAT( + body->root_instruction(), + op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Reverse())))); + } else { + EXPECT_THAT( + body->root_instruction(), + op::Tuple(op::Add(), op::Tuple(op::Add(), op::Reverse(op::Copy())))); + } } // Tests while init instruction which points-to a constant. @@ -927,11 +1043,13 @@ TEST_F(WhileCopyInsertionTest, NestedTupleElements) { // TEST_F(WhileCopyInsertionTest, InitPointsToConstant) { auto while_hlo = BuildWhileInstruction_InitPointsToConstant(); - auto old_init = while_hlo->operand(0); + InsertCopies(module_.get()); + EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0); + EXPECT_EQ(CountCopies(*module_), 2); - EXPECT_THAT(while_hlo->operand(0), op::Tuple(op::Copy(old_init->operand(0)), - op::Copy(old_init->operand(1)))); + EXPECT_THAT(while_hlo->operand(0), + op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant()))); } // Tests while init instruction which points-to a parameter. @@ -942,11 +1060,13 @@ TEST_F(WhileCopyInsertionTest, InitPointsToConstant) { // TEST_F(WhileCopyInsertionTest, InitPointsToParameter) { auto while_hlo = BuildWhileInstruction_InitPointsToParameter(); - auto old_init = while_hlo->operand(0); + InsertCopies(module_.get()); + EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0); + EXPECT_EQ(CountCopies(*module_), 2); - EXPECT_THAT(while_hlo->operand(0), op::Tuple(op::Copy(old_init->operand(0)), - op::Copy(old_init->operand(1)))); + EXPECT_THAT(while_hlo->operand(0), + op::Tuple(op::Copy(op::Constant()), op::Copy(op::Parameter()))); } // Tests while init instruction which has an ambiguous points-to set. @@ -975,15 +1095,34 @@ TEST_F(WhileCopyInsertionTest, InitPointsToParameter) { // TEST_F(WhileCopyInsertionTest, InitPointsToAmbiguous) { auto while_hlo = BuildWhileInstruction_InitPointsToAmbiguous(); - auto old_init = while_hlo->operand(0); - InsertCopies(module_.get()); - EXPECT_THAT( - while_hlo->operand(0), - op::Tuple( - op::Copy(old_init->operand(0)), - op::Tuple(op::Copy(op::GetTupleElement(old_init->operand(1))), - op::Copy(op::GetTupleElement(old_init->operand(1)))))); + InsertCopies(module_.get()); + EXPECT_EQ(CountCopies(*module_), 4); + // The entry computation requires three copies to resolve the ambiguity of two + // init elements and the constant passed in as one of the init elements. + EXPECT_EQ(CountCopies(*module_->entry_computation()), 3); + EXPECT_THAT(while_hlo->operand(0), + op::Tuple(op::Copy(op::Constant()), + op::Tuple(op::Copy(op::GetTupleElement()), + op::Copy(op::GetTupleElement())))); + + // The body requires one copy because the buffer set is not distinct: the + // result of one of the adds is written into two elements of the output of the + // loop body. Either element might be copied. + EXPECT_EQ(CountCopies(*while_hlo->while_body()), 1); + if (while_hlo->while_body() + ->root_instruction() + ->operand(1) + ->operand(0) + ->opcode() == HloOpcode::kCopy) { + EXPECT_THAT( + while_hlo->while_body()->root_instruction(), + op::Tuple(op::Add(), op::Tuple(op::Copy(op::Add()), op::Add()))); + } else { + EXPECT_THAT( + while_hlo->while_body()->root_instruction(), + op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Add())))); + } } // Tests while init instruction which has a non-distinct points-to set. @@ -1011,13 +1150,43 @@ TEST_F(WhileCopyInsertionTest, InitPointsToAmbiguous) { // TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) { auto while_hlo = BuildWhileInstruction_InitPointsToNonDistinct(); - auto old_init = while_hlo->operand(0); + InsertCopies(module_.get()); - EXPECT_THAT(while_hlo->operand(0), - op::Tuple(op::Copy(old_init->operand(0)), - op::Tuple(op::Copy(old_init->operand(1)->operand(0)), - op::Copy(old_init->operand(1)->operand(0))))); + // The entry computation requires two copies to resolve the non-disinctness of + // two init elements and the constant passed in as one of the init + // elements. Either element can be copied for the distinctness issue. + EXPECT_EQ(CountCopies(*module_->entry_computation()), 2); + if (while_hlo->operand(0)->operand(1)->operand(0)->opcode() == + HloOpcode::kCopy) { + EXPECT_THAT( + while_hlo->operand(0), + op::Tuple(op::Copy(op::Constant()), + op::Tuple(op::Copy(op::Broadcast()), op::Broadcast()))); + } else { + EXPECT_THAT( + while_hlo->operand(0), + op::Tuple(op::Copy(op::Constant()), + op::Tuple(op::Broadcast(), op::Copy(op::Broadcast())))); + } + + // The body requires one copy because the buffer set is not distinct: the + // result of one of the adds is written into two elements of the output of the + // loop body. Either element might be copied. + EXPECT_EQ(CountCopies(*while_hlo->while_body()), 1); + if (while_hlo->while_body() + ->root_instruction() + ->operand(1) + ->operand(0) + ->opcode() == HloOpcode::kCopy) { + EXPECT_THAT( + while_hlo->while_body()->root_instruction(), + op::Tuple(op::Add(), op::Tuple(op::Copy(op::Add()), op::Add()))); + } else { + EXPECT_THAT( + while_hlo->while_body()->root_instruction(), + op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Add())))); + } } // Tests while init instruction buffer which interferes with while result @@ -1031,11 +1200,13 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) { // TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) { auto while_hlo = BuildWhileInstruction_InitPointsToInterfering(); - auto old_init = while_hlo->operand(0); + InsertCopies(module_.get()); + EXPECT_EQ(CountCopies(*module_), 2); + EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0); - EXPECT_THAT(while_hlo->operand(0), op::Tuple(op::Copy(old_init->operand(0)), - op::Copy(old_init->operand(1)))); + EXPECT_THAT(while_hlo->operand(0), + op::Tuple(op::Copy(op::Constant()), op::Copy(op::Broadcast()))); } // Tests while init instruction buffer which has a non-distinct points-to set: @@ -1044,18 +1215,21 @@ TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) { // Parameter(F32, {8}))) // // where the second and third parameters are identical *and* the tuple shared -// by another while instruction.. +// by another while instruction. // // Verifies that the resulting point-to set is distinct in the resulting Tuple // (non-identical Copys). In other words, verifies that copy sharing does not // insert identical copies to the resulting tuple. TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) { - auto condition1 = - module_->AddEmbeddedComputation(BuildConditionComputation()); - auto condition2 = - module_->AddEmbeddedComputation(BuildConditionComputation()); // Loop body that outputs tuple comprises two elements dependent on the init // tuple. + const Shape& loop_state_shape = ShapeUtil::MakeTupleShape( + {induction_variable_shape_, data_shape_, data_shape_}); + + auto condition1 = module_->AddEmbeddedComputation( + BuildConditionComputation(loop_state_shape)); + auto condition2 = module_->AddEmbeddedComputation( + BuildConditionComputation(loop_state_shape)); auto body1 = module_->AddEmbeddedComputation(BuildDependentBodyComputation2()); auto body2 = @@ -1072,8 +1246,6 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) { auto loop_init = builder.AddInstruction( HloInstruction::CreateTuple({iter_param, data_param, data_param})); - const Shape& loop_state_shape = ShapeUtil::MakeTupleShape( - {induction_variable_shape_, data_shape_, data_shape_}); // Two while loops shares the same loop init tuple. auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile( @@ -1081,43 +1253,478 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) { auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile( loop_state_shape, condition2, body2, loop_init)); - module_->AddEntryComputation(builder.Build()); + // Add add instruction so neither while is dead. + auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0)); + auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo2, 0)); + builder.AddInstruction( + HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2)); - auto points_to_analysis = - TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); + module_->AddEntryComputation(builder.Build()); - // Asserts that the init tuples before copy insertion is non-distinct. - ASSERT_FALSE( - points_to_analysis->GetPointsToSet(while_hlo1->operand(0)).IsDistinct()); - ASSERT_FALSE( - points_to_analysis->GetPointsToSet(while_hlo2->operand(0)).IsDistinct()); + InsertCopies(module_.get()); - auto old_init1 = while_hlo1->operand(0); - auto old_init2 = while_hlo2->operand(0); + // None of the bodies should have copies or control flow edges. + EXPECT_EQ(CountCopies(*body1), 0); + EXPECT_EQ(CountCopies(*body2), 0); - InsertCopies(module_.get()); + // The loop bodies pass through elements 1 and 2 in the init tuple, so ideally + // these should not need to be copied before either while. However, copy + // insertion is not able to reason about the transparency of elements through + // while bodies in all circumstances so extra copies are added (b/xxx). + EXPECT_EQ(CountCopies(*module_->entry_computation()), 2); EXPECT_THAT(while_hlo1->operand(0), - op::Tuple(op::Copy(old_init1->operand(0)), - op::Copy(old_init1->operand(1)), - op::Copy(old_init1->operand(2)))); - + op::Tuple(op::Copy(), op::Parameter(), op::Parameter())); EXPECT_THAT(while_hlo2->operand(0), - op::Tuple(op::Copy(old_init2->operand(0)), - op::Copy(old_init2->operand(1)), - op::Copy(old_init2->operand(2)))); - - // Verifies the init tuples after copy insertion is distinct. - points_to_analysis = - TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); - const auto& points_to1 = - points_to_analysis->GetPointsToSet(while_hlo1->operand(0)); - EXPECT_TRUE(points_to1.IsDistinct()); - - const auto& points_to2 = - points_to_analysis->GetPointsToSet(while_hlo2->operand(0)); - EXPECT_TRUE(points_to2.IsDistinct()); + op::Tuple(op::Copy(), op::Parameter(), op::Parameter())); } +TEST_F(CopyInsertionTest, SwizzlingWhile) { + // Test a while instruction with a body which permutes its tuple parameter + // elements. + auto module = CreateNewModule(); + const Shape loop_state_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + // Body simply interchanges the two tuple elements in the loop state. + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "param")); + auto body_element_0 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); + auto body_element_1 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); + body_builder.AddInstruction( + HloInstruction::CreateTuple({body_element_1, body_element_0})); + HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); + + auto cond_builder = HloComputation::Builder("condition"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "param")); + auto cond_constant = cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + cond_builder.AddInstruction(HloInstruction::CreateUnary( + cond_constant->shape(), HloOpcode::kNot, cond_constant)); + HloComputation* condition = + module->AddEmbeddedComputation(cond_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + auto tuple = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant2})); + auto xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple)); + module->AddEntryComputation(builder.Build()); + + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*module), 6); + + // The loop state elements should be copied at the parameter and at the root + // with a control edge in between (see DeepCopyAndAddControlEdges). This is + // technically one more copy than is strictly necessary, but in order to have + // only three copies the copies of different loop state elements must be + // ordered with a control edge. + EXPECT_EQ(CountCopies(*body), 4); + EXPECT_EQ(CountControlEdges(*body), 2); + + EXPECT_THAT(body->root_instruction(), + op::Tuple(op::Copy(op::Copy()), op::Copy(op::Copy()))); + + EXPECT_EQ(CountCopies(*module->entry_computation()), 2); + EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy())); +} + +TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) { + // Test a while instruction with a body which permutes its tuple parameter + // elements and applies one operation to one of the elements. The addition of + // the operation (instruction) on the element makes the live range of the + // respective input and output elements different than if the instruction were + // not there (as in the SwizzlingWhile test above). + auto module = CreateNewModule(); + const Shape loop_state_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + // Body interchanges the two tuple elements in the loop state and negates one + // of them. + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "param")); + auto body_element_0 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); + auto body_element_1 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); + auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary( + scalar_shape_, HloOpcode::kNegate, body_element_1)); + body_builder.AddInstruction( + HloInstruction::CreateTuple({negate, body_element_0})); + HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); + + auto cond_builder = HloComputation::Builder("condition"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "param")); + auto cond_constant = cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + cond_builder.AddInstruction(HloInstruction::CreateUnary( + cond_constant->shape(), HloOpcode::kNot, cond_constant)); + HloComputation* condition = + module->AddEmbeddedComputation(cond_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + auto tuple = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant2})); + auto xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple)); + module->AddEntryComputation(builder.Build()); + + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*module), 6); + + // The loop state elements should be copied at the parameter and at the root + // with a control edge in between (see DeepCopyAndAddControlEdges). + EXPECT_EQ(CountCopies(*body), 4); + EXPECT_EQ(CountControlEdges(*body), 2); + + EXPECT_THAT( + body->root_instruction(), + op::Tuple(op::Copy(op::Negate(op::Copy())), op::Copy(op::Copy()))); + + EXPECT_EQ(CountCopies(*module->entry_computation()), 2); + EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy())); +} + +TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) { + // Test a while instruction with a body which permutes it's tuple parameter + // elements similar to SwizzlinWhile above. However, in this test the input to + // the while body is a single constant (both loop state elements are the same + // constant). This means no copies are necessary because both loop state + // elements are the same so interchanging them is a no-op. + auto module = CreateNewModule(); + const Shape loop_state_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + // Body simply interchanges the two tuple elements in the loop state. + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "param")); + auto body_element_0 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); + auto body_element_1 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); + body_builder.AddInstruction( + HloInstruction::CreateTuple({body_element_1, body_element_0})); + HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); + + auto cond_builder = HloComputation::Builder("condition"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "param")); + auto cond_constant = cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + cond_builder.AddInstruction(HloInstruction::CreateUnary( + cond_constant->shape(), HloOpcode::kNot, cond_constant)); + HloComputation* condition = + module->AddEmbeddedComputation(cond_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({constant, constant})); + builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple)); + module->AddEntryComputation(builder.Build()); + + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*module), 2); + EXPECT_EQ(CountCopies(*body), 0); + + EXPECT_EQ(CountCopies(*module->entry_computation()), 2); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::Copy(), op::Copy())); +} + +TEST_F(CopyInsertionTest, SequentialWhiles) { + // Construct a computation with a series of sequential while instructions + // containing four loop state elements: + // + // element 0 is passed to each while directly from an entry parameter. + // + // element 1 is passed transparently in series through all the while bodies. + // + // element 2 is negated in each while body. (in-place possible) + // + // element 3 is reversed in each while body. (in-place not possible) + // + const Shape element_shape = ShapeUtil::MakeShape(F32, {42}); + const Shape loop_state_shape = ShapeUtil::MakeTupleShape( + {element_shape, element_shape, element_shape, element_shape}); + + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto param_0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, element_shape, "param_0")); + auto param_1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, element_shape, "param_1")); + auto param_2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, element_shape, "param_2")); + auto param_3 = builder.AddInstruction( + HloInstruction::CreateParameter(3, element_shape, "param_3")); + + // The number of sequential kWhile instructions. + const int kNumWhiles = 3; + + HloInstruction* prev_element_1 = param_1; + HloInstruction* prev_element_2 = param_2; + HloInstruction* prev_element_3 = param_3; + + // Vector containing all of the while instructions. + std::vector whiles; + for (int i = 0; i < kNumWhiles; ++i) { + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "param")); + auto body_element_0 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(element_shape, body_param, 0)); + auto body_element_1 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(element_shape, body_param, 1)); + auto body_element_2 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(element_shape, body_param, 2)); + auto body_element_3 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(element_shape, body_param, 3)); + auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary( + element_shape, HloOpcode::kNegate, body_element_2)); + auto reverse = body_builder.AddInstruction( + HloInstruction::CreateReverse(element_shape, body_element_3, {0})); + body_builder.AddInstruction(HloInstruction::CreateTuple( + {body_element_0, body_element_1, negate, reverse})); + HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); + + auto cond_builder = HloComputation::Builder("condition"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "param")); + auto cond_constant = cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + cond_builder.AddInstruction(HloInstruction::CreateUnary( + cond_constant->shape(), HloOpcode::kNot, cond_constant)); + HloComputation* condition = + module->AddEmbeddedComputation(cond_builder.Build()); + + auto while_init = builder.AddInstruction(HloInstruction::CreateTuple( + {param_0, prev_element_1, prev_element_2, prev_element_3})); + + auto xla_while = builder.AddInstruction(HloInstruction::CreateWhile( + loop_state_shape, condition, body, while_init)); + whiles.push_back(xla_while); + if (i != kNumWhiles - 1) { + prev_element_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(element_shape, xla_while, 1)); + prev_element_2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(element_shape, xla_while, 2)); + prev_element_3 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(element_shape, xla_while, 3)); + } + } + + module->AddEntryComputation(builder.Build()); + + InsertCopies(module.get()); + + // Each while body has one copy. And each loop state element is copied once in + // the entry computation. + EXPECT_EQ(CountCopies(*module), 4 + kNumWhiles); + + // Each while body should have exactly one copy for element three which is an + // op (kReverse) which cannot be done in place. + for (const HloInstruction* xla_while : whiles) { + EXPECT_EQ(CountCopies(*xla_while->while_body()), 1); + } + + EXPECT_THAT(whiles[0]->operand(0), op::Tuple(op::Parameter(), op::Parameter(), + op::Copy(), op::Copy())); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::Copy(), op::Copy(), op::GetTupleElement(), + op::GetTupleElement())); +} + +TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) { + // Test a while body and condition which are each simply a constant (root of + // computation is a constant). The body constant should be copied. + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto param_0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param_0")); + + auto body_builder = HloComputation::Builder("body"); + body_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + body_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(123.0))); + HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); + + auto cond_builder = HloComputation::Builder("condition"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloComputation* condition = + module->AddEmbeddedComputation(cond_builder.Build()); + + auto xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(scalar_shape_, condition, body, param_0)); + + module->AddEntryComputation(builder.Build()); + + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*module), 2); + + EXPECT_THAT(xla_while->operand(0), op::Copy(op::Parameter())); + EXPECT_THAT(body->root_instruction(), op::Copy(op::Constant())); + EXPECT_THAT(condition->root_instruction(), op::Constant()); +} + +std::unique_ptr MakeTrivialCondition(const Shape& shape) { + auto builder = HloComputation::Builder("trivial_condition"); + builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "loop_state")); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + builder.AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kNot, constant)); + return builder.Build(); +} + +std::unique_ptr MakeBenchmarkWhileBody() { + auto builder = HloComputation::Builder("benchmark_loop_body"); + const Shape element_shape = ShapeUtil::MakeShape(F32, {42}); + const Shape loop_state_shape = + ShapeUtil::MakeTupleShape({element_shape, element_shape, element_shape}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); + HloInstruction* element_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(element_shape, param, 0)); + HloInstruction* element_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(element_shape, param, 1)); + HloInstruction* element_2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(element_shape, param, 2)); + + HloInstruction* rev_1 = builder.AddInstruction( + HloInstruction::CreateReverse(element_shape, element_1, {0})); + HloInstruction* add_1_2 = builder.AddInstruction(HloInstruction::CreateBinary( + element_shape, HloOpcode::kAdd, element_1, element_2)); + + builder.AddInstruction( + HloInstruction::CreateTuple({element_0, rev_1, add_1_2})); + return builder.Build(); +} + +void BM_SequentialWhiles(int num_iters, int num_whiles) { + // This benchmark constructs a chain of sequential while instructions. + tensorflow::testing::StopTiming(); + for (int i = 0; i < num_iters; ++i) { + HloModuleConfig config; + config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + HloModule module("BM_SequentialWhiles", VersionedComputationHandle(), + config); + + auto builder = HloComputation::Builder("BM_SequentialWhiles"); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {42}), "x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {42}), "y")); + HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( + 2, ShapeUtil::MakeShape(F32, {42}), "z")); + HloInstruction* init = + builder.AddInstruction(HloInstruction::CreateTuple({x, y, z})); + + HloInstruction* prev_loop_state = init; + for (int w = 0; w < num_whiles; ++w) { + HloComputation* condition = + module.AddEmbeddedComputation(MakeTrivialCondition(init->shape())); + HloComputation* body = + module.AddEmbeddedComputation(MakeBenchmarkWhileBody()); + prev_loop_state = builder.AddInstruction(HloInstruction::CreateWhile( + init->shape(), condition, body, prev_loop_state)); + } + module.AddEntryComputation(builder.Build()); + + CopyInsertion copy_insertion; + + tensorflow::testing::StartTiming(); + ASSERT_IS_OK(copy_insertion.Run(&module).status()); + tensorflow::testing::StopTiming(); + + // The entry computation should have three copies, and each body has one. + ASSERT_EQ(CountCopies(module), 3 + num_whiles); + } +} + +void BM_ParallelWhiles(int num_iters, int num_whiles) { + // This benchmark constructs a fan-out of parallel while instructions. + tensorflow::testing::StopTiming(); + for (int i = 0; i < num_iters; ++i) { + HloModuleConfig config; + config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + HloModule module("BM_SequentialWhiles", VersionedComputationHandle(), + config); + + auto builder = HloComputation::Builder("BM_ParallelWhiles"); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {42}), "x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {42}), "y")); + HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( + 2, ShapeUtil::MakeShape(F32, {42}), "z")); + HloInstruction* init = + builder.AddInstruction(HloInstruction::CreateTuple({x, y, z})); + + HloInstruction* sum = nullptr; + for (int w = 0; w < num_whiles; ++w) { + HloComputation* condition = + module.AddEmbeddedComputation(MakeTrivialCondition(init->shape())); + HloComputation* body = + module.AddEmbeddedComputation(MakeBenchmarkWhileBody()); + + HloInstruction* xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(init->shape(), condition, body, init)); + + if (sum == nullptr) { + sum = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(x->shape(), xla_while, 0)); + } else { + HloInstruction* element_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(x->shape(), xla_while, 0)); + sum = builder.AddInstruction(HloInstruction::CreateBinary( + x->shape(), HloOpcode::kAdd, sum, element_0)); + } + } + module.AddEntryComputation(builder.Build()); + + CopyInsertion copy_insertion; + + tensorflow::testing::StartTiming(); + ASSERT_IS_OK(copy_insertion.Run(&module).status()); + tensorflow::testing::StopTiming(); + + // Each body receives of copy of two of the parameters (the corresponding + // elements in the body are modifed), and there is one copy in each body. + ASSERT_EQ(CountCopies(module), 3 * num_whiles); + } +} + +BENCHMARK(BM_SequentialWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096); +BENCHMARK(BM_ParallelWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 78216f2ffb9c58d7f4b7ca31cb740d547ea1d470..ed142bd077fc20f5e2e563132df95bff10a37f0f 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -79,15 +79,16 @@ cc_library( deps = [ ":compiler_functor", ":conv_canonicalization", + ":cpu_copy_insertion", ":cpu_executable", ":cpu_instruction_fusion", + ":cpu_layout_assignment", ":cpu_options", ":cpu_parallelization_preparation", ":disassembler", ":dot_op_emitter", ":ir_emission_utils", ":ir_emitter", - ":layout_assignment", ":parallel_cpu_executable", ":parallel_task_assignment", ":simple_orc_jit", @@ -99,17 +100,18 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:algebraic_simplifier", - "//tensorflow/compiler/xla/service:batchnorm_rewriter", + "//tensorflow/compiler/xla/service:batchnorm_expander", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", - "//tensorflow/compiler/xla/service:copy_insertion", + "//tensorflow/compiler/xla/service:dot_decomposer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_constant_folding", "//tensorflow/compiler/xla/service:hlo_cse", "//tensorflow/compiler/xla/service:hlo_dce", + "//tensorflow/compiler/xla/service:hlo_element_type_converter", "//tensorflow/compiler/xla/service:hlo_ordering", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_pass_pipeline", @@ -250,6 +252,8 @@ cc_library( ":dot_op_emitter", ":external_constant_pool", ":ir_emission_utils", + ":ir_function", + ":parallel_loop_emitter", ":shape_partition", ":simple_orc_jit", "//tensorflow/compiler/xla:shape_util", @@ -273,12 +277,48 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:ops", "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", + "@llvm//:analysis", + "@llvm//:code_gen", "@llvm//:core", "@llvm//:support", "@llvm//:target", ], ) +cc_library( + name = "ir_function", + srcs = ["ir_function.cc"], + hdrs = ["ir_function.h"], + deps = [ + ":ir_emission_utils", + ":shape_partition", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/service/cpu:cpu_runtime", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:lib", + "@llvm//:core", + ], +) + +cc_library( + name = "parallel_loop_emitter", + srcs = ["parallel_loop_emitter.cc"], + hdrs = ["parallel_loop_emitter.h"], + deps = [ + ":ir_emission_utils", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", + "//tensorflow/core:lib", + "@llvm//:core", + ], +) + cc_library( name = "dot_op_emitter", srcs = ["dot_op_emitter.cc"], @@ -614,13 +654,14 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/service:hlo", + "@llvm//:core", ], ) cc_library( - name = "layout_assignment", - srcs = ["layout_assignment.cc"], - hdrs = ["layout_assignment.h"], + name = "cpu_layout_assignment", + srcs = ["cpu_layout_assignment.cc"], + hdrs = ["cpu_layout_assignment.h"], deps = [ ":dot_op_emitter", ":ir_emission_utils", @@ -632,11 +673,11 @@ cc_library( ) tf_cc_test( - name = "layout_assignment_test", + name = "cpu_layout_assignment_test", size = "small", - srcs = ["layout_assignment_test.cc"], + srcs = ["cpu_layout_assignment_test.cc"], deps = [ - ":layout_assignment", + ":cpu_layout_assignment", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", @@ -750,6 +791,38 @@ cc_library( ], ) +cc_library( + name = "cpu_copy_insertion", + srcs = ["cpu_copy_insertion.cc"], + hdrs = ["cpu_copy_insertion.h"], + deps = [ + "//tensorflow/compiler/xla/service:copy_insertion", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "cpu_copy_insertion_test", + srcs = ["cpu_copy_insertion_test.cc"], + deps = [ + ":cpu_copy_insertion", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_graph_dumper", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc index 44cd2171afdc6eecc22f3f920276a4d95f930573..2136aeb3877685373efaf5bf702a42b39a63f082 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc @@ -41,19 +41,17 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { auto kernel_input_feature_dim = dnums.kernel_input_feature_dimension(); auto kernel_output_feature_dim = dnums.kernel_output_feature_dimension(); - int num_spatial_dims = dnums.spatial_dimensions_size(); - int num_dims = num_spatial_dims + 2; + const int64 num_spatial_dims = dnums.output_spatial_dimensions_size(); + const int64 num_dims = num_spatial_dims + 2; // A canonical convolution's dimension numbers need to satisfy the // following conditions (see cs/PotentiallyImplementedAsEigenConvolution). // - // - the input is in NHWC or NWHC order. - // - the kernel is in HWIO or WHIO order. - // - the spatial dimensions are in the same relative order in the input, - // kernel and output. + // - the input is in NHWC order. + // - the kernel is in HWIO order. // // For simplicity, as a first step, we reshape the input and filter to - // NHWC and HWIO order, respectively. This may lose precision but not + // NHWC and HWIO order, respectively. This may lose precision but won't // break the soundness. HloInstruction* input = hlo->mutable_operand(0); @@ -61,10 +59,10 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { std::vector new_input_dims(num_dims); new_input_dim_order[0] = input_batch_dim; new_input_dims[0] = input->shape().dimensions(input_batch_dim); - for (int i = 0; i < num_spatial_dims; ++i) { - new_input_dim_order[i + 1] = dnums.spatial_dimensions(i); + for (int64 i = 0; i < num_spatial_dims; ++i) { + new_input_dim_order[i + 1] = dnums.input_spatial_dimensions(i); new_input_dims[i + 1] = - input->shape().dimensions(dnums.spatial_dimensions(i)); + input->shape().dimensions(dnums.input_spatial_dimensions(i)); } new_input_dim_order[num_dims - 1] = input_feature_dim; new_input_dims[num_dims - 1] = @@ -80,7 +78,7 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { std::vector new_kernel_dim_order(num_dims); std::vector new_kernel_dims(num_dims); - for (int i = 0; i < num_spatial_dims; ++i) { + for (int64 i = 0; i < num_spatial_dims; ++i) { new_kernel_dim_order[i] = dnums.kernel_spatial_dimensions(i); new_kernel_dims[i] = kernel->shape().dimensions(dnums.kernel_spatial_dimensions(i)); @@ -98,14 +96,18 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { HloInstruction::CreateTranspose(new_kernel_shape, kernel, new_kernel_dim_order)); + std::vector new_output_dim_order(num_dims); std::vector new_conv_dims(num_dims); auto output_batch_dim = dnums.output_batch_dimension(); auto output_feature_dim = dnums.output_feature_dimension(); + new_output_dim_order[0] = output_batch_dim; new_conv_dims[0] = hlo->shape().dimensions(output_batch_dim); - for (int i = 0; i < num_spatial_dims; ++i) { + for (int64 i = 0; i < num_spatial_dims; ++i) { + new_output_dim_order[i + 1] = dnums.output_spatial_dimensions(i); new_conv_dims[i + 1] = - hlo->shape().dimensions(dnums.spatial_dimensions(i)); + hlo->shape().dimensions(dnums.output_spatial_dimensions(i)); } + new_output_dim_order[num_dims - 1] = output_feature_dim; new_conv_dims[num_dims - 1] = hlo->shape().dimensions(output_feature_dim); Shape new_conv_shape = ShapeUtil::MakeShape(hlo->shape().element_type(), new_conv_dims); @@ -113,9 +115,10 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { ConvolutionDimensionNumbers new_dnums; new_dnums.set_input_batch_dimension(0); new_dnums.set_output_batch_dimension(0); - for (int i = 0; i < num_spatial_dims; ++i) { - new_dnums.add_spatial_dimensions(i + 1); + for (int64 i = 0; i < num_spatial_dims; ++i) { + new_dnums.add_input_spatial_dimensions(i + 1); new_dnums.add_kernel_spatial_dimensions(i); + new_dnums.add_output_spatial_dimensions(i + 1); } new_dnums.set_input_feature_dimension(num_dims - 1); new_dnums.set_output_feature_dimension(num_dims - 1); @@ -129,14 +132,11 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { HloInstruction::CreateConvolve(new_conv_shape, new_input, new_kernel, hlo->window(), new_dnums)); - // kConvolution inherits the dimension mapping of its input, so we need to - // reshape the output back to the shape of the original convolution. This - // is done by apply the inverse permutation of the collapsing order of the - // input reshape. + // Reshape the output back to the shape of the original convolution. TF_RETURN_IF_ERROR(module->entry_computation()->ReplaceWithNewInstruction( hlo, HloInstruction::CreateTranspose( hlo->shape(), new_conv, - InversePermutation(new_input_dim_order)))); + InversePermutation(new_output_dim_order)))); changed = true; } } diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index d593ba26b655d00a0f0f0b9a94c9e62fa1835080..968f53d5c706651d2a470a853e0e9b601c0ed2df 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -69,8 +69,10 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { ConvolutionDimensionNumbers dnums; dnums.set_input_batch_dimension(1); dnums.set_output_batch_dimension(1); - dnums.add_spatial_dimensions(2); - dnums.add_spatial_dimensions(3); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(3); + dnums.add_output_spatial_dimensions(3); dnums.set_input_feature_dimension(0); dnums.set_output_feature_dimension(0); dnums.add_kernel_spatial_dimensions(2); @@ -125,8 +127,10 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { ConvolutionDimensionNumbers dnums; dnums.set_input_batch_dimension(0); dnums.set_output_batch_dimension(0); - dnums.add_spatial_dimensions(1); - dnums.add_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); dnums.set_input_feature_dimension(3); dnums.set_output_feature_dimension(3); dnums.add_kernel_spatial_dimensions(0); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index f5b95d3657cb91623aa043f7544760c11fc87408..6dc30bfe2cd036f7e83b054d25361072ac5077e9 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -42,32 +42,34 @@ limitations under the License. #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" -#include "tensorflow/compiler/xla/service/batchnorm_rewriter.h" +#include "tensorflow/compiler/xla/service/batchnorm_expander.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" -#include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h" #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h" #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" #include "tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h" #include "tensorflow/compiler/xla/service/cpu/disassembler.h" #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/ir_emitter.h" -#include "tensorflow/compiler/xla/service/cpu/layout_assignment.h" #include "tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h" #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" #include "tensorflow/compiler/xla/service/hlo_cse.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" @@ -197,28 +199,35 @@ void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) { class CollectProfileCandidates : public DfsHloVisitorWithDefault { public: static StatusOr> - GetCandidatesForComputation(HloComputation* computation) { + GetCandidatesForComputation( + HloComputation* computation, + const std::unordered_map& + assigned_indices) { std::unordered_map hlo_to_profile_idx; CollectProfileCandidates profile_candidates_for_computation( - &hlo_to_profile_idx); + &hlo_to_profile_idx, assigned_indices); TF_RETURN_IF_ERROR( computation->Accept(&profile_candidates_for_computation)); return hlo_to_profile_idx; } private: - explicit CollectProfileCandidates( - std::unordered_map* hlo_to_profile_idx) - : hlo_to_profile_idx_(hlo_to_profile_idx) {} + CollectProfileCandidates( + std::unordered_map* hlo_to_profile_idx, + const std::unordered_map& assigned_indices) + : hlo_to_profile_idx_(hlo_to_profile_idx), + assigned_indices_(assigned_indices) {} Status DefaultAction(HloInstruction* hlo_instruction) override { - hlo_to_profile_idx_->insert({hlo_instruction, hlo_to_profile_idx_->size()}); + hlo_to_profile_idx_->insert( + {hlo_instruction, FindOrDie(assigned_indices_, hlo_instruction)}); return Status::OK(); } Status HandleCall(HloInstruction* call) override { TF_RETURN_IF_ERROR(DefaultAction(call)); - CollectProfileCandidates candidates_for_call(hlo_to_profile_idx_); + CollectProfileCandidates candidates_for_call(hlo_to_profile_idx_, + assigned_indices_); TF_RETURN_IF_ERROR(call->to_apply()->Accept(&candidates_for_call)); return Status::OK(); } @@ -232,17 +241,20 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { Status HandleWhile(HloInstruction* xla_while) override { TF_RETURN_IF_ERROR(DefaultAction(xla_while)); - CollectProfileCandidates candidates_for_condition(hlo_to_profile_idx_); + CollectProfileCandidates candidates_for_condition(hlo_to_profile_idx_, + assigned_indices_); TF_RETURN_IF_ERROR( xla_while->while_condition()->Accept(&candidates_for_condition)); - CollectProfileCandidates candidates_for_body(hlo_to_profile_idx_); + CollectProfileCandidates candidates_for_body(hlo_to_profile_idx_, + assigned_indices_); TF_RETURN_IF_ERROR(xla_while->while_body()->Accept(&candidates_for_body)); return Status::OK(); } std::unordered_map* hlo_to_profile_idx_; + const std::unordered_map& assigned_indices_; }; } // namespace @@ -262,14 +274,14 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { // TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner // pass. pipeline.AddPass(); - + pipeline.AddPass(); pipeline.AddPass(); { auto& pass = pipeline.AddPass>("simplification"); pass.AddInvariantChecker(ShapeSizeBytesFunction()); - pass.AddPass( + pass.AddPass( /*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true, @@ -277,7 +289,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { pass.AddPass( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }, - /*enable_dot_simplification=*/false); + /*enable_dot_strength_reduction=*/false); pass.AddPass(); pass.AddPass(); pass.AddPass(); @@ -306,8 +318,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { pipeline.AddPass>( /*is_layout_sensitive=*/true, [](const Shape&, const Shape&) { return true; }, - /*enable_dot_simplification=*/false); + /*enable_dot_strength_reduction=*/false); pipeline.AddPass(/*is_layout_sensitive=*/true); + pipeline.AddPass(BF16, F32); // Outline ops in the entry computation into calls to subcomputations. const int max_parallelism = module->config().intra_op_parallelism_threads() > 0 @@ -332,15 +345,16 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { // (and sometime after) copy insertion, to avoid dead code from interfering // with the rewrites. pipeline.AddPass(); - pipeline.AddPass(); + pipeline.AddPass(); + pipeline.AddPass(); if (options::CpuParallelBackendRequested(module->config())) { // Re-run the outlining, in case any copies were inserted into the entry // computation. pipeline.AddPass(max_parallelism, ShapeSizeBytesFunction()); + pipeline.AddPass(); } pipeline.AddPass(); - pipeline.AddPass(); return pipeline.Run(module).status(); } @@ -426,11 +440,25 @@ Status InitializeModuleHooks( } // namespace -StatusOr> CpuCompiler::Compile( - std::unique_ptr module, se::StreamExecutor* stream_exec) { +StatusOr> CpuCompiler::RunHloPasses( + std::unique_ptr module, + perftools::gputools::StreamExecutor* /*stream_exec*/) { + VLOG(2) << "Before optimization:"; + XLA_VLOG_LINES(2, module->ToString()); + + TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false)); + + VLOG(2) << "After optimization:"; + XLA_VLOG_LINES(2, module->ToString()); + return std::move(module); +} + +StatusOr> CpuCompiler::RunBackend( + std::unique_ptr module, + perftools::gputools::StreamExecutor* stream_exec) { const string timer_message = "Compiling [" + module->name() + "] for CPU using JIT"; - ScopedLoggingTimer compiling_timer(timer_message, 1); + XLA_SCOPED_LOGGING_TIMER(timer_message); VLOG(1) << "Compiling: " << module->name(); TF_RET_CHECK(stream_exec != nullptr); @@ -444,11 +472,11 @@ StatusOr> CpuCompiler::Compile( &pre_optimization_ir_hook, &post_optimization_ir_hook)); // Compile must be thread-safe so create a new LLVM context for the module. - auto llvm_context = MakeUnique(); + auto llvm_context = xla::MakeUnique(); auto llvm_module = - MakeUnique("__compute_module", *llvm_context); + xla::MakeUnique("__compute_module", *llvm_context); - auto jit = MakeUnique( + auto jit = xla::MakeUnique( CompilerTargetOptions(module->config()), CodeGenOptLevel(module->config()), options::OptimizeForSizeRequested(module->config()), @@ -458,20 +486,29 @@ StatusOr> CpuCompiler::Compile( llvm_module->setDataLayout(jit->data_layout()); llvm_module->setTargetTriple(jit->target_triple().getTriple()); - VLOG(2) << "Before optimization:"; - XLA_VLOG_LINES(2, module->ToString()); - - TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false)); - - VLOG(2) << "After optimization:"; - XLA_VLOG_LINES(2, module->ToString()); - HloComputation* computation = module->entry_computation(); std::unordered_map hlo_to_profile_idx; + std::unique_ptr hlo_profile_index_map; + std::unique_ptr hlo_profile_printer; if (module->config().hlo_profiling_enabled()) { + hlo_profile_index_map = MakeUnique(*module); + TF_ASSIGN_OR_RETURN( hlo_to_profile_idx, - CollectProfileCandidates::GetCandidatesForComputation(computation)); + CollectProfileCandidates::GetCandidatesForComputation( + computation, hlo_profile_index_map->instruction_to_profile_idx())); + + auto shape_size_bytes = [](const Shape& shape) { + // On the cpu, opaques are pointers. + if (ShapeUtil::IsOpaque(shape)) { + return static_cast(sizeof(void*)); + } + return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); + }; + + HloCostAnalysis cost_analysis(shape_size_bytes); + hlo_profile_printer = + CreateHloProfilePrinter(*hlo_profile_index_map, cost_analysis); } std::unique_ptr cpu_executable; @@ -494,9 +531,9 @@ StatusOr> CpuCompiler::Compile( // uses data dependencies for determining order. TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, - BufferAssigner::Run(module.get(), - MakeUnique(module.get()), - BufferSizeBytesFunction(), memory_alignment)); + BufferAssigner::Run( + module.get(), xla::MakeUnique(module.get()), + BufferSizeBytesFunction(), memory_alignment)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); @@ -523,7 +560,7 @@ StatusOr> CpuCompiler::Compile( const void* data = instruction->literal().InternalData(); int64 size = CpuExecutable::ShapeSizeBytes(instruction->shape()); auto iter = aligned_constants.emplace( - instruction, MakeUnique(size)); + instruction, xla::MakeUnique(size)); CHECK_EQ(iter.second, true); unsigned char* aligned_data = iter.first->second.get(); memcpy(aligned_data, data, size); @@ -537,11 +574,17 @@ StatusOr> CpuCompiler::Compile( parallel_computations.emplace(to_apply, instruction); } - size_t entry_computation_profile_idx = hlo_to_profile_idx.size(); - IrEmitter ir_emitter( - *module, *assignment, llvm_module.get(), std::move(hlo_to_profile_idx), - /*entry_computation_profile_idx=*/entry_computation_profile_idx, - jit->target_machine(), jit->external_constant_pool()); + // We always profile the entire computation as a whole, even if hlo + // profiling is disabled. When hlo profiling is diabled, we pass in a + // profile counter array of just one element, which corresponds to the whole + // computation. + size_t entry_computation_profile_idx = + hlo_profile_index_map ? hlo_profile_index_map->GetProfileIndexFor( + *module->entry_computation()) + : 0; + IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), + hlo_to_profile_idx, entry_computation_profile_idx, + jit->target_machine(), jit->external_constant_pool()); std::unique_ptr> function_names( new HloInstructionMap()); @@ -560,7 +603,7 @@ StatusOr> CpuCompiler::Compile( llvm::Function * ir_function, ir_emitter.EmitComputation( embedded_computation, embedded_computation->name(), - /*is_entry_computation=*/computation_is_parallel, + /*is_top_level_computation=*/computation_is_parallel, /*instruction_order=*/nullptr)); // If this computation is parallel, remember it in the function name map. // This way we know what function to execute when we try to run code for @@ -581,8 +624,8 @@ StatusOr> CpuCompiler::Compile( jit->AddModule(std::move(llvm_module)); cpu_executable.reset(new ParallelCpuExecutable( std::move(jit), std::move(assignment), std::move(module), - std::move(function_names), std::move(hlo_to_profile_idx), - std::move(aligned_constants))); + std::move(function_names), std::move(aligned_constants), + std::move(hlo_profile_printer), std::move(hlo_profile_index_map))); if (embed_ir_in_executable) { static_cast(*cpu_executable) @@ -602,10 +645,10 @@ StatusOr> CpuCompiler::Compile( // temporary buffers are required to run the computation. TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, - BufferAssigner::Run( - module.get(), - MakeUnique(module.get(), module_sequence), - BufferSizeBytesFunction(), memory_alignment)); + BufferAssigner::Run(module.get(), + xla::MakeUnique( + module.get(), module_sequence), + BufferSizeBytesFunction(), memory_alignment)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); @@ -615,15 +658,23 @@ StatusOr> CpuCompiler::Compile( TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( proto, xla_dump_hlo_proto_to, module->name())); } + // We always profile the entire computation as a whole, even if hlo + // profiling is disabled. When hlo profiling is diabled, we pass in a + // profile counter array of just one element, which corresponds to the whole + // computation. + size_t entry_computation_profile_idx = + hlo_profile_index_map ? hlo_profile_index_map->GetProfileIndexFor( + *module->entry_computation()) + : 0; + // Each computation is a single function. Emit all embedded computations // before the entry computation. The order of computations returned from // GetEmbeddedComputations guarantees that a called computation occurs // before a caller computation. - size_t entry_computation_profile_idx = hlo_to_profile_idx.size(); - IrEmitter ir_emitter( - *module, *assignment, llvm_module.get(), std::move(hlo_to_profile_idx), - /*entry_computation_profile_idx=*/entry_computation_profile_idx, - jit->target_machine(), jit->external_constant_pool()); + + IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), + hlo_to_profile_idx, entry_computation_profile_idx, + jit->target_machine(), jit->external_constant_pool()); for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { @@ -634,7 +685,7 @@ StatusOr> CpuCompiler::Compile( ir_emitter .EmitComputation(embedded_computation, embedded_computation->name(), - /*is_entry_computation=*/false, + /*is_top_level_computation=*/false, &module_sequence.at(embedded_computation)) .status()); } @@ -643,7 +694,7 @@ StatusOr> CpuCompiler::Compile( TF_ASSIGN_OR_RETURN( llvm::Function * entry_function, ir_emitter.EmitComputation(computation, function_name_prefix, - /*is_entry_computation=*/true, + /*is_top_level_computation=*/true, &module_sequence.at(computation))); string function_name = llvm_ir::AsString(entry_function->getName()); @@ -656,7 +707,7 @@ StatusOr> CpuCompiler::Compile( jit->AddModule(std::move(llvm_module)); cpu_executable.reset(new CpuExecutable( std::move(jit), std::move(assignment), std::move(module), function_name, - std::move(hlo_to_profile_idx))); + std::move(hlo_profile_printer), std::move(hlo_profile_index_map))); if (embed_ir_in_executable) { static_cast(*cpu_executable) @@ -776,7 +827,8 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, BufferAssigner::Run( - module, MakeUnique(module, module_sequence), + module, + xla::MakeUnique(module, module_sequence), BufferSizeBytesFunction(), memory_alignment)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. @@ -807,7 +859,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, ir_emitter .EmitComputation(embedded_computation, embedded_computation->name(), - /*is_entry_computation=*/false, + /*is_top_level_computation=*/false, &module_sequence.at(embedded_computation)) .status()); } @@ -815,7 +867,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, TF_ASSIGN_OR_RETURN( llvm::Function * entry_function, ir_emitter.EmitComputation(computation, entry_point_name, - /*is_entry_computation=*/true, + /*is_top_level_computation=*/true, &module_sequence.at(computation))); CHECK(entry_function->getName() == llvm_ir::AsStringRef(entry_point_name)); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index 963aced208813e58b3d069a80bd88fcb05d8253f..ebed7058d8f7968c6e03ef90d0da6b2325037eb0 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -116,7 +116,11 @@ class CpuCompiler : public LLVMCompiler { // stream_execs) using LLVMCompiler::Compile; - StatusOr> Compile( + StatusOr> RunHloPasses( + std::unique_ptr module, + perftools::gputools::StreamExecutor* stream_exec) override; + + StatusOr> RunBackend( std::unique_ptr module, perftools::gputools::StreamExecutor* stream_exec) override; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.cc b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.cc new file mode 100644 index 0000000000000000000000000000000000000000..baaacd2ecc9611946678f71ac36ef787ecb57b4e --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.cc @@ -0,0 +1,43 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/copy_insertion.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +StatusOr CpuCopyInsertion::Run(HloModule* module) { + CopyInsertion generic_copy_insertion; + + TF_ASSIGN_OR_RETURN(bool generic_changed, generic_copy_insertion.Run(module)); + + // The CPU backend needs additional copies added due to deficiencies in + // buffer assignment. + TF_ASSIGN_OR_RETURN(bool buffer_assignment_changed, + CopyInsertion::AddCopiesForBufferAssignment(module)); + + return generic_changed || buffer_assignment_changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h new file mode 100644 index 0000000000000000000000000000000000000000..3313d1e6eb71bff39f509c3d24858568df786422 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COPY_INSERTION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COPY_INSERTION_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// Besides the modifications made by the generic xla::CopyInsertion, this +// CPU-specific copy insertion pass also adds copies to values live out of +// computations satisfying certain conditions (defined by constant or parameter, +// etc). This is necessary because of deficiencies of buffer +// assignment. Specifically, buffer assignment is computation-scoped and does +// not recognized aliasing between arguments and outputs of computations. +// +// TODO(b/62548313): Remove this when buffer assignment is smarter +// (module-scoped). +class CpuCopyInsertion : public HloPassInterface { + public: + tensorflow::StringPiece name() const override { return "copy-insertion"; } + + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COPY_INSERTION_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a05a26941786cbf404c4685abb098c9ac8caaa09 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc @@ -0,0 +1,139 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h" + +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +int64 CountCopies(const HloComputation& computation) { + int64 count = 0; + for (const auto& instruction : computation.instructions()) { + if (instruction->opcode() == HloOpcode::kCopy) { + count++; + } + } + return count; +} + +int64 CountCopies(const HloModule& module) { + int64 count = 0; + for (const auto& computation : module.computations()) { + count += CountCopies(*computation); + } + return count; +} + +class CpuCopyInsertionTest : public HloTestBase { + protected: + void InsertCopies(HloModule* module) { + CpuCopyInsertion copy_insertion; + ASSERT_IS_OK(copy_insertion.Run(module).status()); + } + + const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); +}; + +TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) { + // Test a while body and condition which are each simply a constant (root of + // computation is a constant). Each constant should be copied. + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto param_0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param_0")); + + auto body_builder = HloComputation::Builder("body"); + body_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + body_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(123.0))); + HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); + + auto cond_builder = HloComputation::Builder("condition"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloComputation* condition = + module->AddEmbeddedComputation(cond_builder.Build()); + + auto xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(scalar_shape_, condition, body, param_0)); + + module->AddEntryComputation(builder.Build()); + + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*module), 3); + + EXPECT_THAT(xla_while->operand(0), op::Copy(op::Parameter())); + EXPECT_THAT(body->root_instruction(), op::Copy(op::Constant())); + EXPECT_THAT(condition->root_instruction(), op::Copy(op::Constant())); +} + +TEST_F(CpuCopyInsertionTest, TupleCall) { + // Test a kCall instruction which calls a computation which produces a three + // element tuple: one is a constant, one is a parameter, and one is produced + // in the computation. The constant and parameter should be copied. + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param_0")); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_, scalar_shape_}); + + auto sub_builder = HloComputation::Builder("subcomputation"); + auto sub_param = sub_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + auto constant = sub_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(123.0))); + auto add = sub_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, sub_param, constant)); + sub_builder.AddInstruction( + HloInstruction::CreateTuple({sub_param, constant, add})); + HloComputation* subcomputation = + module->AddEmbeddedComputation(sub_builder.Build()); + + builder.AddInstruction( + HloInstruction::CreateCall(tuple_shape, {param}, subcomputation)); + + module->AddEntryComputation(builder.Build()); + + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*subcomputation), 2); + EXPECT_THAT(subcomputation->root_instruction(), + op::Tuple(op::Copy(op::Parameter()), op::Copy(op::Constant()), + op::Add())); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index f62353bee7b1058dc237169b70341c33ab19fc52..028f827337979de14ec557a8f0d7a47f095bf55e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -43,6 +43,7 @@ limitations under the License. #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/host/host_stream.h" namespace se = ::perftools::gputools; @@ -54,11 +55,12 @@ CpuExecutable::CpuExecutable( std::unique_ptr assignment, std::unique_ptr hlo_module, const string& entry_function_name, - std::unordered_map hlo_to_profile_idx) - : Executable(std::move(hlo_module)), + std::unique_ptr hlo_profile_printer, + std::unique_ptr hlo_profile_index_map) + : Executable(std::move(hlo_module), std::move(hlo_profile_printer), + std::move(hlo_profile_index_map)), jit_(std::move(jit)), - assignment_(std::move(assignment)), - hlo_to_profile_idx_(std::move(hlo_to_profile_idx)) { + assignment_(std::move(assignment)) { // Resolve symbols in the constructor rather than at execution time to avoid // races because FindSymbol is not thread safe. llvm::JITSymbol sym = jit_->FindSymbol(entry_function_name); @@ -71,28 +73,6 @@ CpuExecutable::CpuExecutable( reinterpret_cast(cantFail(sym.getAddress())); } -// Given a pointer to an output buffer (following the CPU JIT calling -// conventions), mark addresses that are "live". The initial pointer itself is -// trivially live. If the shape of the buffer is a tuple, this analysis looks -// into the tuple's elements and marks them live as well (since tuples keep -// pointers to buffers) and also works recursively. address is an in-memory -// buffer address that contains some runtime XLA object. shape is its -// shape. marked_addresses is the set of live addresses to populate. -static void MarkLiveAddressesInOutput( - const void* address, const Shape& shape, - std::unordered_set* marked_addresses) { - marked_addresses->insert(address); - const uintptr_t* address_buffer = static_cast(address); - if (ShapeUtil::IsTuple(shape)) { - for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - const uintptr_t* element_address = address_buffer + i; - const void* element = reinterpret_cast(*element_address); - MarkLiveAddressesInOutput( - element, ShapeUtil::GetTupleElementShape(shape, i), marked_addresses); - } - } -} - Status CpuExecutable::AllocateBuffers( DeviceMemoryAllocator* memory_allocator, int device_ordinal, std::vector* buffers) { @@ -146,19 +126,6 @@ Status CpuExecutable::ExecuteComputeFunction( tensorflow::gtl::ArraySlice arguments, tensorflow::gtl::ArraySlice buffers, HloExecutionProfile* hlo_execution_profile) { - std::vector argument_buffers; - for (int i = 0; i < arguments.size(); ++i) { - argument_buffers.push_back(arguments[i]->buffer(/*index=*/{})); - } - return ExecuteComputeFunction(run_options, argument_buffers, buffers, - hlo_execution_profile); -} - -Status CpuExecutable::ExecuteComputeFunction( - const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - tensorflow::gtl::ArraySlice buffers, - HloExecutionProfile* hlo_execution_profile) { // The calling convention for JITed functions is: // // void function(void* result, const void* run_options, void** args_array, @@ -174,16 +141,23 @@ Status CpuExecutable::ExecuteComputeFunction( // determined by buffer analysis. // std::vector args_array; - for (se::DeviceMemoryBase arg_mem : arguments) { - args_array.push_back(arg_mem.opaque()); + for (const ShapedBuffer* argument : arguments) { + args_array.push_back(argument->root_buffer().opaque()); } uint64 start_micros = tensorflow::Env::Default()->NowMicros(); // Allocate profiling counters for each hlo instruction that we would like to - // profile. Allocate an additional profile counter for the entire - // computation. - std::vector profile_counters(hlo_to_profile_idx_.size() + 1); + // profile. Even when not Hlo profiling, we allocate a counter for the entire + // computation, which we use to update ExecutionProfile below. + std::vector* profile_counters = nullptr; + std::vector profile_counter_for_entry_computation; + if (hlo_execution_profile) { + profile_counters = hlo_execution_profile->mutable_profile_counters(); + } else { + profile_counters = &profile_counter_for_entry_computation; + profile_counter_for_entry_computation.push_back(0); + } // Call the computation function following the calling convention. std::vector buffer_pointers; @@ -198,7 +172,7 @@ Status CpuExecutable::ExecuteComputeFunction( VLOG(3) << tensorflow::strings::Printf( " func(void* result, void* params[%zu], void* temps[%zu], " "uint64 profile_counters[%zu])", - args_array.size(), buffer_pointers.size(), profile_counters.size()); + args_array.size(), buffer_pointers.size(), profile_counters->size()); VLOG(3) << tensorflow::strings::Printf(" result = %p", result_buffer); auto ptr_printer = [](string* out, const void* p) { tensorflow::strings::StrAppend(out, tensorflow::strings::Printf("%p", p)); @@ -210,11 +184,11 @@ Status CpuExecutable::ExecuteComputeFunction( " temps = [%s]", tensorflow::str_util::Join(buffer_pointers, ", ", ptr_printer).c_str()); VLOG(3) << tensorflow::strings::Printf(" profile_counters = %p", - profile_counters.data()); + profile_counters->data()); } compute_function_(result_buffer, run_options, args_array.data(), - buffer_pointers.data(), profile_counters.data()); + buffer_pointers.data(), profile_counters->data()); uint64 end_micros = tensorflow::Env::Default()->NowMicros(); @@ -223,67 +197,97 @@ Status CpuExecutable::ExecuteComputeFunction( const double nanoseconds = (end_micros - start_micros) * 1000.0; execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0)); - // The last profile counter is used for the computation as a whole. - execution_profile_.set_compute_cycle_count(profile_counters.back()); - } - - if (hlo_execution_profile != nullptr) { - hlo_execution_profile->set_total_cycles_executed( - *module().entry_computation(), profile_counters.back()); - - for (auto hlo_prof_idx : hlo_to_profile_idx_) { - const HloInstruction* hlo = hlo_prof_idx.first; - uint64 cycles_taken = profile_counters[hlo_prof_idx.second]; - hlo_execution_profile->SetCyclesTakenBy(hlo, cycles_taken); + if (hlo_execution_profile) { + execution_profile_.set_compute_cycle_count( + hlo_execution_profile->total_cycles_executed( + *module().entry_computation())); + } else { + execution_profile_.set_compute_cycle_count(profile_counters->back()); } } + return Status::OK(); } -StatusOr CpuExecutable::ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - HloExecutionProfile* hlo_execution_profile) { - se::Stream* stream = run_options->stream(); - DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - std::vector buffers(assignment_->Allocations().size()); - - TF_RETURN_IF_ERROR(AllocateBuffers( - memory_allocator, stream->parent()->device_ordinal(), &buffers)); - TF_RETURN_IF_ERROR(ExecuteComputeFunction( - &run_options->run_options(), arguments, buffers, hlo_execution_profile)); - - // Mark the buffers that are actually live (used in the output) when the - // computation finishes executing. - std::unordered_set marked_addresses; - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, - assignment_->GetUniqueTopLevelOutputSlice()); - se::DeviceMemoryBase top_level_output = buffers[result_slice.index()]; - MarkLiveAddressesInOutput(top_level_output.opaque(), result_shape(), - &marked_addresses); +static void LogLiveAddresses( + tensorflow::gtl::ArraySlice buffers, + const std::vector& buffers_in_result) { + if (!VLOG_IS_ON(3)) { + return; + } + CHECK_EQ(buffers.size(), buffers_in_result.size()); + std::vector live_out_buffers; + for (int i = 0; i < buffers.size(); ++i) { + if (buffers_in_result[i]) { + live_out_buffers.push_back(buffers[i].opaque()); + } + } VLOG(3) << "Live addresses in output marking found " - << marked_addresses.size() << " addresses:\n" + << live_out_buffers.size() << " addresses:\n" << tensorflow::str_util::Join( - marked_addresses, ", ", [](string* out, const void* address) { + live_out_buffers, ", ", [](string* out, const void* address) { tensorflow::strings::StrAppend( out, tensorflow::strings::Printf("%p", address)); }); +} - // Computation is done - deallocate temp buffers. Keep those marked live - // because they are referenced by the output of the computation and are needed +static Status DeallocateTempBuffers( + DeviceMemoryAllocator* allocator, se::Stream* stream, + tensorflow::gtl::ArraySlice buffers, + const std::vector& buffers_in_result) { + // Keep those buffers in the output of the marked live because they are needed // by the service. They will be deallocated by the service. for (size_t i = 0; i < buffers.size(); ++i) { se::DeviceMemoryBase alloc = buffers[i]; - if (marked_addresses.count(alloc.opaque()) == 0 && !alloc.is_null()) { + if (!buffers_in_result[i] && !alloc.is_null()) { VLOG(3) << "CpuExecutable deallocating buffer #" << i << " [" << alloc.opaque() << "]"; - TF_RETURN_IF_ERROR(memory_allocator->Deallocate( - stream->parent()->device_ordinal(), &alloc)); + TF_RETURN_IF_ERROR( + allocator->Deallocate(stream->parent()->device_ordinal(), &alloc)); } } - return top_level_output; + return Status::OK(); +} + +StatusOr> CpuExecutable::CreateResultShapedBuffer( + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice + allocated_buffers, + std::vector* buffers_in_result) { + se::Stream* stream = run_options->stream(); + auto result_buffer = MakeUnique( + /*on_host_shape=*/result_shape(), /*on_device_shape=*/result_shape(), + stream->parent()->platform(), stream->parent()->device_ordinal()); + + // Copy DeviceMemoryBase values which contain the array(s) of the result into + // the respective location in ShapedBuffer which is returned to the caller. + TF_RETURN_IF_ERROR(result_buffer->buffers().ForEachMutableElementWithStatus( + [&](const ShapeIndex& index, se::DeviceMemoryBase* device_memory) { + const auto& sources = this->GetRootPointsToSet().element(index); + // The points to set is unambiguous so the set should be a + // singleton. + CHECK_EQ(1, sources.size()); + const LogicalBuffer* buffer_source = sources[0]; + HloInstruction* src = buffer_source->instruction(); + + // The source for this result buffer can be a nested buffer such as + // a tuple element. The source instruction should have a + // non-parameter buffer assigned. + TF_ASSIGN_OR_RETURN( + const BufferAllocation::Slice slice, + this->assignment_->GetUniqueSlice(src, buffer_source->index())); + CHECK(!slice.allocation()->is_entry_computation_parameter()); + + const BufferAllocation::Index buffer_index = slice.index(); + const se::DeviceMemoryBase& buffer = allocated_buffers[buffer_index]; + CHECK(!buffer.is_null() || buffer.size() == 0); + *device_memory = buffer; + (*buffers_in_result)[buffer_index] = true; + return Status::OK(); + })); + return std::move(result_buffer); } StatusOr> CpuExecutable::ExecuteOnStream( @@ -298,70 +302,60 @@ StatusOr> CpuExecutable::ExecuteOnStream( DeviceMemoryAllocator* memory_allocator = run_options->allocator(); std::vector buffers(assignment_->Allocations().size()); - auto result_buffer = - MakeUnique(result_shape(), stream->parent()->platform(), - stream->parent()->device_ordinal()); - TF_RETURN_IF_ERROR(AllocateBuffers( memory_allocator, stream->parent()->device_ordinal(), &buffers)); TF_RETURN_IF_ERROR(ExecuteComputeFunction( &run_options->run_options(), arguments, buffers, hlo_execution_profile)); - // Copy DeviceMemoryBase values which contain the array(s) of the result into - // the respective location in ShapedBuffer which is returned to the caller. std::vector buffers_in_result(assignment_->Allocations().size(), false); - TF_RETURN_IF_ERROR( - result_buffer->mutable_shape_index_to_buffer_entry() - ->ForEachMutableElementWithStatus( - [&buffers, &buffers_in_result, &result_buffer, this]( - const ShapeIndex& index, size_t* buffer_entry) { - const auto& sources = this->GetRootPointsToSet().element(index); - // The points to set is unambiguous so the set should be a - // singleton. - CHECK_EQ(1, sources.size()); - const LogicalBuffer* buffer_source = sources[0]; - HloInstruction* src = buffer_source->instruction(); - - // The source for this result buffer can be a nested buffer - // such as a tuple element. - - // The source instruction should have a non-parameter buffer - // assigned. - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, - this->assignment_->GetUniqueSlice( - src, buffer_source->index())); - CHECK(!slice.allocation()->is_entry_computation_parameter()); - - const BufferAllocation::Index buffer_index = slice.index(); - const se::DeviceMemoryBase& buffer = buffers[buffer_index]; - CHECK(!buffer.is_null() || buffer.size() == 0); - *buffer_entry = result_buffer->mutable_buffers()->size(); - result_buffer->mutable_buffers()->push_back(buffer); - buffers_in_result[buffer_index] = true; - return Status::OK(); - })); + TF_ASSIGN_OR_RETURN( + std::unique_ptr result_buffer, + CreateResultShapedBuffer(run_options, buffers, &buffers_in_result)); // Free all buffers not in the result. - for (size_t i = 0; i < buffers.size(); ++i) { - se::DeviceMemoryBase alloc = buffers[i]; - if (!buffers_in_result[i] && !alloc.is_null()) { - VLOG(3) << "CpuExecutable deallocating buffer #" << i << " [" - << alloc.opaque() << "]"; - TF_RETURN_IF_ERROR(memory_allocator->Deallocate( - stream->parent()->device_ordinal(), &alloc)); - } - } + TF_RETURN_IF_ERROR(DeallocateTempBuffers(memory_allocator, stream, buffers, + buffers_in_result)); return std::move(result_buffer); } -StatusOr -CpuExecutable::ExecuteAsyncOnStream( +StatusOr> CpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) { - // TODO(b/30671675): Implement asynchronous execution mode. - return Unimplemented( - "Asynchronous execution on stream is not yet supported on CPU."); + tensorflow::gtl::ArraySlice arguments) { + if (hlo_profiling_enabled()) { + return Unimplemented( + "Asynchronous execution on stream with hlo profiling is not yet " + "supported on CPU."); + } + + auto* host_stream = dynamic_cast( + run_options->stream()->implementation()); + se::Stream* stream = run_options->stream(); + DeviceMemoryAllocator* memory_allocator = run_options->allocator(); + std::vector buffers(assignment_->Allocations().size()); + + TF_RETURN_IF_ERROR(AllocateBuffers( + memory_allocator, stream->parent()->device_ordinal(), &buffers)); + + std::vector buffers_in_result(assignment_->Allocations().size(), false); + TF_ASSIGN_OR_RETURN( + std::unique_ptr result_buffer, + CreateResultShapedBuffer(run_options, buffers, &buffers_in_result)); + + LogLiveAddresses(buffers, buffers_in_result); + + host_stream->EnqueueTask([this, run_options, arguments, buffers, + buffers_in_result, memory_allocator, stream]() { + // Failing a CHECK here is not great, but I don't see an obvious way to + // return a failed Status asynchronously. + TF_CHECK_OK(ExecuteComputeFunction(&run_options->run_options(), arguments, + buffers, + /*hlo_execution_profile=*/nullptr)); + TF_CHECK_OK(DeallocateTempBuffers(memory_allocator, stream, buffers, + buffers_in_result)); + }); + + return std::move(result_buffer); } /*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) { @@ -377,9 +371,5 @@ const PointsToSet& CpuExecutable::GetRootPointsToSet() const { module().entry_computation()->root_instruction()); } -std::unique_ptr CpuExecutable::CreateCostAnalysis() const { - return MakeUnique(ShapeSizeBytes); -} - } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 238bc9b46ae2bf1b519eaf137d9ae063e769bd2e..50443a59954e222f65fc935e83effdaf6d6c8bf0 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -47,29 +47,22 @@ namespace cpu { // architecture, so JIT-ed code and host code share the same ABI. class CpuExecutable : public Executable { public: - CpuExecutable( - std::unique_ptr jit, - std::unique_ptr assignment, - std::unique_ptr hlo_module, - const string& entry_function_name, - std::unordered_map hlo_to_profile_idx); + CpuExecutable(std::unique_ptr jit, + std::unique_ptr assignment, + std::unique_ptr hlo_module, + const string& entry_function_name, + std::unique_ptr hlo_profile_printer, + std::unique_ptr hlo_profile_index_map); ~CpuExecutable() override {} - StatusOr ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments, - HloExecutionProfile* hlo_execution_profile) override; - StatusOr> ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) override; - StatusOr ExecuteAsyncOnStream( + StatusOr> ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments) override; + tensorflow::gtl::ArraySlice arguments) override; // This should be called after set_ir_module_string. const string& ir_module_string() const { return ir_module_string_; } @@ -85,12 +78,10 @@ class CpuExecutable : public Executable { static int64 ShapeSizeBytes(const Shape& shape); - std::unique_ptr CreateCostAnalysis() const override; - // Type of the computation function we expect in the JIT. using ComputeFunctionType = void (*)( void* /*result*/, const ExecutableRunOptions* /*run_options*/, - const void** /*args*/, void** /*temps*/, uint64* /*profile_counters*/); + const void** /*args*/, void** /*temps*/, int64* /*profile_counters*/); const ComputeFunctionType& compute_function() const { return compute_function_; @@ -110,13 +101,6 @@ class CpuExecutable : public Executable { // Calls the generated function performing the computation with the given // arguments using the supplied buffers. - Status ExecuteComputeFunction( - const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments, - tensorflow::gtl::ArraySlice - buffers, - HloExecutionProfile* hlo_execution_profile); Status ExecuteComputeFunction( const ExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, @@ -124,6 +108,18 @@ class CpuExecutable : public Executable { buffers, HloExecutionProfile* hlo_execution_profile); + // Create a ShapedBuffer for holding the result of the computation. The + // addresses (DeviceMemoryBases) are set according to buffer assignment. + // 'buffers_in_result' should point to a vector of the same size as + // 'allocated_buffers'. An element in buffers_in_result is set to true if the + // corresponding buffer is live out of the computation (and thus contained in + // the returned ShapedBuffer). + StatusOr> CreateResultShapedBuffer( + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice + allocated_buffers, + std::vector* buffers_in_result); + // Returns the points-to set of the root instruction of the entry // computation. Uses points-to analysis from buffer assignment. const PointsToSet& GetRootPointsToSet() const; @@ -145,9 +141,6 @@ class CpuExecutable : public Executable { // Entry function name for the computation. const string entry_function_name_; - // Maps HLOs to their index into the profile counter array. - const std::unordered_map hlo_to_profile_idx_; - TF_DISALLOW_COPY_AND_ASSIGN(CpuExecutable); }; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index f87ee3cecd932faac140636a3db7cd4aa0371b85..482e04052d5a914eab0e5bff2c7a83f3b698052f 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -26,7 +26,7 @@ int64 BytesInDimension(const Shape& shape, int64 dimension) { shape.dimensions(dimension); } -bool IsFusile(const HloInstruction& hlo) { +bool CanBeLoopFused(const HloInstruction& hlo) { // These are the only ones we fuse since we rely on effective elemental IR // generation. return hlo.IsElementwise() || // @@ -42,6 +42,23 @@ bool IsFusile(const HloInstruction& hlo) { hlo.opcode() == HloOpcode::kTranspose; } +bool IsMatrixVectorDot(const HloInstruction* hlo) { + const Shape& hlo_shape = hlo->shape(); + return hlo->opcode() == HloOpcode::kDot && hlo_shape.dimensions_size() == 2 && + (hlo_shape.dimensions(0) == 1 || hlo_shape.dimensions(1) == 1); +} + +bool CanBeOutputFused(const HloInstruction* producer, + const HloInstruction* consumer) { + return consumer->opcode() == HloOpcode::kAdd && IsMatrixVectorDot(producer) && + producer->user_count() == 1; +} + +bool CanBeOutputFusedIntoSomeOperand(const HloInstruction* consumer) { + return consumer->opcode() == HloOpcode::kAdd && + (CanBeOutputFused(consumer->operand(0), consumer) || + CanBeOutputFused(consumer->operand(1), consumer)); +} } // namespace bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, @@ -52,7 +69,15 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, constexpr int kFusionThresholdBytes = 16 * 1024; - if (!IsFusile(*producer)) { + if (CanBeOutputFused(producer, consumer)) { + return true; + } + + if (CanBeOutputFusedIntoSomeOperand(producer)) { + return false; + } + + if (!CanBeLoopFused(*producer)) { VLOG(2) << "Producer is not fusile."; return false; } @@ -108,16 +133,13 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, } } - if (consumer->opcode() == HloOpcode::kFusion) { - // InstructionFusion::ShouldFuse above only allows kLoop and kInput fusions. - // The CPU backend does not create kInput fusions, so we only expect to see - // kLoop here. - CHECK(consumer->fusion_kind() == HloInstruction::FusionKind::kLoop); + if (consumer->opcode() == HloOpcode::kFusion && + consumer->fusion_kind() == HloInstruction::FusionKind::kLoop) { VLOG(2) << "Fusing: consumer is a fusion node."; return true; } - if (IsFusile(*consumer)) { + if (CanBeLoopFused(*consumer)) { VLOG(2) << "Fusing: consumer is elementwise or fusile."; return true; } @@ -126,5 +148,11 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return false; } +HloInstruction::FusionKind CpuInstructionFusion::ChooseKind( + const HloInstruction* producer, const HloInstruction* consumer) { + return CanBeOutputFused(producer, consumer) + ? HloInstruction::FusionKind::kOutput + : HloInstruction::FusionKind::kLoop; +} } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h index 0eca4c3473e1454fe5dbd8bf855b4418cf553a94..07aff34974e0cfa6c7a129f82017b280fb1ccd59 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h @@ -30,6 +30,8 @@ class CpuInstructionFusion : public InstructionFusion { protected: bool ShouldFuse(HloInstruction* consumer, int64 operand_index) override; + HloInstruction::FusionKind ChooseKind( + const HloInstruction* producer, const HloInstruction* consumer) override; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index b9e4d006d77ae76e33ac51440349400ea4eff118..595c3f55b321f47e2312b93e0c238c7637495d77 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -31,6 +31,14 @@ namespace { using InstructionFusionTest = HloTestBase; +std::unique_ptr MakeDot(const Shape& shape, HloInstruction* lhs, + HloInstruction* rhs) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums); +} + TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) { HloComputation::Builder builder(TestName()); HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -40,8 +48,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) { HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kExp, arg0)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1024, 1}), HloOpcode::kDot, exp0, arg1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), exp0, arg1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -59,8 +67,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_1) { HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {256, 1024}), HloOpcode::kExp, arg1)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1, 1024}), HloOpcode::kDot, arg0, exp1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, exp1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -80,8 +88,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Bitcast) { ShapeUtil::MakeShape(S32, {2, 512, 2, 128}), HloOpcode::kExp, arg0)); HloInstruction* bitcast0 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kBitcast, exp0)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1024, 1}), HloOpcode::kDot, bitcast0, arg1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), bitcast0, arg1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -102,8 +110,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) { HloInstruction* reshape0 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1024, 256}), exp0)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1024, 1}), HloOpcode::kDot, reshape0, arg1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), reshape0, arg1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -121,8 +129,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_TooLarge) { HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {256, 32 * 1024}), HloOpcode::kExp, arg1)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1, 32 * 1024}), HloOpcode::kDot, arg0, exp1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1, 32 * 1024}), arg0, exp1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -140,8 +148,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_ElementReuse) { HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {256, 1024}), HloOpcode::kExp, arg1)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {2, 1024}), HloOpcode::kDot, arg0, exp1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {2, 1024}), arg0, exp1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -162,8 +170,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion) { HloInstruction* transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {256, 1024}), exp1, {1, 0})); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1, 1024}), HloOpcode::kDot, arg0, transpose1)); + builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, transpose1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -188,7 +196,9 @@ class OpcodeFusionTest : public InstructionFusionTest { // Runs CPU instruction fusion on the given module, and tests that the result // contains a fused op at the root with exactly the given multiset of opcodes. void RunFusionAndCheckOpcodesWereFused( - HloModule* module, const std::multiset& expected_opcodes) { + HloModule* module, const std::multiset& expected_opcodes, + HloInstruction::FusionKind fusion_kind = + HloInstruction::FusionKind::kLoop) { auto computation = module->entry_computation(); auto did_fusion = CpuInstructionFusion().Run(module); ASSERT_TRUE(did_fusion.ok()); @@ -196,7 +206,7 @@ class OpcodeFusionTest : public InstructionFusionTest { HloInstruction* root = computation->root_instruction(); ASSERT_THAT(root, op::Fusion()); - EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kLoop); + EXPECT_EQ(root->fusion_kind(), fusion_kind); std::vector fused_opcodes(root->fused_instruction_count()); std::transform(root->fused_instructions().begin(), @@ -608,6 +618,88 @@ TEST_F(OpcodeFusionTest, ReuseViaImplicitBroadcastBinary) { Not(op::Fusion())); } +void CreateComputationForDotAddOutputFusionTest(const string& test_name, + HloModule* module, int m, int k, + int n, + bool add_extra_use_for_dot) { + HloComputation::Builder builder(test_name); + + Shape dot_lhs_shape = ShapeUtil::MakeShape(F32, {m, k}); + Shape dot_rhs_shape = ShapeUtil::MakeShape(F32, {k, n}); + Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n}); + + auto* dot_lhs = builder.AddInstruction( + HloInstruction::CreateParameter(0, dot_lhs_shape, "param0")); + auto* dot_rhs = builder.AddInstruction( + HloInstruction::CreateParameter(1, dot_rhs_shape, "param1")); + auto* addend = builder.AddInstruction( + HloInstruction::CreateParameter(2, dot_shape, "param2")); + + auto* dot = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); + builder.AddInstruction( + HloInstruction::CreateBinary(dot_shape, HloOpcode::kAdd, dot, addend)); + + if (add_extra_use_for_dot) { + builder.AddInstruction( + HloInstruction::CreateOutfeed(dot_shape, dot, "no_config")); + } + + module->AddEntryComputation(builder.Build()); +} + +TEST_F(OpcodeFusionTest, DotAddOutputFusion_1x50x19) { + auto module = CreateNewModule(); + CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/1, + /*k=*/50, /*n=*/19, + /*add_extra_use_for_dot=*/false); + + RunFusionAndCheckOpcodesWereFused( + module.get(), + {HloOpcode::kDot, HloOpcode::kAdd, HloOpcode::kParameter, + HloOpcode::kParameter, HloOpcode::kParameter}, + HloInstruction::FusionKind::kOutput); +} + +TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1) { + auto module = CreateNewModule(); + CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, + /*k=*/50, /*n=*/1, + /*add_extra_use_for_dot=*/false); + + RunFusionAndCheckOpcodesWereFused( + module.get(), + {HloOpcode::kDot, HloOpcode::kAdd, HloOpcode::kParameter, + HloOpcode::kParameter, HloOpcode::kParameter}, + HloInstruction::FusionKind::kOutput); +} + +TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x19) { + auto module = CreateNewModule(); + CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, + /*k=*/50, /*n=*/19, + /*add_extra_use_for_dot=*/false); + + TF_ASSERT_OK_AND_ASSIGN(bool fused_something, + CpuInstructionFusion().Run(module.get())); + EXPECT_FALSE(fused_something); + EXPECT_THAT(module->entry_computation()->root_instruction(), + Not(op::Fusion())); +} + +TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1_multi_use) { + auto module = CreateNewModule(); + CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, + /*k=*/50, /*n=*/1, + /*add_extra_use_for_dot=*/true); + + TF_ASSERT_OK_AND_ASSIGN(bool fused_something, + CpuInstructionFusion().Run(module.get())); + EXPECT_FALSE(fused_something); + EXPECT_THAT(module->entry_computation()->root_instruction(), + Not(op::Fusion())); +} + } // namespace } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc similarity index 55% rename from tensorflow/compiler/xla/service/cpu/layout_assignment.cc rename to tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc index 3f2d101959db50d9f775097f01d5a2ba25a0da8c..e8117377e61a4e21b8c45b929c518a18878fcb60 100644 --- a/tensorflow/compiler/xla/service/cpu/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/cpu/layout_assignment.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h" #include @@ -25,58 +25,77 @@ limitations under the License. namespace xla { namespace cpu { -Status CpuLayoutAssignment::AddBackendConstraints( - LayoutConstraints* constraints) { - auto row_major_shape = [](const Shape& old_shape) { - Shape new_shape(old_shape); - std::vector dimension_order(new_shape.dimensions_size()); - std::iota(dimension_order.rbegin(), dimension_order.rend(), 0); - *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order); - return new_shape; - }; - auto col_major_shape = [](const Shape& old_shape) { - Shape new_shape(old_shape); - std::vector dimension_order(new_shape.dimensions_size()); - std::iota(dimension_order.begin(), dimension_order.end(), 0); - *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order); - return new_shape; - }; - - // We want to change the layout of constant arrays to be column major when all - // of their users are dot operations that can be made faster with the flipped - // layout. To avoid going quadriatic over the # of instructions, we cache - // this property in should_make_rhs_col_major -- it maps a constant to true if - // all of the users of said constant are dot operations that can be sped up. - // This cache is populated lazily as we encounter dot operations traversing - // the instruction stream. - tensorflow::gtl::FlatMap - should_make_rhs_col_major_cache; - auto should_make_rhs_col_major = [&](const HloInstruction& instruction) { - if (ProfitableToImplementDotInUntiledLlvmIr(instruction) != - DotInLlvmIrProfitable::kWithColumnMajorRhs) { +// We want to change the layout of constant arrays to be column major when all +// of their users are dot operations that can be made faster with the flipped +// layout. To avoid going quadriatic over the # of instructions, we cache this +// property in should_make_rhs_col_major -- it maps a constant to true if all of +// the users of said constant are dot operations that can be sped up. This +// cache is populated lazily as we encounter dot operations traversing the +// instruction stream. + +namespace { +using ::tensorflow::gtl::nullopt; +using ::tensorflow::gtl::optional; + +using ShouldMakeOperandColMajorCache = + tensorflow::gtl::FlatMap; +} // namespace + +static bool ShouldMakeAllUsersColMajor(const HloInstruction* instruction) { + for (auto* user : instruction->users()) { + optional operand_idx = ProfitableToMakeDotOperandColumnMajor(*user); + if (!operand_idx || user->operand(*operand_idx) != instruction || + std::count(user->operands().begin(), user->operands().end(), + instruction) != 1) { return false; } + } + return true; +} - const auto* rhs = instruction.operand(1); - if (rhs->opcode() != HloOpcode::kConstant) { - return false; - } +static optional ShouldMakeOperandColumnMajor( + ShouldMakeOperandColMajorCache* cache, const HloInstruction& instruction) { + optional operand_idx = + ProfitableToMakeDotOperandColumnMajor(instruction); + if (!operand_idx) { + return nullopt; + } - auto it = should_make_rhs_col_major_cache.find(rhs); - if (it != should_make_rhs_col_major_cache.end()) { - return it->second; - } + const HloInstruction* operand = instruction.operand(*operand_idx); + if (operand->opcode() != HloOpcode::kConstant) { + return nullopt; + } - bool result = std::all_of( - rhs->users().begin(), rhs->users().end(), [&](HloInstruction* user) { - return ProfitableToImplementDotInUntiledLlvmIr(*user) == - DotInLlvmIrProfitable::kWithColumnMajorRhs && - user->operand(0) != rhs; - }); + auto it = cache->find(operand); + if (it == cache->end()) { + auto insert_result = + cache->insert({operand, ShouldMakeAllUsersColMajor(operand)}); + CHECK(insert_result.second); + it = insert_result.first; + } - InsertOrDie(&should_make_rhs_col_major_cache, rhs, result); - return result; - }; + return it->second ? operand_idx : nullopt; +} + +static Shape RowMajorShape(const Shape& old_shape) { + Shape new_shape(old_shape); + std::vector dimension_order(new_shape.dimensions_size()); + std::iota(dimension_order.rbegin(), dimension_order.rend(), 0); + *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order); + return new_shape; +} + +static Shape ColMajorShape(const Shape& old_shape) { + Shape new_shape(old_shape); + std::vector dimension_order(new_shape.dimensions_size()); + std::iota(dimension_order.begin(), dimension_order.end(), 0); + *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order); + return new_shape; +} + +Status CpuLayoutAssignment::AddBackendConstraints( + LayoutConstraints* constraints) { + ShouldMakeOperandColMajorCache cache; const HloComputation* computation = constraints->computation(); for (auto* instruction : computation->instructions()) { @@ -91,9 +110,9 @@ Status CpuLayoutAssignment::AddBackendConstraints( // // These constraints are not hard constraints. Ideally, we should decide // which layouts to choose according to some cost model. - Shape output_shape(row_major_shape(convolution->shape())); - Shape input_shape(row_major_shape(lhs_instruction->shape())); - Shape filter_shape(row_major_shape(rhs_instruction->shape())); + Shape output_shape(RowMajorShape(convolution->shape())); + Shape input_shape(RowMajorShape(lhs_instruction->shape())); + Shape filter_shape(RowMajorShape(rhs_instruction->shape())); // Set layouts of the instructions' shapes. TF_RETURN_IF_ERROR( @@ -102,11 +121,11 @@ Status CpuLayoutAssignment::AddBackendConstraints( constraints->SetOperandLayout(filter_shape, convolution, 1)); TF_RETURN_IF_ERROR( constraints->SetInstructionLayout(output_shape, convolution)); - } else if (should_make_rhs_col_major(*instruction)) { - auto* dot = instruction; - const auto& rhs_shape = dot->operand(1)->shape(); - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(col_major_shape(rhs_shape), dot, 1)); + } else if (optional op_idx = + ShouldMakeOperandColumnMajor(&cache, *instruction)) { + const HloInstruction* op = instruction->operand(*op_idx); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + ColMajorShape(op->shape()), instruction, *op_idx)); } else if (PotentiallyImplementedAsEigenDot(*instruction)) { const HloInstruction* dot = instruction; // In order to implement `dot` with Eigen dot, the layouts of the lhs, @@ -114,17 +133,17 @@ Status CpuLayoutAssignment::AddBackendConstraints( // // These constraints are not hard constraints. Ideally, we should decide // which layouts to choose according to some cost model. - Shape output_shape(row_major_shape(dot->shape())); + Shape output_shape(RowMajorShape(dot->shape())); const HloInstruction* lhs_instruction = dot->operand(0); - Shape lhs_shape(row_major_shape(lhs_instruction->shape())); + Shape lhs_shape(RowMajorShape(lhs_instruction->shape())); TF_RETURN_IF_ERROR(constraints->SetOperandLayout(lhs_shape, dot, 0)); // dot is a kDot or a kTransposeDot fusion node. In the latter case, if // it represents X @ X, it may have just one operand. if (dot->operand_count() > 1) { const HloInstruction* rhs_instruction = dot->operand(1); - Shape rhs_shape(row_major_shape(rhs_instruction->shape())); + Shape rhs_shape(RowMajorShape(rhs_instruction->shape())); TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, dot, 1)); } @@ -141,8 +160,12 @@ Status CpuLayoutAssignment::AddBackendConstraints( if (constraints->OperandBufferForwarded(instruction, operand_no)) { continue; } + // Skip operands with non-array shapes. + if (!ShapeUtil::IsArray(instruction->operand(operand_no)->shape())) { + continue; + } Shape operand_shape( - row_major_shape(instruction->operand(operand_no)->shape())); + RowMajorShape(instruction->operand(operand_no)->shape())); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( operand_shape, instruction, operand_no)); } diff --git a/tensorflow/compiler/xla/service/cpu/layout_assignment.h b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h similarity index 86% rename from tensorflow/compiler/xla/service/cpu/layout_assignment.h rename to tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h index 4fd8d68dd6b4f2a8b16f6c048743a996ea76a560..c8edbb9e15a5b6f9c574f5fe9d130d149499ebd2 100644 --- a/tensorflow/compiler/xla/service/cpu/layout_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LAYOUT_ASSIGNMENT_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LAYOUT_ASSIGNMENT_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_LAYOUT_ASSIGNMENT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_LAYOUT_ASSIGNMENT_H_ #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" @@ -38,4 +38,4 @@ class CpuLayoutAssignment : public LayoutAssignment { } // namespace cpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LAYOUT_ASSIGNMENT_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_LAYOUT_ASSIGNMENT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc similarity index 54% rename from tensorflow/compiler/xla/service/cpu/layout_assignment_test.cc rename to tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc index 1ea5e8c7fc4896512e62396d0a756cda44785f11..6ba030fff3bbc5f413bfb133114ceb5309b77672 100644 --- a/tensorflow/compiler/xla/service/cpu/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/cpu/layout_assignment.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h" #include #include @@ -40,6 +40,8 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" +namespace op = xla::testing::opcode_matchers; + namespace xla { namespace { @@ -61,8 +63,8 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) { HloInstruction::CreateParameter(0, lhs_shape, "param0")); auto dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); - auto result = builder.AddInstruction(HloInstruction::CreateBinary( - result_shape, HloOpcode::kDot, dot_lhs, dot_rhs)); + auto result = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); @@ -98,10 +100,10 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) { HloInstruction::CreateParameter(1, lhs_shape, "param1")); auto dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); - auto dot_a_result = builder.AddInstruction(HloInstruction::CreateBinary( - result_shape, HloOpcode::kDot, dot_a_lhs, dot_rhs)); - auto dot_b_result = builder.AddInstruction(HloInstruction::CreateBinary( - result_shape, HloOpcode::kDot, dot_b_lhs, dot_rhs)); + auto dot_a_result = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(result_shape, dot_a_lhs, dot_rhs)); + auto dot_b_result = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(result_shape, dot_b_lhs, dot_rhs)); builder.AddInstruction(HloInstruction::CreateBinary( result_shape, HloOpcode::kAdd, dot_a_result, dot_b_result)); @@ -142,10 +144,10 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor1) { HloInstruction::CreateParameter(1, lhs_b_shape, "param1")); auto dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); - auto dot_a_result = builder.AddInstruction(HloInstruction::CreateBinary( - result_a_shape, HloOpcode::kDot, dot_a_lhs, dot_rhs)); - auto dot_b_result = builder.AddInstruction(HloInstruction::CreateBinary( - result_b_shape, HloOpcode::kDot, dot_b_lhs, dot_rhs)); + auto dot_a_result = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(result_a_shape, dot_a_lhs, dot_rhs)); + auto dot_b_result = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(result_b_shape, dot_b_lhs, dot_rhs)); auto tuple_result = builder.AddInstruction( HloInstruction::CreateTuple({dot_a_result, dot_b_result})); @@ -180,8 +182,8 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantLhsTensor) { HloInstruction::CreateConstant(Literal::CreateFromShape(lhs_shape))); auto dot_rhs = builder.AddInstruction( HloInstruction::CreateParameter(0, rhs_shape, "param0")); - auto dot_result = builder.AddInstruction(HloInstruction::CreateBinary( - result_shape, HloOpcode::kDot, dot_lhs, dot_rhs)); + auto dot_result = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); @@ -220,8 +222,8 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensorThroughGTE) { HloInstruction::CreateParameter(0, lhs_shape, "param0")); auto dot_rhs = builder.AddInstruction( HloInstruction::CreateGetTupleElement(rhs_shape, constant, 1)); - auto dot_result = builder.AddInstruction(HloInstruction::CreateBinary( - result_shape, HloOpcode::kDot, dot_lhs, dot_rhs)); + auto dot_result = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); @@ -241,5 +243,172 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensorThroughGTE) { EXPECT_NE(instruction->opcode(), HloOpcode::kCopy); } } + +struct DotOutputFusionLayoutAssignmentResult { + bool layout_assignment_changed_something; + const HloInstruction* dot_lhs_fusion_param; + const HloInstruction* dot_rhs_fusion_param; + const HloInstruction* addend_fusion_param; +}; + +static StatusOr RunDotOutputFusion( + HloModule* module, const string& test_name, int m, int k, int n, + const int64 dot_operand_idx_in_add) { + DotOutputFusionLayoutAssignmentResult result; + + CHECK(dot_operand_idx_in_add == 0 || dot_operand_idx_in_add == 1); + + auto builder = HloComputation::Builder(test_name); + + Shape dot_lhs_shape = ShapeUtil::MakeShape(F32, {m, k}); + Shape dot_rhs_shape = ShapeUtil::MakeShape(F32, {k, n}); + Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n}); + + HloInstruction* dot_lhs = builder.AddInstruction( + HloInstruction::CreateParameter(0, dot_lhs_shape, "param0")); + HloInstruction* addend = builder.AddInstruction( + HloInstruction::CreateParameter(1, dot_shape, "param1")); + HloInstruction* dot_rhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateFromShape(dot_rhs_shape))); + HloInstruction* dot_result = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); + HloInstruction* add_result; + if (dot_operand_idx_in_add == 0) { + add_result = builder.AddInstruction(HloInstruction::CreateBinary( + dot_shape, HloOpcode::kAdd, dot_result, addend)); + } else { + add_result = builder.AddInstruction(HloInstruction::CreateBinary( + dot_shape, HloOpcode::kAdd, addend, dot_result)); + } + + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + HloInstruction* fusion_instruction = + module->entry_computation()->AddInstruction(HloInstruction::CreateFusion( + dot_shape, HloInstruction::FusionKind::kOutput, add_result)); + TF_RETURN_IF_ERROR( + computation->ReplaceInstruction(add_result, fusion_instruction)); + + HloInstruction* fused_add = + fusion_instruction->fused_instructions_computation()->root_instruction(); + HloInstruction* fused_dot = fusion_instruction->FuseInstruction(dot_result); + + TF_RETURN_IF_ERROR( + computation->RemoveInstructionAndUnusedOperands(dot_result)); + + ComputationLayout computation_layout(computation->ComputeProgramShape()); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(LayoutUtil::GetWithDefaultLayout(dot_lhs_shape)); + *computation_layout.mutable_parameter_layout(1) = + ShapeLayout(LayoutUtil::GetWithDefaultLayout(dot_shape)); + *computation_layout.mutable_result_layout() = + ShapeLayout(LayoutUtil::GetWithDefaultLayout(dot_shape)); + + result.dot_lhs_fusion_param = + fusion_instruction->operand(fused_dot->operand(0)->parameter_number()); + result.dot_rhs_fusion_param = + fusion_instruction->operand(fused_dot->operand(1)->parameter_number()); + result.addend_fusion_param = fusion_instruction->operand( + fused_add->operand(1 - dot_operand_idx_in_add)->parameter_number()); + + cpu::CpuLayoutAssignment layout_assignment(&computation_layout); + TF_ASSIGN_OR_RETURN(result.layout_assignment_changed_something, + layout_assignment.Run(module)); + + return result; +} + +static void AssertCorrectLayoutForDotOutputFusion( + const HloComputation* computation, + const DotOutputFusionLayoutAssignmentResult& layout_assignment_result, + bool expect_col_major_dot_rhs) { + Layout expected_dot_rhs_layout = expect_col_major_dot_rhs + ? LayoutUtil::MakeLayout({0, 1}) + : LayoutUtil::MakeLayout({1, 0}); + EXPECT_TRUE(LayoutUtil::Equal( + expected_dot_rhs_layout, + layout_assignment_result.dot_rhs_fusion_param->shape().layout())); + + EXPECT_TRUE(LayoutUtil::Equal( + LayoutUtil::MakeLayout({1, 0}), + layout_assignment_result.dot_lhs_fusion_param->shape().layout())); + + EXPECT_TRUE(LayoutUtil::Equal( + LayoutUtil::MakeLayout({1, 0}), + layout_assignment_result.addend_fusion_param->shape().layout())); + EXPECT_THAT(computation->instructions(), Each(Not(op::Copy()))); +} + +TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_0) { + std::unique_ptr module = CreateNewModule(); + TF_ASSERT_OK_AND_ASSIGN( + DotOutputFusionLayoutAssignmentResult layout_assignment_result, + RunDotOutputFusion(module.get(), TestName(), /*m=*/1, /*k=*/50, /*n=*/19, + /*dot_operand_idx_in_add=*/0)); + ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); + AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), + layout_assignment_result, + /*expect_col_major_dot_rhs=*/true); +} + +TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_1) { + std::unique_ptr module = CreateNewModule(); + TF_ASSERT_OK_AND_ASSIGN( + DotOutputFusionLayoutAssignmentResult layout_assignment_result, + RunDotOutputFusion(module.get(), TestName(), /*m=*/1, /*k=*/50, /*n=*/19, + /*dot_operand_idx_in_add=*/1)); + ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); + AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), + layout_assignment_result, + /*expect_col_major_dot_rhs=*/true); +} + +TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_0) { + std::unique_ptr module = CreateNewModule(); + TF_ASSERT_OK_AND_ASSIGN( + DotOutputFusionLayoutAssignmentResult layout_assignment_result, + RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/1, + /*dot_operand_idx_in_add=*/0)); + ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); + AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), + layout_assignment_result, + /*expect_col_major_dot_rhs=*/false); +} + +TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_1) { + std::unique_ptr module = CreateNewModule(); + TF_ASSERT_OK_AND_ASSIGN( + DotOutputFusionLayoutAssignmentResult layout_assignment_result, + RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/1, + /*dot_operand_idx_in_add=*/1)); + ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); + AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), + layout_assignment_result, + /*expect_col_major_dot_rhs=*/false); +} + +TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_0) { + std::unique_ptr module = CreateNewModule(); + TF_ASSERT_OK_AND_ASSIGN( + DotOutputFusionLayoutAssignmentResult layout_assignment_result, + RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/19, + /*dot_operand_idx_in_add=*/0)); + ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); + AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), + layout_assignment_result, + /*expect_col_major_dot_rhs=*/false); +} + +TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_1) { + std::unique_ptr module = CreateNewModule(); + TF_ASSERT_OK_AND_ASSIGN( + DotOutputFusionLayoutAssignmentResult layout_assignment_result, + RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/19, + /*dot_operand_idx_in_add=*/1)); + ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); + AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), + layout_assignment_result, + /*expect_col_major_dot_rhs=*/false); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h index acfada8540d89bb098bb0b04e109441e2123e678..74ae6d00c91be07c0d181ea324e570c73c6b2e77 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h @@ -38,14 +38,16 @@ typedef float V8F32AVX __attribute__((__vector_size__(32))); extern "C" { +#ifdef __AVX__ // The following functions are vectorized versions of a selection of libm // library functions. // References to these functions are created by the LLVM vectorizer. xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_ExpV8F32AVX( - xla::cpu::runtime::V8F32AVX x) TF_ATTRIBUTE_WEAK; + xla::cpu::runtime::V8F32AVX x); xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_LogV8F32AVX( - xla::cpu::runtime::V8F32AVX x) TF_ATTRIBUTE_WEAK; + xla::cpu::runtime::V8F32AVX x); +#endif } #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_AVX_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h index 75cb16b273973d2bf665d378084343fd612a2941..645a43858fb8c3d8e7e94709333c88503b6cc52d 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h @@ -49,14 +49,16 @@ struct V4F32NEON; extern "C" { +#ifdef __ARM_NEON__ // The following functions are vectorized versions of a selection of libm // library functions. // References to these functions are created by the LLVM vectorizer. xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_ExpV4F32NEON( - xla::cpu::runtime::V4F32NEON x) TF_ATTRIBUTE_WEAK; + xla::cpu::runtime::V4F32NEON x); xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_LogV4F32NEON( - xla::cpu::runtime::V4F32NEON x) TF_ATTRIBUTE_WEAK; + xla::cpu::runtime::V4F32NEON x); +#endif // __ARM_NEON__ } #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_NEON_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h index 96587d10d2b86e14ff6a7400fdf14ca0d994ddc5..1bd8494bf8494d2100e68841f974c86e2beb3859 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h @@ -39,14 +39,17 @@ typedef float V4F32SSE __attribute__((__vector_size__(16))); extern "C" { +#ifdef __SSE4_1__ // The following functions are vectorized versions of a selection of libm // library functions. // References to these functions are created by the LLVM vectorizer. xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_ExpV4F32SSE( - xla::cpu::runtime::V4F32SSE x) TF_ATTRIBUTE_WEAK; + xla::cpu::runtime::V4F32SSE x); xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_LogV4F32SSE( - xla::cpu::runtime::V4F32SSE x) TF_ATTRIBUTE_WEAK; + xla::cpu::runtime::V4F32SSE x); +#endif + } #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_SSE4_1_H_ diff --git a/tensorflow/compiler/xla/service/cpu/disassembler.h b/tensorflow/compiler/xla/service/cpu/disassembler.h index b6feaa7e45cee26eb7f850081bd1fad2cb63b15c..5e302f88990ee4a3c37758881ecec4d6f71dd8e6 100644 --- a/tensorflow/compiler/xla/service/cpu/disassembler.h +++ b/tensorflow/compiler/xla/service/cpu/disassembler.h @@ -37,7 +37,7 @@ struct DisassemblerResult { DisassemblerResult(const string& text, size_t code_size_bytes) : text(text), code_size_bytes(code_size_bytes) {} - // The dissassembled text sections of the object file. + // The disassembled text sections of the object file. string text; // The total number of bytes of executable code in the object file. uint64_t code_size_bytes; @@ -53,7 +53,7 @@ class Disassembler { // Returns a DisassemblerResult for the given object file, containing the // disassembled code. // - // If we couldnt' retrieve a disassembler for this platform, an error status + // If we couldn't retrieve a disassembler for this platform, an error status // is returned. StatusOr DisassembleObjectFile( const llvm::object::ObjectFile& object_file) const; diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 4c40dae5122b0853a72d6428fc120220e3a69237..74f71e5ad575134d78f834a9e63723c22ae49111 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -143,7 +143,8 @@ class ColumnMajorMatrixVectorProductEmitter { ColumnMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, int64 m, int64 k, llvm::Value* lhs, - llvm::Value* rhs, llvm::Value* result, + llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result, llvm::IRBuilder<>* ir_builder) : scalar_type_(scalar_type), tile_rows_(tile_rows), @@ -152,6 +153,7 @@ class ColumnMajorMatrixVectorProductEmitter { k_(k), lhs_(lhs), rhs_(rhs), + addend_(addend), result_(result), ir_builder_(ir_builder), ksl_(ir_builder_), @@ -198,6 +200,7 @@ class ColumnMajorMatrixVectorProductEmitter { int64 k_; llvm::Value* lhs_; llvm::Value* rhs_; + llvm::Value* addend_; llvm::Value* result_; llvm::IRBuilder<>* ir_builder_; KernelSupportLibrary ksl_; @@ -242,9 +245,10 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( /*step=*/tile_rows_, [&](llvm::Value* row) { std::vector lhs_tile = lhs_tile_loader->LoadTile(/*minor_dim_offset=*/row); - llvm::Value* accumulator = is_first_column - ? vsl_.GetZeroVector() - : vsl_.LoadVector(result_, row); + llvm::Value* accumulator = + is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row) + : vsl_.GetZeroVector()) + : vsl_.LoadVector(result_, row); for (int i = 0; i < columns; i++) { accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator); } @@ -288,7 +292,18 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( ir_builder_->getInt1(is_first_tiled_column)); ksl_.If( setting_result_first_time, - [&]() { vsl_.StoreScalar(product, result_, scalar_row); }, + /*true_block_generator=*/ + [&]() { + if (addend_) { + vsl_.StoreScalar( + vsl_.Add(vsl_.LoadScalar(addend_, scalar_row), + product), + result_, scalar_row); + } else { + vsl_.StoreScalar(product, result_, scalar_row); + } + }, + /*false_block_generator=*/ [&]() { vsl_.StoreScalar( vsl_.Add(vsl_.LoadScalar(result_, scalar_row), product), @@ -353,7 +368,7 @@ class RowMajorMatrixVectorProductEmitter { RowMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, int64 m, int64 k, llvm::Value* lhs, llvm::Value* rhs, - llvm::Value* result, + llvm::Value* addend, llvm::Value* result, llvm::IRBuilder<>* ir_builder) : scalar_type_(scalar_type), tile_rows_(tile_rows), @@ -362,6 +377,7 @@ class RowMajorMatrixVectorProductEmitter { k_(k), lhs_(lhs), rhs_(rhs), + addend_(addend), result_(result), ir_builder_(ir_builder), ksl_(ir_builder_), @@ -394,6 +410,7 @@ class RowMajorMatrixVectorProductEmitter { int64 k_; llvm::Value* lhs_; llvm::Value* rhs_; + llvm::Value* addend_; llvm::Value* result_; llvm::IRBuilder<>* ir_builder_; KernelSupportLibrary ksl_; @@ -415,11 +432,32 @@ void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, EmitInnerLoopEpilogue(/*current_tile_row=*/row, /*rows=*/row_count, &scalar_accumulators); + std::vector accumulator_values; + std::transform( + vector_accumulators.begin(), vector_accumulators.end(), + std::back_inserter(accumulator_values), + [](const VectorVariable& vector_var) { return vector_var.Get(); }); + + std::vector horizontal_sums; + if (row_count == vsl_.vector_size()) { + if (addend_) { + horizontal_sums = vsl_.ComputeHorizontalSums( + std::move(accumulator_values), vsl_.LoadVector(addend_, row)); + } else { + horizontal_sums = + vsl_.ComputeHorizontalSums(std::move(accumulator_values)); + } + } else { + horizontal_sums = vsl_.ComputeHorizontalSums(std::move(accumulator_values)); + } + for (int i = 0; i < row_count; i++) { llvm::Value* result_value = - vsl_.Add(vsl_.AddReduce(vector_accumulators[i].Get()), - scalar_accumulators[i].Get()); + vsl_.Add(horizontal_sums[i], scalar_accumulators[i].Get()); llvm::Value* offset = ir_builder_->CreateAdd(ir_builder_->getInt64(i), row); + if (addend_ && row_count != vsl_.vector_size()) { + result_value = vsl_.Add(vsl_.LoadScalar(addend_, offset), result_value); + } vsl_.StoreScalar(result_value, result_, offset); } } @@ -483,20 +521,19 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( } // namespace -DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs, - bool transpose_rhs, - const llvm_ir::IrArray& target_array, - const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, - llvm::Value* executable_run_options_value, - llvm::IRBuilder<>* ir_builder, - const HloModuleConfig& hlo_module_config) +DotOpEmitter::DotOpEmitter( + const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, + const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, + const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, + llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, + const HloModuleConfig& hlo_module_config) : dot_(dot), transpose_lhs_(transpose_lhs), transpose_rhs_(transpose_rhs), target_array_(target_array), lhs_array_(lhs_array), rhs_array_(rhs_array), + addend_array_(addend_array), executable_run_options_value_(executable_run_options_value), ir_builder_(ir_builder), hlo_module_config_(hlo_module_config) {} @@ -504,28 +541,29 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs, /* static */ tensorflow::Status DotOpEmitter::EmitDotOperation( const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, const HloModuleConfig& hlo_module_config) { PrimitiveType type = target_array.GetShape().element_type(); TF_RET_CHECK(F32 == type || F64 == type || C64 == type); DotOpEmitter dot_emitter(dot, transpose_lhs, transpose_rhs, target_array, - lhs_array, rhs_array, executable_run_options_value, - ir_builder, hlo_module_config); + lhs_array, rhs_array, addend_array, + executable_run_options_value, ir_builder, + hlo_module_config); return dot_emitter.Emit(); } bool DotOpEmitter::ShapesAreLegalForRuntimeDot() const { return true; } bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { - if (dot_.shape().dimensions_size() != 2 || - ProfitableToImplementDotInUntiledLlvmIr(dot_) == - DotInLlvmIrProfitable::kYes) { + if (dot_.shape().dimensions_size() != 2) { return false; } - if (!primitive_util::IsFloatingPointType(dot_.shape().element_type()) && - !primitive_util::IsIntegralType(dot_.shape().element_type())) { + PrimitiveType primitive_type = dot_.shape().element_type(); + + if (!primitive_util::IsFloatingPointType(primitive_type) && + !primitive_util::IsIntegralType(primitive_type)) { return false; } @@ -575,30 +613,63 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { int64 tiling_factor = GetGemvTilingFactor(); CHECK_GT(tiling_factor, 0); + llvm::Value* result_op = target_array_.GetBasePointer(); + llvm::Value* lhs_op = + swap_operands ? rhs_array_.GetBasePointer() : lhs_array_.GetBasePointer(); + llvm::Value* rhs_op = + swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer(); + + const bool enable_fast_math = + hlo_module_config_.debug_options().xla_enable_fast_math(); + const bool optimize_for_size = + options::OptimizeForSizeRequested(hlo_module_config_); + if (is_column_major_matrix_vector) { VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m << " and k = " << k; - ColumnMajorMatrixVectorProductEmitter emitter( - dot_.shape().element_type(), /*tile_rows=*/8, - /*tile_cols=*/tiling_factor, m, k, - swap_operands ? rhs_array_.GetBasePointer() - : lhs_array_.GetBasePointer(), - swap_operands ? lhs_array_.GetBasePointer() - : rhs_array_.GetBasePointer(), - target_array_.GetBasePointer(), ir_builder_); - emitter.Emit(); + int64 tile_rows = 8; + int64 tile_cols = tiling_factor; + + string kernel_name = tensorflow::strings::StrCat( + "col_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows, + "_", tile_cols, "_", m, "_", k, addend_array_ ? "_with_addend" : ""); + + KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size, ir_builder_, kernel_name, + lhs_op, rhs_op, + addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op, + [this, tile_rows, tile_cols, m, k, primitive_type]( + llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op, + llvm::Value* result_op) { + ColumnMajorMatrixVectorProductEmitter emitter( + primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op, + addend_op, result_op, ir_builder_); + emitter.Emit(); + }); } else { VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m << " and k = " << k; - RowMajorMatrixVectorProductEmitter emitter( - dot_.shape().element_type(), /*tile_rows=*/tiling_factor, - /*tile_cols=*/8, m, k, - swap_operands ? rhs_array_.GetBasePointer() - : lhs_array_.GetBasePointer(), - swap_operands ? lhs_array_.GetBasePointer() - : rhs_array_.GetBasePointer(), - target_array_.GetBasePointer(), ir_builder_); - emitter.Emit(); + int64 tile_rows = tiling_factor; + int64 tile_cols = 8; + + string kernel_name = tensorflow::strings::StrCat( + "row_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows, + "_", tile_cols, "_", m, "_", k, addend_array_ ? "_with_addend" : ""); + + KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size, ir_builder_, kernel_name, + lhs_op, rhs_op, + addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op, + [this, tile_rows, tile_cols, m, k, primitive_type]( + llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op, + llvm::Value* result_op) { + RowMajorMatrixVectorProductEmitter emitter( + primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op, + addend_op, result_op, ir_builder_); + emitter.Emit(); + }); } return true; @@ -641,6 +712,8 @@ tensorflow::Status DotOpEmitter::Emit() { return Status::OK(); } + CHECK_EQ(addend_array_, nullptr); + if (PotentiallyImplementedAsEigenDot(dot_)) { return EmitCallToRuntime(); } @@ -915,8 +988,8 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { return {lhs_shape.dimensions(transpose_lhs_ ? 1 : 0), lhs_shape.dimensions(transpose_lhs_ ? 0 : 1), rhs_shape.dimensions(transpose_rhs_ ? 0 : 1), - lhs_shape.layout().minor_to_major(0) == 0, - rhs_shape.layout().minor_to_major(0) == 0}; + LayoutUtil::Minor(lhs_shape.layout(), 0) == 0, + LayoutUtil::Minor(rhs_shape.layout(), 0) == 0}; } llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest( @@ -927,8 +1000,8 @@ llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest( // reduction dimension. std::vector dimensions; const Shape& shape = operand_array.GetShape(); - for (int i = shape.layout().minor_to_major_size() - 1; i >= 0; --i) { - int64 dimension = shape.layout().minor_to_major(i); + for (int i = LayoutUtil::MinorToMajor(shape).size() - 1; i >= 0; --i) { + int64 dimension = LayoutUtil::Minor(shape.layout(), i); if (dimension != reduction_dimension) { dimensions.push_back(dimension); } @@ -977,9 +1050,7 @@ bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) { return false; } - if (ProfitableToImplementDotInUntiledLlvmIr(hlo) == - DotInLlvmIrProfitable::kYes || - ProfitableToImplementDotInTiledLlvmIr(hlo)) { + if (ProfitableToImplementDotInTiledLlvmIr(hlo)) { return false; } @@ -1010,46 +1081,42 @@ bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) { return false; } -DotInLlvmIrProfitable ProfitableToImplementDotInUntiledLlvmIr( - const HloInstruction& dot) { - if (dot.opcode() == HloOpcode::kDot && dot.shape().dimensions_size() == 2) { - const Shape& result_shape = dot.shape(); - // kReductionDimensionThresholdBytes was chosen to be 1/4 of a typical L1 - // cache line size, so that we can have the reduction dimension of both the - // LHS and RHS matrices and still have some space "left over". This needs - // to be tuned further. - const int64 kReductionDimensionThresholdBytes = 8 * 1024; - const bool single_threaded_eigen = - !dot.GetModule()->config().debug_options().xla_cpu_multi_thread_eigen(); - - // This is the point at which it is better to call into Eigen and shard the - // dot across multiple worker threads. This is a rough estimate by running - // a matmult benchmark on my local machine, and it can be tuned further. - const int64 kMaxSingleThreadedFlops = 16 * 1024; - - const int64 M = result_shape.dimensions(0); - const int64 N = result_shape.dimensions(1); - const int64 K = dot.operand(1)->shape().dimensions(0); - const int64 primitive_type_size = - ShapeUtil::ByteSizeOfPrimitiveType(result_shape.element_type()); - if (M == 1 && - K * primitive_type_size <= kReductionDimensionThresholdBytes && - (single_threaded_eigen || M * K * N <= kMaxSingleThreadedFlops)) { - // Heuristics: - // - // - Look for a configuration where we will likely be able to keep LHS in - // L1 and do a cache-optimal traversal of RHS. - // - // - Bail out on matrices that are large enough that Eigen can profitably - // shard the computation across multiple cores. This only applies when - // multi-threading is enabled. - return LayoutUtil::IsMonotonicWithDim0Major( - dot.operand(1)->shape().layout()) - ? DotInLlvmIrProfitable::kWithColumnMajorRhs - : DotInLlvmIrProfitable::kYes; +// For vector-matrix dot products, it is always profitable to make the Rhs +// column major. +tensorflow::gtl::optional ProfitableToMakeDotOperandColumnMajor( + const HloInstruction& hlo) { + if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() == 2 && + hlo.shape().dimensions(0) == 1) { + if (hlo.dot_dimension_numbers().rhs_contracting_dimensions(0) == 0) { + return 1; + } + return {}; + } + + if (hlo.opcode() == HloOpcode::kFusion && + hlo.fusion_kind() == HloInstruction::FusionKind::kOutput) { + auto* fusion_root = + hlo.fused_instructions_computation()->root_instruction(); + if (fusion_root->opcode() != HloOpcode::kAdd) { + return {}; + } + + for (auto* fusion_root_op : fusion_root->operands()) { + if (fusion_root_op->opcode() != HloOpcode::kDot) { + continue; + } + if (auto operand_num = + ProfitableToMakeDotOperandColumnMajor(*fusion_root_op)) { + auto* operand = fusion_root_op->operand(*operand_num); + if (operand->opcode() == HloOpcode::kParameter && + operand->user_count() == 1) { + return operand->parameter_number(); + } + } } } - return DotInLlvmIrProfitable::kNo; + + return {}; } bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot) { diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index c9168ccc0f6629c2a2bfbc7d4dc9c7ebab0a5708..2118965a70872846204974e25555340baca718cf 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -32,19 +32,11 @@ namespace cpu { bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo); -enum class DotInLlvmIrProfitable { kYes, kNo, kWithColumnMajorRhs }; - -// Returns a value to indicate if (and under what conditions) will lowering -// |dot| as a untiled LLVM IR dot operation be profitable over calling into -// Eigen or emitting a tiled LLVM IR implementation. Possible return values -// are: -// -// * DotInLlvmIrProfitable::kYes - always profitable. -// * DotInLlvmIrProfitable::kNo - never profitable. -// * DotInLlvmIrProfitable::kWithColumnMajorRhs - only if we can manage to make -// the Rhs layout column major. -DotInLlvmIrProfitable ProfitableToImplementDotInUntiledLlvmIr( - const HloInstruction& dot); +// Returns the index for an operand to `hlo` that should ideally be column +// major. Returns nullopt if there is no such operand or if `hlo` is not a dot +// or a fusion containing a dot. +tensorflow::gtl::optional ProfitableToMakeDotOperandColumnMajor( + const HloInstruction& hlo); // Returns true to indicate that we can generate a tiled LLVM IR implementation // for |dot|. @@ -57,10 +49,15 @@ class DotOpEmitter { // place the result in target_array. IR is emitted at current insert point of // the builder. Upon completion of the method, the insert point is set to the // end of all instructions emitted for this operation. + // + // If `addend_array` is not nullptr then it must be an array of the same + // dimensions as the result, and the result is computed as `addend_array` + + // dot(`lhs_array`, `rhs_array`). A non-null `addend_array` is only supported + // for Matrix-vector products. static tensorflow::Status EmitDotOperation( const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, const HloModuleConfig& hlo_module_config); @@ -69,6 +66,7 @@ class DotOpEmitter { bool transpose_rhs, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, const HloModuleConfig& hlo_module_config); @@ -140,6 +138,7 @@ class DotOpEmitter { const llvm_ir::IrArray& target_array_; const llvm_ir::IrArray& lhs_array_; const llvm_ir::IrArray& rhs_array_; + const llvm_ir::IrArray* addend_array_; llvm::Value* executable_run_options_value_; llvm::IRBuilder<>* ir_builder_; const HloModuleConfig& hlo_module_config_; diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index ba693ec89ab7c4090f8c9d1e4d65f17a80d0ac55..ebd96c4c42759b71b79408c73814605301af03c1 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -44,15 +44,11 @@ StatusOr CpuElementalIrEmitter::EmitFloatUnaryOp( default: return Unimplemented("tanh"); } - // Create function type for the function. - llvm::FunctionType* function_type = llvm::FunctionType::get( - llvm_ir::PrimitiveTypeToIrType(element_type, module_), - llvm_ir::PrimitiveTypeToIrType(element_type, module_), - /*isVarArg=*/false); // Create function declaration for 'tanhf'. llvm::Function* function = llvm::cast(module_->getOrInsertFunction( - llvm_ir::AsStringRef(function_name), function_type)); + llvm_ir::AsStringRef(function_name), operand_value->getType(), + operand_value->getType())); function->setCallingConv(llvm::CallingConv::C); function->setDoesNotThrow(); function->setDoesNotAccessMemory(); @@ -64,6 +60,31 @@ StatusOr CpuElementalIrEmitter::EmitFloatUnaryOp( } } +StatusOr CpuElementalIrEmitter::EmitAtan2( + PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { + string function_name; + switch (prim_type) { + case F32: + function_name = "atan2f"; + break; + case F64: + function_name = "atan2"; + break; + default: + return Unimplemented("atan2"); + } + // Create function declaration for 'atan2'. + llvm::Function* function = + llvm::cast(module_->getOrInsertFunction( + llvm_ir::AsStringRef(function_name), lhs->getType(), lhs->getType(), + rhs->getType())); + function->setCallingConv(llvm::CallingConv::C); + function->setDoesNotThrow(); + function->setDoesNotAccessMemory(); + // Create instruction to call 'atan2'. + return ir_builder_->CreateCall(function, {lhs, rhs}); +} + llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator) const { diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h index 7e9f27befb456c17581f556868712f92fd8fd083..4446dfd2821fb4b6e75f33694367392ecbcdd8bf 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h @@ -41,6 +41,8 @@ class CpuElementalIrEmitter : public ElementalIrEmitter { protected: StatusOr EmitFloatUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const override; + StatusOr EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, + llvm::Value* rhs) const override; IrEmitter* ir_emitter_; }; diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc index cb5cb8a6dd6d01febde46ac7dc0950f947fd3265..788217aab6172b4e548452b3f6ffd4197c163ce4 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc @@ -29,10 +29,8 @@ bool PotentiallyImplementedAsEigenConvolution( // The following conditions are necessary (but not sufficient) for // implementing `convolution` with Eigen convolution: // - the input and kernel have a non-zero number of elements. - // - the input is in NHWC or NWHC order. - // - the kernel is in HWIO or WHIO order. - // - the spatial dimensions are in the same relative order in the input, - // kernel and output. + // - the input is in NHWC order. + // - the kernel is in HWIO order. // // To be sufficient, certain layout constraints need to be satisfied as well. const Shape& input_shape = convolution.operand(0)->shape(); @@ -46,20 +44,30 @@ bool PotentiallyImplementedAsEigenConvolution( ShapeUtil::ElementIsComplex(kernel_shape)) { return false; } + if (window_util::HasWindowReversal(convolution.window())) { + return false; + } const ConvolutionDimensionNumbers& dnums = convolution.convolution_dimension_numbers(); // Only 1D and 2D convolutions are supported at the moment. // TODO(b/32897908): add an optimized implementation for 3D convolution. - if (dnums.spatial_dimensions_size() > 2) { + const int64 num_spatial_dims = dnums.output_spatial_dimensions_size(); + if (num_spatial_dims > 2) { return false; } - bool input_spatial_dims_ascending = std::is_sorted( - dnums.spatial_dimensions().begin(), dnums.spatial_dimensions().end()); - bool kernel_spatial_dims_ascending = - std::is_sorted(dnums.kernel_spatial_dimensions().begin(), - dnums.kernel_spatial_dimensions().end()); + for (int64 i = 0; i < num_spatial_dims; ++i) { + if (dnums.input_spatial_dimensions(i) != i + 1) { + return false; + } + if (dnums.kernel_spatial_dimensions(i) != i) { + return false; + } + if (dnums.output_spatial_dimensions(i) != i + 1) { + return false; + } + } const Shape& output_shape = convolution.shape(); return dnums.input_batch_dimension() == 0 && @@ -67,7 +75,6 @@ bool PotentiallyImplementedAsEigenConvolution( dnums.output_batch_dimension() == 0 && dnums.output_feature_dimension() == output_shape.dimensions_size() - 1 && - input_spatial_dims_ascending == kernel_spatial_dims_ascending && dnums.kernel_input_feature_dimension() == kernel_shape.dimensions_size() - 2 && dnums.kernel_output_feature_dimension() == diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h index ac361ddfb4c8d253ffb1c99200939f6324cad2bb..34b2003916933f5ec0a15d9e219063c0a912fa40 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMISSION_UTILS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMISSION_UTILS_H_ +#include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" namespace xla { @@ -23,6 +24,19 @@ namespace cpu { bool PotentiallyImplementedAsEigenConvolution( const HloInstruction& convolution); + +// Dynamic loop bounds are specified as an array of dimension index +// [start, limit) pairs of ir values (one for each partitioned outer dimension). +// +// EX: Let 'shape' = [8, 16, 32], with the loop bounds of the two-most major +// dimensions dynamic. Then 'dynamic_loop_bounds' will contain the +// following ir values for the two most-major dimensions: +// [dim0_index_start_ir_value, dim0_index_limit_ir_value] +// [dim1_index_start_ir_value, dim1_index_limit_ir_value] +// +// See IrFunction and ParallelLoopEmitter for details. +using DynamicLoopBounds = std::vector>; + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index c00f1d5c1dbe8a7dcb92e98df6604081d5e496ae..ef33260c17168b1516264a2f69cb80afb04ddeef 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -24,16 +24,17 @@ limitations under the License. #include #include +#include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/logging.h" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "llvm/CodeGen/TargetRegisterInfo.h" +#include "llvm/CodeGen/TargetSubtargetInfo.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" -#include "llvm/Target/TargetRegisterInfo.h" -#include "llvm/Target/TargetSubtargetInfo.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" @@ -42,6 +43,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/cpu/ir_function.h" +#include "tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h" #include "tensorflow/compiler/xla/service/cpu/shape_partition.h" #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" @@ -124,131 +127,27 @@ StatusOr IrEmitter::EmitComputation( } else { TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, *instruction_order)); } - InsertOrDie(&emitted_functions_, computation, compute_function_); - - return compute_function_; -} - -static llvm::Argument* GetArg(llvm::Function* f, int idx) { - llvm::Function::arg_iterator arg_iter = f->arg_begin(); - std::advance(arg_iter, idx); - return &*arg_iter; + llvm::Function* ir_function = compute_function_->function(); + InsertOrDie(&emitted_functions_, computation, ir_function); + // Delete 'compute_function', finalizing 'ir_function' and restoring caller + // IR insert point. + compute_function_.reset(); + return ir_function; } void IrEmitter::InitializeIrFunction(const string& function_name) { - // The function signature is: - // void function(i8* retval, i8* run_options, i8** params, i8** temps, - // i64* dynamic_loop_bounds, i64* prof_counters) - // - // retval: points to the returned value. - // params: address of an array with pointers to parameters. - // temps: address of an array with pointers to temporary buffers. - // - // Therefore, the generated function's signature (FunctionType) is statically - // determined - parameter unpacking is done in code generated into the - // function, rather than by a prologue dictated by the platform ABI. - // - // /--------------\ - // retval ----------> | return value | - // \--------------/ - // - // /-------------------------------\ - // run_options -----> | xla::ExecutableRunOptions | - // \-------------------------------/ - // - // /---------------------------------------------\ - // params --------> | param 0 | param 1 | ..... | param N-1 | - // | addr | addr | | addr | - // \---------------------------------------------/ - // | | | - // | | | - // V V V - // /---------\ /---------\ /-----------\ - // | param 0 | | param 1 | | param N-1 | - // \---------/ \---------/ \-----------/ - // - // /---------------------------------------------\ - // temps ---------> | temp 0 | temp 1 | ..... | temp N-1 | - // | addr | addr | | addr | - // \---------------------------------------------/ - // | | | - // | | | - // V V V - // /---------\ /---------\ /-----------\ - // | temp 0 | | temp 1 | | temp N-1 | - // \---------/ \---------/ \-----------/ - // - // /--------------------------------------------\ - // dynamic loop bounds -> | outer_dim0_start | outer_dim0_limit | .....| - // (elided for aot) \--------------------------------------------/ - // - // /---------------------------------------------\ - // prof counters -> | counter 0 | counter 1 | ..... | counter N-1 | - // (elided for aot) \---------------------------------------------/ - - // Even though the type of params and temps is void** in the host's view, in - // LLVM IR this is represented by i8*, similarly to void*. It's up to the code - // to use GEPs to unravel the indirection layers. - llvm::FunctionType* compute_function_type = llvm::FunctionType::get( - /*Result=*/llvm::Type::getVoidTy(module_->getContext()), - /*Params=*/GetComputeFunctionParams(), - /*isVarArg=*/false); - // Functions with local linkage get an inlining bonus. Because we know // a-priori that embedded functions (non-entry functions) will not have its // name resolved, give it local linkage. llvm::Function::LinkageTypes linkage = is_top_level_computation_ ? llvm::GlobalValue::ExternalLinkage : llvm::GlobalValue::InternalLinkage; - compute_function_ = - llvm::Function::Create(/*Ty=*/compute_function_type, - /*Linkage=*/linkage, - /*Name=*/AsStringRef(function_name), - /*Module=*/module_); - compute_function_->setCallingConv(llvm::CallingConv::C); - - // Set meaningful names for the function's arguments: useful for debugging. - llvm::Function::arg_iterator arg_iter = compute_function_->arg_begin(); - arg_iter->setName("retval"); - (++arg_iter)->setName("run_options"); - (++arg_iter)->setName("params"); - (++arg_iter)->setName("temps"); - if (num_dynamic_loop_bounds_ > 0) { - (++arg_iter)->setName("dynamic_loop_bounds"); - } - (++arg_iter)->setName("prof_counters"); - - // We know a-priori that the function arguments are guaranteed to point to - // disjoint objects. - llvm::Argument* retval = GetResultArgument(); - for (llvm::Argument& argument : compute_function_->args()) { - // However, the return buffer aliases the temporaries and thus cannot be - // marked noalias. - if (&argument == retval) { - continue; - } - compute_function_->addAttribute(argument.getArgNo() + 1, - llvm::Attribute::NoAlias); - } - - // Add the optize attribute to the function if optimizing for size. This - // controls internal behavior of some optimization passes (e.g. loop - // unrolling). - if (options::OptimizeForSizeRequested(hlo_module_config_)) { - compute_function_->addFnAttr(llvm::Attribute::OptimizeForSize); - } - - if (hlo_module_config_.debug_options().xla_enable_fast_math()) { - compute_function_->addFnAttr("unsafe-fp-math", "true"); - compute_function_->addFnAttr("no-infs-fp-math", "true"); - compute_function_->addFnAttr("no-nans-fp-math", "true"); - compute_function_->addFnAttr("no-signed-zeros-fp-math", "true"); - } - - ir_builder_.SetInsertPoint(llvm::BasicBlock::Create( - /*Context=*/module_->getContext(), - /*Name=*/"entry", - /*Parent=*/compute_function_)); + // Create and initialize new IrFunction. + compute_function_.reset( + new IrFunction(function_name, linkage, + options::OptimizeForSizeRequested(hlo_module_config_), + hlo_module_config_.debug_options().xla_enable_fast_math(), + module_, &ir_builder_, num_dynamic_loop_bounds_)); } IrEmitter::~IrEmitter() {} @@ -344,11 +243,12 @@ int IrEmitter::MinimumAlignmentForBufferSize(int64 buffer_size) { // Calculate the alignment of a buffer allocated for a given primitive type. int IrEmitter::MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type) { - int64 buffer_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); - DCHECK_GE(buffer_size, 0); - DCHECK_LE(buffer_size, SIZE_MAX); - - return MinimumAlignmentForBufferSize(buffer_size); + int64 byte_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); + DCHECK_GE(byte_size, 0); + // Largest scalar is a complex64 so we don't need to worry about the + // int64->int truncation here. + DCHECK_LE(byte_size, 8); + return byte_size; } int64 IrEmitter::ByteSizeOf(const Shape& shape) const { @@ -357,6 +257,10 @@ int64 IrEmitter::ByteSizeOf(const Shape& shape) const { // Calculate the alignment of a buffer allocated for a given shape. int IrEmitter::MinimumAlignmentForShape(const Shape& shape) { + if (ShapeUtil::IsScalar(shape)) { + return MinimumAlignmentForPrimitiveType(shape.element_type()); + } + int64 buffer_size = ByteSizeOf(shape); DCHECK_GE(buffer_size, 0); DCHECK_LE(buffer_size, SIZE_MAX); @@ -612,7 +516,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { HloComputation* function = reduce_window->to_apply(); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*reduce_window, /*operands=*/{operand}, - /*supported_types=*/{F32})); + /*supported_types=*/{F32, BF16})); // TODO(b/31410564): Implement dilation for reduce-window. if (window_util::HasDilation(window)) { @@ -795,7 +699,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { // operand index is within the bounds. The unsigned comparison includes // checking whether the operand index >= 0. llvm_ir::IrArray::Index operand_index(source_index.size()); - llvm::Value* in_bounds_condition = ir_builder_.getInt1(true); + llvm::Value* in_bounds_condition = ir_builder_.getTrue(); for (int64 i = 0; i < rank; ++i) { llvm::Value* strided_index = ir_builder_.CreateNSWMul( source_index[i], ir_builder_.getInt64(window.dimensions(i).stride())); @@ -822,14 +726,16 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { // If the initialized_flag is false, initialize the selected value and index // with the currently visiting operand. SetToFirstInsertPoint(if_initialized.false_block, &ir_builder_); - const auto save_operand_index = [&]( - const llvm_ir::IrArray::Index& operand_index) { - for (int64 i = 0; i < rank; ++i) { - llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP( - selected_index_address, {ir_builder_.getInt32(i)}); - ir_builder_.CreateStore(operand_index[i], selected_index_address_slot); - } - }; + const auto save_operand_index = + [&](const llvm_ir::IrArray::Index& operand_index) { + for (int64 i = 0; i < rank; ++i) { + llvm::Value* selected_index_address_slot = + ir_builder_.CreateInBoundsGEP(selected_index_address, + {ir_builder_.getInt32(i)}); + ir_builder_.CreateStore(operand_index[i], + selected_index_address_slot); + } + }; llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); llvm::Value* operand_data = operand_array.EmitReadArrayElement(operand_index, &ir_builder_); @@ -896,6 +802,24 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*dot, /*operands=*/{lhs, rhs}, /*supported_types=*/{F32, F64, C64})); + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); + if (dnums.lhs_batch_dimensions_size() > 0 || + dnums.rhs_batch_dimensions_size() > 0) { + return Unimplemented("Dot with batch dimensions not implemented."); + } + + if (dnums.lhs_contracting_dimensions_size() != 1) { + // This is disallowed by ShapeInference today. + return Unimplemented( + "Dot with multiple contracting dimensions not implemented."); + } + + if (dnums.lhs_contracting_dimensions(0) != + std::min(lhs->shape().dimensions_size() - 1, 1) || + dnums.rhs_contracting_dimensions(0) != 0) { + return Unimplemented( + "Dot with non-standard contracting dimensions not implemented."); + } llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs)); llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs)); @@ -914,8 +838,8 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // Dot operation is complicated so we delegate to a helper class. return DotOpEmitter::EmitDotOperation( *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array, - lhs_array, rhs_array, GetExecutableRunOptionsArgument(), &ir_builder_, - hlo_module_config_); + lhs_array, rhs_array, /*addend_array=*/nullptr, + GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_); } Status IrEmitter::HandleConvolution(HloInstruction* convolution) { @@ -952,11 +876,12 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { // Input tensor. const Shape& input_shape = convolution->operand(0)->shape(); int64 input_batch = input_shape.dimensions(dnums.input_batch_dimension()); - int64 input_rows = input_shape.dimensions(dnums.spatial_dimensions(0)); + int64 input_rows = + input_shape.dimensions(dnums.input_spatial_dimensions(0)); int64 input_cols = one_dim_convolution ? 1 - : input_shape.dimensions(dnums.spatial_dimensions(1)); + : input_shape.dimensions(dnums.input_spatial_dimensions(1)); int64 input_channels = input_shape.dimensions(dnums.input_feature_dimension()); @@ -976,11 +901,11 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { // Output tensor. const Shape& convolution_shape = convolution->shape(); int64 output_rows = - convolution_shape.dimensions(dnums.spatial_dimensions(0)); - int64 output_cols = - one_dim_convolution - ? 1 - : convolution_shape.dimensions(dnums.spatial_dimensions(1)); + convolution_shape.dimensions(dnums.output_spatial_dimensions(0)); + int64 output_cols = one_dim_convolution + ? 1 + : convolution_shape.dimensions( + dnums.output_spatial_dimensions(1)); // Extract the window stride for the convolution. const Window& window = convolution->window(); @@ -1068,10 +993,10 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { return EmitTargetElementLoop( convolution, [this, convolution, lhs, rhs, window, dnums](const llvm_ir::IrArray::Index& index) { - int num_spatial_dims = dnums.spatial_dimensions_size(); + int num_spatial_dims = dnums.output_spatial_dimensions_size(); std::vector output_spatial(num_spatial_dims); for (int i = 0; i < num_spatial_dims; ++i) { - output_spatial[i] = index[dnums.spatial_dimensions(i)]; + output_spatial[i] = index[dnums.output_spatial_dimensions(i)]; } llvm::Value* output_feature = index[dnums.output_feature_dimension()]; llvm::Value* batch = index[dnums.output_batch_dimension()]; @@ -1091,8 +1016,9 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { for (int i = 0; i < num_spatial_dims; ++i) { kernel_spatial[i] = loops - .AddLoop(0, rhs->shape().dimensions( - dnums.kernel_spatial_dimensions(i)), + .AddLoop(0, + rhs->shape().dimensions( + dnums.kernel_spatial_dimensions(i)), tensorflow::strings::StrCat("k", i)) ->GetIndVarValue(); } @@ -1108,17 +1034,18 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { // Calculate the spatial index in the input array, taking striding, // dilation and padding into account. An index in the padding will be // out of the bounds of the array. - const auto calculate_input_index = [this]( - llvm::Value* output_index, llvm::Value* kernel_index, - const WindowDimension& window_dim) { - llvm::Value* strided_index = ir_builder_.CreateNSWMul( - output_index, ir_builder_.getInt64(window_dim.stride())); - llvm::Value* dilated_kernel_index = ir_builder_.CreateNSWMul( - kernel_index, ir_builder_.getInt64(window_dim.window_dilation())); - return ir_builder_.CreateNSWSub( - ir_builder_.CreateNSWAdd(strided_index, dilated_kernel_index), - ir_builder_.getInt64(window_dim.padding_low())); - }; + const auto calculate_input_index = + [this](llvm::Value* output_index, llvm::Value* kernel_index, + const WindowDimension& window_dim) { + llvm::Value* strided_index = ir_builder_.CreateNSWMul( + output_index, ir_builder_.getInt64(window_dim.stride())); + llvm::Value* dilated_kernel_index = ir_builder_.CreateNSWMul( + kernel_index, + ir_builder_.getInt64(window_dim.window_dilation())); + return ir_builder_.CreateNSWSub( + ir_builder_.CreateNSWAdd(strided_index, dilated_kernel_index), + ir_builder_.getInt64(window_dim.padding_low())); + }; std::vector input_spatial(num_spatial_dims); for (int i = 0; i < num_spatial_dims; ++i) { input_spatial[i] = calculate_input_index( @@ -1140,11 +1067,11 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { return ir_builder_.CreateICmpEQ(remainder, ir_builder_.getInt64(0)); }; - llvm::Value* in_bounds_condition = nullptr; + llvm::Value* in_bounds_condition = ir_builder_.getInt1(true); for (int i = 0; i < num_spatial_dims; ++i) { llvm::ConstantInt* input_bound = ir_builder_.getInt64(window_util::DilatedBound( - lhs->shape().dimensions(dnums.spatial_dimensions(i)), + lhs->shape().dimensions(dnums.input_spatial_dimensions(i)), window.dimensions(i).base_dilation())); llvm::Value* dim_in_bound = ir_builder_.CreateICmpULT(input_spatial[i], input_bound); @@ -1153,9 +1080,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { llvm::Value* dim_ok = ir_builder_.CreateAnd(dim_in_bound, dim_not_in_hole); in_bounds_condition = - in_bounds_condition - ? ir_builder_.CreateAnd(in_bounds_condition, dim_ok) - : dim_ok; + ir_builder_.CreateAnd(in_bounds_condition, dim_ok); } // Now we need to map the dilated base coordinates back to the actual @@ -1178,7 +1103,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { int num_dims = num_spatial_dims + 2; llvm_ir::IrArray::Index input_index(num_dims); for (int i = 0; i < num_spatial_dims; ++i) { - input_index[dnums.spatial_dimensions(i)] = input_spatial[i]; + input_index[dnums.input_spatial_dimensions(i)] = input_spatial[i]; } input_index[dnums.input_feature_dimension()] = input_feature; input_index[dnums.input_batch_dimension()] = batch; @@ -1186,8 +1111,14 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { llvm_ir::IrArray kernel_array(GetIrArrayFor(rhs)); llvm_ir::IrArray::Index kernel_index(num_dims); for (int i = 0; i < num_spatial_dims; ++i) { - kernel_index[dnums.kernel_spatial_dimensions(i)] = kernel_spatial[i]; + kernel_index[dnums.kernel_spatial_dimensions(i)] = + window.dimensions(i).window_reversal() + ? ir_builder_.CreateNSWSub( + ir_builder_.getInt64(window.dimensions(i).size() - 1), + kernel_spatial[i]) + : kernel_spatial[i]; } + kernel_index[dnums.kernel_input_feature_dimension()] = input_feature; kernel_index[dnums.kernel_output_feature_dimension()] = output_feature; @@ -1449,15 +1380,20 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) { // // Where Param is the actual element type of the underlying buffer (for // example, float for an XLA F32 element type). - llvm::Argument* params = GetArg(compute_function_, 2); + llvm::Value* params = compute_function_->parameters_arg(); llvm::Value* param_address_offset = llvm_ir::EmitBufferIndexingGEP(params, param_number, &ir_builder_); llvm::LoadInst* param_address_untyped = ir_builder_.CreateLoad(param_address_offset); param_address_untyped->setName(AsStringRef(IrName(parameter, "untyped"))); - if (hlo_module_config_.debug_options() + if (is_top_level_computation_ && + hlo_module_config_.debug_options() .xla_llvm_enable_invariant_load_metadata()) { - // We never reassign parameters, so this load is invariant. + // In the entry computation the parameter slots in the %params argument are + // invariant through program execution. In computations that are called + // from the entry computation (via kWhile, kCall and kConditional) the + // parameter slots are *not* invariant since they're written to by their + // callers. param_address_untyped->setMetadata( llvm::LLVMContext::MD_invariant_load, llvm::MDNode::get(param_address_untyped->getContext(), /*MDs=*/{})); @@ -1584,13 +1520,9 @@ IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator( IrEmitter::ShardedVectorType IrEmitter::CreateShardedVectorType( PrimitiveType element_type, unsigned element_count) { - // Here we assume that the largest register is a vector register. - int max_vector_register_size_in_bytes = - target_machine_features_.largest_register_size_in_bytes( - compute_function_); - int vector_register_size_in_elements = - max_vector_register_size_in_bytes / + target_machine_features_.vector_register_byte_size( + *compute_function_->function()) / ShapeUtil::ByteSizeOfPrimitiveType(element_type); ShardedVectorType sharded_vector_type; @@ -1745,19 +1677,6 @@ void IrEmitter::EmitShardedVectorStore( } } -namespace { -// TODO(sanjoy): This is duplicated in tensorflow/core/lib/core/arena.cc. -// Extract out a common implementation to tensorflow/core/lib/math/math_util.h -uint32 GCD(uint32 x, uint32 y) { - while (y != 0) { - uint32 r = x % y; - x = y; - y = r; - } - return x; -} -} // namespace - StatusOr IrEmitter::EmitVectorizedReduce( HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, tensorflow::gtl::ArraySlice dimensions, HloComputation* function, @@ -1778,11 +1697,12 @@ StatusOr IrEmitter::EmitVectorizedReduce( bool is_reduction_over_minor_dimension = std::find(dimensions.begin(), dimensions.end(), - arg->shape().layout().minor_to_major(0)) != dimensions.end(); + LayoutUtil::Minor(arg->shape().layout(), 0)) != + dimensions.end(); - unsigned element_alignment = - GCD(ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()), - MinimumAlignmentForPrimitiveType(reduce->shape().element_type())); + unsigned element_alignment = tensorflow::MathUtil::GCD( + ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()), + MinimumAlignmentForPrimitiveType(reduce->shape().element_type())); if (is_reduction_over_minor_dimension) { // TODO(sanjoy): Implement vectorized reduction over the minor dimension. @@ -1815,8 +1735,9 @@ StatusOr IrEmitter::EmitVectorizedReduce( llvm_ir::ForLoopNest loop_nest(IrName(reduce), &ir_builder_); llvm_ir::IrArray::Index array_index(reduce->shape().dimensions_size()); - for (int i = reduce->shape().layout().minor_to_major_size() - 1; i > 0; --i) { - int64 dimension = reduce->shape().layout().minor_to_major(i); + for (int i = LayoutUtil::MinorToMajor(reduce->shape()).size() - 1; i > 0; + --i) { + int64 dimension = LayoutUtil::Minor(reduce->shape().layout(), i); int64 start_index = 0; int64 end_index = reduce->shape().dimensions(dimension); std::unique_ptr loop = @@ -1825,7 +1746,7 @@ StatusOr IrEmitter::EmitVectorizedReduce( array_index[dimension] = loop->GetIndVarValue(); } - int64 innermost_dimension = reduce->shape().layout().minor_to_major(0); + int64 innermost_dimension = LayoutUtil::Minor(reduce->shape().layout(), 0); int64 innermost_dimension_size = reduce->shape().dimensions(innermost_dimension); @@ -1861,10 +1782,10 @@ StatusOr IrEmitter::EmitVectorizedReduce( target_array); if (auto exit_terminator = loop->GetExitBasicBlock()->getTerminator()) { - CHECK_GT(reduce->shape().layout().minor_to_major_size(), 1); + CHECK_GT(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1); ir_builder_.SetInsertPoint(exit_terminator); } else { - CHECK_EQ(reduce->shape().layout().minor_to_major_size(), 1); + CHECK_EQ(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1); ir_builder_.SetInsertPoint(loop->GetExitBasicBlock()); } } @@ -1992,7 +1913,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { VLOG(2) << "HandleSlice: " << slice->ToString(); auto operand = slice->operand(0); // The code below emits a sequential loop nest. For the parallel backend, use - // EmitParallelTargetElementLoop() which respects dynamic loop bounds. + // ParallelLoopEmitter which respects dynamic loop bounds. if (ShouldEmitParallelLoopFor(*slice)) { return DefaultAction(slice); } @@ -2024,7 +1945,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { // * Implement the memcpy within the innermost loop. tensorflow::gtl::FlatSet inner_dims; - for (int64 dim : layout.minor_to_major()) { + for (int64 dim : LayoutUtil::MinorToMajor(layout)) { if (operand->shape().dimensions(dim) != slice->shape().dimensions(dim)) { break; } @@ -2051,7 +1972,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { // memcpy_dim is the innermost (in terms of layout) dimension for which the // slice does *not* just copy all the elements along the dimension. - const int64 memcpy_dim = layout.minor_to_major(inner_dims.size()); + const int64 memcpy_dim = LayoutUtil::Minor(layout, inner_dims.size()); const bool memcpy_is_contiguous = slice->slice_strides(memcpy_dim) == 1; // The number of logical elements that can be copied in a single call @@ -2260,8 +2181,8 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( *root, root->operand(0)->IsRank2Transpose(), root->operand(1)->IsRank2Transpose(), target_array, lhs_array, - rhs_array, GetExecutableRunOptionsArgument(), &ir_builder_, - hlo_module_config_)); + rhs_array, /*addend_array=*/nullptr, GetExecutableRunOptionsArgument(), + &ir_builder_, hlo_module_config_)); return Status::OK(); } else if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, assignment_)) { @@ -2282,6 +2203,35 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter)); return EmitTargetElementLoop(fusion, fused_emitter.GetRootGenerator()); + } else if (fusion->fusion_kind() == HloInstruction::FusionKind::kOutput) { + VLOG(3) << "HandleFusion kOutput"; + int64 dot_op_index = root->operand(0)->opcode() == HloOpcode::kDot ? 0 : 1; + const HloInstruction* dot = root->operand(dot_op_index); + CHECK_EQ(dot->opcode(), HloOpcode::kDot) + << dot->ToString() << " " + << fusion->fused_instructions_computation()->ToString(); + + int64 dot_lhs_param_number = dot->operand(0)->parameter_number(); + int64 dot_rhs_param_number = dot->operand(1)->parameter_number(); + int64 addend_param_number = + root->operand(1 - dot_op_index)->parameter_number(); + + Shape target_shape = fusion->shape(); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion)); + llvm_ir::IrArray target_array = GetIrArrayFor(fusion); + + llvm_ir::IrArray lhs_array( + GetIrArrayFor(fusion->operand(dot_lhs_param_number))); + llvm_ir::IrArray rhs_array( + GetIrArrayFor(fusion->operand(dot_rhs_param_number))); + llvm_ir::IrArray addend_array( + GetIrArrayFor(fusion->operand(addend_param_number))); + + TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( + *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array, + lhs_array, rhs_array, &addend_array, GetExecutableRunOptionsArgument(), + &ir_builder_, hlo_module_config_)); + return Status::OK(); } else { return Unimplemented("Fusion kind not implemented on CPU"); } @@ -2302,9 +2252,17 @@ Status IrEmitter::HandleCall(HloInstruction* call) { !parallel_cpu_backend_) { // ParallelTaskAssignment assigned partitions, emit call to // ParallelForkJoin. - TF_RETURN_IF_ERROR(EmitParallelForkJoin(parameter_addresses, - emitted_value_[call], computation, - call_ir_function)); + std::vector call_args = GetArrayFunctionCallArguments( + parameter_addresses, &ir_builder_, computation->name(), + /*return_value_buffer=*/emitted_value_[call], + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*temp_buffers_arg=*/GetTempBuffersArgument(), + /*profile_counters_arg=*/GetProfileCountersArgument()); + + HloInstruction* root = computation->root_instruction(); + TF_RETURN_IF_ERROR(EmitCallToParallelForkJoin( + call_args, root->shape(), root->outer_dimension_partitions(), + &ir_builder_, call_ir_function, computation->name())); } else { EmitArrayFunctionCallInto(call_ir_function, parameter_addresses, emitted_value_[call], computation->name()); @@ -2407,7 +2365,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { // Terminates the current block with a branch to a while header. llvm::BasicBlock* header_bb = llvm::BasicBlock::Create( module_->getContext(), AsStringRef(IrName(xla_while, "header")), - compute_function_); + compute_function_->function()); ir_builder_.CreateBr(header_bb); ir_builder_.SetInsertPoint(header_bb); @@ -2424,7 +2382,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { // Branches to the body or to the while exit depending on the condition. llvm::BasicBlock* body_bb = llvm::BasicBlock::Create( module_->getContext(), AsStringRef(IrName(xla_while, "body")), - compute_function_); + compute_function_->function()); llvm::BasicBlock* exit_bb = llvm::BasicBlock::Create( module_->getContext(), AsStringRef(IrName(xla_while, "exit"))); ir_builder_.CreateCondBr(while_predicate, body_bb, exit_bb); @@ -2439,7 +2397,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { ir_builder_.CreateBr(header_bb); // Adds the exit block to the function and sets the insert point there. - compute_function_->getBasicBlockList().push_back(exit_bb); + compute_function_->function()->getBasicBlockList().push_back(exit_bb); ir_builder_.SetInsertPoint(exit_bb); return Status::OK(); @@ -2475,14 +2433,13 @@ StatusOr IrEmitter::EmitFastConcatenate( int64 concat_dim = concatenate->dimensions(0); const Layout& output_layout = output_shape.layout(); + auto output_min2maj = LayoutUtil::MinorToMajor(output_layout); auto concat_dim_layout_itr = - std::find(output_layout.minor_to_major().begin(), - output_layout.minor_to_major().end(), concat_dim); + std::find(output_min2maj.begin(), output_min2maj.end(), concat_dim); - std::vector inner_dims(output_layout.minor_to_major().begin(), - concat_dim_layout_itr); + std::vector inner_dims(output_min2maj.begin(), concat_dim_layout_itr); std::vector outer_dims(std::next(concat_dim_layout_itr), - output_layout.minor_to_major().end()); + output_min2maj.end()); llvm::Type* i8_ptr_type = ir_builder_.getInt8PtrTy(); llvm::Type* i8_type = ir_builder_.getInt8Ty(); @@ -2557,7 +2514,7 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, const llvm_ir::IrArray& source_array) { unsigned primitive_type_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); - unsigned element_alignment = GCD( + unsigned element_alignment = tensorflow::MathUtil::GCD( primitive_type_size, MinimumAlignmentForPrimitiveType(primitive_type)); llvm::Type* primitive_ptr_type = llvm::PointerType::getUnqual( llvm_ir::PrimitiveTypeToIrType(primitive_type, module_)); @@ -2604,6 +2561,65 @@ Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) { return DefaultAction(concatenate); } +Status IrEmitter::HandleConditional(HloInstruction* conditional) { + auto pred = conditional->operand(0); + auto true_arg = conditional->operand(1); + auto false_arg = conditional->operand(2); + TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()) && + pred->shape().element_type() == PRED) + << "Predicate on a Conditional must be bool; got: " + << ShapeUtil::HumanString(pred->shape()); + + HloComputation* true_computation = conditional->true_computation(); + HloComputation* false_computation = conditional->false_computation(); + TF_RET_CHECK(ShapeUtil::Equal(conditional->shape(), + true_computation->root_instruction()->shape())) + << "Shape of conditional should be same as the shape of the true " + << "computation; got: " << ShapeUtil::HumanString(conditional->shape()) + << " and " + << ShapeUtil::HumanString(true_computation->root_instruction()->shape()); + + TF_RET_CHECK(ShapeUtil::Equal(conditional->shape(), + false_computation->root_instruction()->shape())) + << "Shape of conditional should be same as the shape of the false " + << "computation; got: " << ShapeUtil::HumanString(conditional->shape()) + << " and " + << ShapeUtil::HumanString(false_computation->root_instruction()->shape()); + + llvm::Function* true_function = + FindOrDie(emitted_functions_, true_computation); + llvm::Function* false_function = + FindOrDie(emitted_functions_, false_computation); + + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(conditional)); + llvm::Value* conditional_result = GetEmittedValueFor(conditional); + + // Generating: + // if (pred) + // cond_result = true_computation(true_operand) + // else + // cond_result = false_computation(false_operand) + llvm::LoadInst* pred_value = ir_builder_.CreateLoad( + GetIrArrayFor(pred).GetBasePointer(), "load_predicate_value"); + llvm::Value* pred_cond = ir_builder_.CreateICmpNE( + pred_value, + llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), + "boolean_predicate"); + llvm_ir::LlvmIfData if_data = + llvm_ir::EmitIfThenElse(pred_cond, "conditional", &ir_builder_); + + SetToFirstInsertPoint(if_data.true_block, &ir_builder_); + EmitArrayFunctionCallInto(true_function, {GetEmittedValueFor(true_arg)}, + conditional_result, IrName(conditional, "_true")); + + SetToFirstInsertPoint(if_data.false_block, &ir_builder_); + EmitArrayFunctionCallInto(false_function, {GetEmittedValueFor(false_arg)}, + conditional_result, IrName(conditional, "_false")); + + SetToFirstInsertPoint(if_data.after_block, &ir_builder_); + return Status::OK(); +} + Status IrEmitter::FinishVisit(HloInstruction* root) { // When this method is called, we should have already emitted an IR value for // the root (return) op. The IR value holds the address of the buffer holding @@ -2639,7 +2655,6 @@ Status IrEmitter::FinishVisit(HloInstruction* root) { if (prof_counter) { profiling_state_.RecordCompleteComputation(&ir_builder_, prof_counter); } - ir_builder_.CreateRetVoid(); return Status::OK(); } @@ -2780,43 +2795,16 @@ llvm::Type* IrEmitter::IrShapeType(const Shape& shape) { return llvm_ir::ShapeToIrType(shape, module_); } -std::vector IrEmitter::GetComputeFunctionParams() { - llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); - llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo(); - llvm::Type* i64_ptr_type = llvm::Type::getInt64PtrTy(module_->getContext()); - std::vector compute_function_params( - {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type}); - if (num_dynamic_loop_bounds_ > 0) { - compute_function_params.push_back(i64_ptr_type); - } - compute_function_params.push_back(i64_ptr_type); - return compute_function_params; -} - -llvm::Argument* IrEmitter::GetResultArgument() { - return GetArg(compute_function_, 0); -} - -llvm::Argument* IrEmitter::GetProfileCountersArgument() { - const int64 arg_index = num_dynamic_loop_bounds_ > 0 ? 5 : 4; - return GetArg(compute_function_, arg_index); +llvm::Value* IrEmitter::GetProfileCountersArgument() { + return compute_function_->profile_counters_arg(); } llvm::Value* IrEmitter::GetTempBuffersArgument() { - return GetArg(compute_function_, 3); -} - -llvm::Value* IrEmitter::GetDynamicLoopBound(const int64 offset) { - CHECK_GT(num_dynamic_loop_bounds_, 0); - CHECK_LT(offset, num_dynamic_loop_bounds_ * 2); - llvm::Argument* loop_bounds_arg = GetArg(compute_function_, 4); - string name = tensorflow::strings::StrCat("dynamic_loop_bound_", offset); - return ir_builder_.CreateLoad(ir_builder_.CreateGEP( - loop_bounds_arg, ir_builder_.getInt64(offset), AsStringRef(name))); + return compute_function_->temp_buffers_arg(); } llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() { - return GetArg(compute_function_, 1); + return compute_function_->exec_run_options_arg(); } llvm::Value* IrEmitter::EmitTempBufferPointer( @@ -2847,10 +2835,14 @@ llvm::Value* IrEmitter::EmitTempBufferPointer( GetTempBuffersArgument(), slice.index(), &ir_builder_); llvm::LoadInst* tempbuf_address_base = ir_builder_.CreateLoad(tempbuf_address_ptr); - if (hlo_module_config_.debug_options() + if (is_top_level_computation_ && + hlo_module_config_.debug_options() .xla_llvm_enable_invariant_load_metadata()) { - // Loading the address of a buffer is invariant of the point at which the - // load is executed in the program because we never reassign buffers. + // In the entry computation the parameter slots in the %params argument are + // invariant through program execution. In computations that are called + // from the entry computation (via kWhile, kCall and kConditional) the + // parameter slots are *not* invariant since they're written to by their + // callers. tempbuf_address_base->setMetadata( llvm::LLVMContext::MD_invariant_load, llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{})); @@ -2881,42 +2873,6 @@ llvm::Value* IrEmitter::EmitElementFunctionCall( AsStringRef(tensorflow::strings::StrCat(name, "_return_value"))); } -// Emits code to allocate an array of parameter address pointers, and store -// each address from 'parameter_addresses'. -// Returns an array of compute function call arguments (including parameter -// address buffer). -std::vector IrEmitter::GetArrayFunctionCallArguments( - tensorflow::gtl::ArraySlice parameter_addresses, - llvm::Value* return_value_buffer, tensorflow::StringPiece name) { - llvm::Value* parameter_addresses_buffer = - llvm_ir::EmitAllocaAtFunctionEntryWithCount( - ir_builder_.getInt8PtrTy(), - ir_builder_.getInt32(parameter_addresses.size()), - tensorflow::strings::StrCat(name, "_parameter_addresses"), - &ir_builder_); - for (size_t i = 0; i < parameter_addresses.size(); ++i) { - llvm::Value* parameter_as_i8ptr = ir_builder_.CreateBitCast( - parameter_addresses[i], ir_builder_.getInt8PtrTy(), - AsStringRef(tensorflow::strings::StrCat(name, "_parameter_", i, - "_address_as_i8ptr"))); - llvm::Value* slot_in_param_adresses = ir_builder_.CreateInBoundsGEP( - parameter_addresses_buffer, {ir_builder_.getInt64(i)}); - ir_builder_.CreateStore(parameter_as_i8ptr, slot_in_param_adresses); - } - - const auto to_int8_ptr = [this](llvm::Value* ptr) { - return ir_builder_.CreatePointerCast(ptr, ir_builder_.getInt8PtrTy()); - }; - std::vector arguments{ - to_int8_ptr(return_value_buffer), - to_int8_ptr(GetExecutableRunOptionsArgument()), - parameter_addresses_buffer, GetTempBuffersArgument()}; - if (auto* profile_counters = GetProfileCountersArgument()) { - arguments.push_back(profile_counters); - } - return arguments; -} - // Emits a core function call based on the following pseudo-code. // // char** parameter_addresses_buffer = @@ -2932,8 +2888,12 @@ void IrEmitter::EmitArrayFunctionCallInto( tensorflow::gtl::ArraySlice parameter_addresses, llvm::Value* return_value_buffer, tensorflow::StringPiece name) { ir_builder_.CreateCall( - function, GetArrayFunctionCallArguments(parameter_addresses, - return_value_buffer, name)); + function, GetArrayFunctionCallArguments( + parameter_addresses, &ir_builder_, name, + /*return_value_buffer=*/return_value_buffer, + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*temp_buffers_arg=*/GetTempBuffersArgument(), + /*profile_counters_arg=*/GetProfileCountersArgument())); } llvm::Value* IrEmitter::EmitArrayFunctionCall( @@ -2953,117 +2913,13 @@ llvm::Value* IrEmitter::EmitArrayFunctionCall( return return_value_buffer; } -// Emits a call to a runtime fork/join function which dispatches parallel -// calls to 'parallel_function' (and joins threads before returning). -Status IrEmitter::EmitParallelForkJoin( - tensorflow::gtl::ArraySlice parameter_addresses, - llvm::Value* output_address, HloComputation* computation, - llvm::Function* parallel_function) { - HloInstruction* root = computation->root_instruction(); - - // Build ParallelForkJoin function type. - std::vector compute_function_params = GetComputeFunctionParams(); - // Number of parallel compute functions. - compute_function_params.push_back(ir_builder_.getInt32Ty()); - // Array of partitions. There is an array element for each - // partition x partition_dim x 2 (for dimension start and limit). - compute_function_params.push_back( - llvm::Type::getInt64PtrTy(module_->getContext())); - // Number of partitioned most-major dimensions in 'root.shape'. - compute_function_params.push_back(ir_builder_.getInt32Ty()); - // Function pointer for compute function to be dispatched in parallel. - compute_function_params.push_back( - llvm::Type::getInt8PtrTy(module_->getContext())); - - llvm::FunctionType* fork_join_type = llvm::FunctionType::get( - /*Result=*/llvm::Type::getVoidTy(module_->getContext()), - /*Params=*/compute_function_params, - /*isVarArg=*/false); - - llvm::Function* fork_join_func = - llvm::cast(module_->getOrInsertFunction( - runtime::kParallelForkJoinSymbolName, fork_join_type)); - fork_join_func->setCallingConv(llvm::CallingConv::C); - fork_join_func->setDoesNotThrow(); - - // Add common compute function arguments. - const string name = computation->name(); - std::vector arguments = - GetArrayFunctionCallArguments(parameter_addresses, output_address, name); - - // Create ShapePartitionIterator to generate all partitions of 'root.shape'. - ShapePartitionIterator partition_iterator(root->shape(), - root->outer_dimension_partitions()); - const int64 num_partitions = partition_iterator.GetTotalPartitionCount(); - // Add argument specifying the number of parallel partitions. - arguments.push_back(ir_builder_.getInt32(num_partitions)); - - // The number of partitioned most-major dimensions in 'root.shape'. - const int32 num_partitioned_dims = root->outer_dimension_partitions().size(); - // A dimension partition consists of two elements: [start_index, limit_index). - const int32 dim_partition_size = 2; - // Calculate array partition stride. - const int32 array_partition_stride = - num_partitioned_dims * dim_partition_size; - // Calculate the total number of elements in the partition array. - const int32 partition_array_size = - dim_partition_size * num_partitioned_dims * num_partitions; - - // Store dimension partition values as llvm constants in 'partitions'. - // See comments in runtime_fork_join.cc for array layout description. - std::vector partitions(partition_array_size); - for (int32 i = 0; i < num_partitions; ++i) { - std::vector> dim_partitions = - partition_iterator.GetPartition(i); - CHECK_EQ(num_partitioned_dims, dim_partitions.size()); - const int32 partition_index = i * array_partition_stride; - for (int32 j = 0; j < num_partitioned_dims; ++j) { - const std::pair& dim_partition = dim_partitions[j]; - const int32 index = partition_index + j * dim_partition_size; - // Store partition [dim_start, dim_limit) intervals for each dimension. - partitions[index] = ir_builder_.getInt64(dim_partition.first); - partitions[index + 1] = - ir_builder_.getInt64(dim_partition.first + dim_partition.second); - } - } - - // Create global variable out of dimension partitions in 'partitions'. - llvm::ArrayType* partitions_array_type = - llvm::ArrayType::get(ir_builder_.getInt64Ty(), partition_array_size); - llvm::Constant* partitions_array = - llvm::ConstantArray::get(partitions_array_type, partitions); - llvm::GlobalVariable* global_partitions_array = new llvm::GlobalVariable( - /*Module=*/*module_, - /*Type=*/partitions_array_type, - /*isConstant=*/true, - /*Linkage=*/llvm::GlobalValue::PrivateLinkage, - /*Initializer=*/partitions_array, - /*Name=*/ - AsStringRef( - tensorflow::strings::StrCat(name, "_parallel_dimension_partitions"))); - - // Add argument specifying parallel dimension partitions. - arguments.push_back(ir_builder_.CreateBitCast( - global_partitions_array, - llvm::Type::getInt64PtrTy(module_->getContext()))); - // Add argument specifying the number of partitioned most-major dimensions. - arguments.push_back(ir_builder_.getInt32(num_partitioned_dims)); - // Add argument for parallel compute function pointer. - arguments.push_back( - ir_builder_.CreateBitCast(parallel_function, ir_builder_.getInt8PtrTy())); - // Emit call to parallel fork/join. - ir_builder_.CreateCall(fork_join_func, arguments); - - return Status::OK(); -} - Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { llvm::Value* addr; const Shape& target_shape = op->shape(); if (op == op->parent()->root_instruction()) { // For the root node, we write directly to the output buffer of the // function. - llvm::Argument* retval = GetResultArgument(); + llvm::Argument* retval = compute_function_->result_arg(); if (!ShapeUtil::IsNil(target_shape)) { llvm::AttrBuilder attr_builder; attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape)); @@ -3124,8 +2980,13 @@ Status IrEmitter::EmitTargetElementLoop( } else { if (ShouldEmitParallelLoopFor(*target_op)) { - TF_RETURN_IF_ERROR(EmitParallelTargetElementLoop( - target_shape, element_generator, IrName(target_op), &target_array)); + // Emit code to read dynamic loop bounds from compute function argument. + std::vector> dynamic_loop_bounds = + compute_function_->GetDynamicLoopBounds(); + // Emit parallel loop with dynamic loop bounds for most-major dimensions. + TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, target_array, + &dynamic_loop_bounds, &ir_builder_) + .EmitLoop(IrName(target_op))); } else { TF_RETURN_IF_ERROR( llvm_ir::LoopEmitter(element_generator, target_array, &ir_builder_) @@ -3135,60 +2996,6 @@ Status IrEmitter::EmitTargetElementLoop( return Status::OK(); } -Status IrEmitter::EmitParallelTargetElementLoop( - const Shape& target_shape, - const llvm_ir::ElementGenerator& element_generator, - tensorflow::StringPiece loop_name, llvm_ir::IrArray* target_array) { - CHECK(!ShapeUtil::IsTuple(target_shape)); - CHECK(!ShapeUtil::IsScalar(target_shape)); - - // Emit code to read dynamic loop bounds from function argument 4. - std::vector dynamic_loop_bounds(2 * num_dynamic_loop_bounds_); - for (int i = 0; i < 2 * num_dynamic_loop_bounds_; ++i) { - dynamic_loop_bounds[i] = GetDynamicLoopBound(i); - } - - llvm_ir::ForLoopNest loop_nest(loop_name, &ir_builder_); - const int64 num_dims = target_shape.dimensions_size(); - llvm_ir::IrArray::Index array_index(num_dims); - - // Add loops from outer-most to inner-most dimensions. - for (int i = target_shape.layout().minor_to_major_size() - 1; i >= 0; --i) { - const int64 dimension = target_shape.layout().minor_to_major(i); - const int bounds_index = num_dims - 1 - i; - if (bounds_index < num_dynamic_loop_bounds_) { - // Emit dynamic loop bounds for this dimension. Dynamic loop bounds - // are read from ir function dynamic loop bounds argument. - llvm::Value* start_index = dynamic_loop_bounds[bounds_index * 2 + 0]; - llvm::Value* end_index = dynamic_loop_bounds[bounds_index * 2 + 1]; - - std::unique_ptr loop = loop_nest.AddLoop( - /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension), - start_index, end_index); - array_index[dimension] = loop->GetIndVarValue(); - } else { - // Emit static loop bounds for this dimension. - std::unique_ptr loop = loop_nest.AddLoop( - /*start_index=*/0, - /*end_index=*/target_shape.dimensions(dimension), - /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension)); - array_index[dimension] = loop->GetIndVarValue(); - } - } - // Point IR builder at inner loop BB. - SetToFirstInsertPoint(loop_nest.GetInnerLoopBodyBasicBlock(), &ir_builder_); - - // Emit loop body. - TF_ASSIGN_OR_RETURN(llvm::Value * target_element, - element_generator(array_index)); - target_array->EmitWriteArrayElement(array_index, target_element, - &ir_builder_); - // Point IR builder at outer loop exit BB. - SetToFirstInsertPoint(loop_nest.GetOuterLoopExitBasicBlock(), &ir_builder_); - - return Status::OK(); -} - Status IrEmitter::EmitMemcpy(const HloInstruction& source, const HloInstruction& destination) { llvm::Value* source_value = GetEmittedValueFor(&source); @@ -3247,36 +3054,26 @@ StatusOr IrEmitter::EmitScalarCall( argument_addrs, name); } -unsigned TargetMachineFeatures::largest_register_size_in_bytes( - llvm::Function* function) { - auto itr = largest_register_size_in_bytes_.find(function); - if (itr != largest_register_size_in_bytes_.end()) { - return itr->second; +llvm::TargetTransformInfo* TargetMachineFeatures::GetTargetTransformInfoFor( + const llvm::Function& function) { + auto it = target_transform_infos_.find(&function); + if (it == target_transform_infos_.end()) { + // Using a dummy function analysis manager is kind of hacky, but LLVM's + // TargetTransformInfoWrapperPass::getTTI does the same thing. + // + // TODO(sanjoy): Fix this within LLVM by directly exposing + // TargetTransformInfo factories from TargetMachine. + llvm::FunctionAnalysisManager DummyFAM; + llvm::TargetTransformInfo target_transform_info = + target_machine_->getTargetIRAnalysis().run(function, DummyFAM); + auto emplace_result = target_transform_infos_.emplace( + &function, std::move(target_transform_info)); + CHECK(emplace_result.second); + it = emplace_result.first; } - int result = largest_register_size_in_bytes_impl(function); - - InsertOrDie(&largest_register_size_in_bytes_, function, result); - DCHECK_EQ(result, largest_register_size_in_bytes_.begin()->second); - return result; + return &it->second; } -unsigned TargetMachineFeatures::largest_register_size_in_bytes_impl( - llvm::Function* function) const { - auto register_info = - target_machine_->getSubtargetImpl(*function)->getRegisterInfo(); - - unsigned largest_register_size = 0; - for (const llvm::TargetRegisterClass* register_class : - register_info->regclasses()) { - if (register_class->isAllocatable()) { - largest_register_size = - std::max(largest_register_size, - register_info->getRegSizeInBits(*register_class)); - } - } - - return largest_register_size / 8; -} } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 351c95278c17f536e56d9f085b938a9baea9cde1..2341e3ea72ff312f2ca54b9495aff4065b34cd81 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -18,11 +18,13 @@ limitations under the License. #include #include +#include #include #include #include #include "llvm/ADT/Triple.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" @@ -30,6 +32,7 @@ limitations under the License. #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" +#include "tensorflow/compiler/xla/service/cpu/ir_function.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -52,15 +55,6 @@ namespace cpu { // Wraps an llvm::TargetMachine and parses out some information that feeds into // code LLVM IR generation decisions. -// -// Ideally we'd be able to use llvm::TargetTransformInfo here (since its -// interface is pretty much a perfect fit for our use case), but obtaining an -// instance of llvm::TargetTransformInfo outside an LLVM pass pipeline without -// super-ugly hacks is difficult. -// -// TODO(b/66049221): See if the LLVM community will be receptive to exposing an -// API that lets us directly create and use llvm::TargetTransformInfo instances -// outside of a pass manager. class TargetMachineFeatures { public: TargetMachineFeatures(llvm::TargetMachine* target_machine) @@ -75,20 +69,21 @@ class TargetMachineFeatures { return 128; } - // Return the size of the largest register size in bytes. We need to pass in + // Return the size of the largest vector size in bytes. We need to pass in // "function" since llvm functions can contain annotations for specializing // them to specific micro-architectures (though currently XLA does not use // this functionality). - // - // Ideally we should have been able to use - // llvm::TargetTransformInfo::getRegisterBitWidth(true) here. - unsigned largest_register_size_in_bytes(llvm::Function* function); + int vector_register_byte_size(const llvm::Function& function) { + llvm::TargetTransformInfo* tti = GetTargetTransformInfoFor(function); + return tti->getRegisterBitWidth(/*Vector=*/true) / 8; + } private: - unsigned largest_register_size_in_bytes_impl(llvm::Function* function) const; + llvm::TargetTransformInfo* GetTargetTransformInfoFor( + const llvm::Function& function); - tensorflow::gtl::FlatMap - largest_register_size_in_bytes_; + tensorflow::gtl::FlatMap + target_transform_infos_; llvm::TargetMachine* target_machine_; }; @@ -189,6 +184,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleCustomCall(HloInstruction* custom_call) override; Status HandleWhile(HloInstruction* xla_while) override; Status HandleConcatenate(HloInstruction* concatenate) override; + Status HandleConditional(HloInstruction* conditional) override; Status FinishVisit(HloInstruction* root) override; Status Preprocess(HloInstruction* hlo) override; @@ -233,16 +229,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Convenience function to get the IR type matching the given shape. llvm::Type* IrShapeType(const Shape& shape); - // Returns an array of compute function parameter types. - std::vector GetComputeFunctionParams(); - - // Get the llvm::Value* that represents the "retval" argument of the - // computation function being emitted by this emitter. - llvm::Argument* GetResultArgument(); - // Get the llvm::Value* that represents the "prof_counters" argument of the // computation function being emitted by this emitter. - llvm::Argument* GetProfileCountersArgument(); + llvm::Value* GetProfileCountersArgument(); // Get the xla::ExecutableRunOptions that represents the "run_options" // argument of the computation function being emitted by this emitter. @@ -252,11 +241,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { // computation function being emitted by this emitter. llvm::Value* GetTempBuffersArgument(); - // Emit ir to read and return the ir value for the dynamic loop bound at - // 'offset' from the "dynamic_loop_bounds" argument of the computation - // function being emitted by this emitter. - llvm::Value* GetDynamicLoopBound(const int64 offset); - // Emits code that computes the address of the given temporary buffer to the // function. target_shape is the shape of this temporary buffer. // The returned Value's type is a pointer to element_type. @@ -310,18 +294,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { tensorflow::gtl::ArraySlice parameter_addresses, tensorflow::StringPiece name); - // Returns an array of compute function call arguments. - std::vector GetArrayFunctionCallArguments( - tensorflow::gtl::ArraySlice parameter_addresses, - llvm::Value* return_value_buffer, tensorflow::StringPiece name); - - // Emits a call to a runtime fork/join function which dispatches parallel - // calls to 'parallel_function' (and joins threads before returning). - Status EmitParallelForkJoin( - tensorflow::gtl::ArraySlice parameter_addresses, - llvm::Value* output_address, HloComputation* computation, - llvm::Function* parallel_function); - // Verifies that the element types of all of the given operand instructions // match and are of one of the given supported types. Status ElementTypesSameAndSupported( @@ -346,15 +318,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { HloInstruction* target_op, tensorflow::StringPiece desc, const llvm_ir::ElementGenerator& element_generator); - // Emit IR to perform a computation for every element in a partition/slice of - // 'target_shape'. The loop bounds for the outer-dimension partitions are - // passed into the compute function as a runtime argument (accessible from - // GetDynamicLoopBound). - Status EmitParallelTargetElementLoop( - const Shape& target_shape, - const llvm_ir::ElementGenerator& element_generator, - tensorflow::StringPiece loop_name, llvm_ir::IrArray* target_array); - // Emits a memcpy from the source instruction's result value to the // destination's. Both source and destination must have an entry in the // emitted_value_ table. @@ -476,8 +439,10 @@ class IrEmitter : public DfsHloVisitorWithDefault { thread_local_buffers_; // The following fields track the IR emission state. According to LLVM memory - // management rules, their memory is owned by the module. - llvm::Function* compute_function_; + // management rules, their memory is owned by the module (Note that IrFunction + // creates the encapsulated llvm::Function s.t. it is added to the llvm + // module's function list). + std::unique_ptr compute_function_; llvm::IRBuilder<> ir_builder_; // Maps HLOs to their index into the profile counter array. @@ -490,7 +455,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm_ir::AliasAnalysis alias_analysis_; // The number of root instruction outer dimensions used in parallel loop - // emission (EmitParallelTargetElementLoop). + // emission (ParallelLoopEmitter). int64 num_dynamic_loop_bounds_ = 0; // Returns whether the given instruction should be emitted as a parallel loop. @@ -510,7 +475,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { use_rdtscp_(false), prof_counters_(nullptr) {} ProfilingState(bool is_top_level_computation, bool use_rdtscp, - llvm::Argument* prof_counters) + llvm::Value* prof_counters) : is_top_level_computation_(is_top_level_computation), use_rdtscp_(use_rdtscp), prof_counters_(prof_counters) {} @@ -543,7 +508,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { bool use_rdtscp_; // The argument which corresponds to the profile counter buffer. - llvm::Argument* prof_counters_; + llvm::Value* prof_counters_; // The first read cycle counter in the program. llvm::Value* first_read_cycle_start_ = nullptr; diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc new file mode 100644 index 0000000000000000000000000000000000000000..ca8c290dd1c4959e42026c3917d37f8fc95a1011 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc @@ -0,0 +1,333 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/cpu/ir_function.h" + +#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" +#include "tensorflow/compiler/xla/service/cpu/shape_partition.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace xla { + +namespace { +using llvm_ir::AsStringRef; +} // namespace + +namespace cpu { + +static std::vector GetComputeFunctionParams( + llvm::Module* llvm_module, const int64 num_dynamic_loop_bounds) { + llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(llvm_module->getContext()); + llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo(); + llvm::Type* i64_ptr_type = + llvm::Type::getInt64PtrTy(llvm_module->getContext()); + std::vector compute_function_params( + {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type}); + if (num_dynamic_loop_bounds > 0) { + compute_function_params.push_back(i64_ptr_type); + } + compute_function_params.push_back(i64_ptr_type); + return compute_function_params; +} + +IrFunction::IrFunction(const string& function_name, + llvm::Function::LinkageTypes linkage, + const bool optimize_for_size_requested, + const bool enable_fast_math, llvm::Module* llvm_module, + llvm::IRBuilder<>* ir_builder, + int64 num_dynamic_loop_bounds) + : ir_builder_(ir_builder), + llvm_module_(llvm_module), + caller_insert_point_guard_(*ir_builder), + num_dynamic_loop_bounds_(num_dynamic_loop_bounds) { + Initialize(function_name, linkage, optimize_for_size_requested, + enable_fast_math); +} + +IrFunction::~IrFunction() { + // Emit function return value. + ir_builder_->CreateRetVoid(); +} + +DynamicLoopBounds IrFunction::GetDynamicLoopBounds() { + DynamicLoopBounds dynamic_loop_bounds(num_dynamic_loop_bounds_); + for (int i = 0; i < num_dynamic_loop_bounds_; ++i) { + dynamic_loop_bounds[i].first = GetDynamicLoopBound(i * 2 + 0); + dynamic_loop_bounds[i].second = GetDynamicLoopBound(i * 2 + 1); + } + return dynamic_loop_bounds; +} + +void IrFunction::Initialize(const string& function_name, + llvm::Function::LinkageTypes linkage, + const bool optimize_for_size_requested, + const bool enable_fast_math) { + // The function signature is: + // void function(i8* retval, i8* run_options, i8** params, i8** temps, + // i64* dynamic_loop_bounds, i64* prof_counters) + // + // retval: points to the returned value. + // params: address of an array with pointers to parameters. + // temps: address of an array with pointers to temporary buffers. + // + // Therefore, the generated function's signature (FunctionType) is statically + // determined - parameter unpacking is done in code generated into the + // function, rather than by a prologue dictated by the platform ABI. + // + // /--------------\ + // retval ----------> | return value | + // \--------------/ + // + // /-------------------------------\ + // run_options -----> | xla::ExecutableRunOptions | + // \-------------------------------/ + // + // /---------------------------------------------\ + // params --------> | param 0 | param 1 | ..... | param N-1 | + // | addr | addr | | addr | + // \---------------------------------------------/ + // | | | + // | | | + // V V V + // /---------\ /---------\ /-----------\ + // | param 0 | | param 1 | | param N-1 | + // \---------/ \---------/ \-----------/ + // + // /---------------------------------------------\ + // temps ---------> | temp 0 | temp 1 | ..... | temp N-1 | + // | addr | addr | | addr | + // \---------------------------------------------/ + // | | | + // | | | + // V V V + // /---------\ /---------\ /-----------\ + // | temp 0 | | temp 1 | | temp N-1 | + // \---------/ \---------/ \-----------/ + // + // /--------------------------------------------\ + // dynamic loop bounds -> | outer_dim0_start | outer_dim0_limit | .....| + // (elided for aot) \--------------------------------------------/ + // + // /---------------------------------------------\ + // prof counters -> | counter 0 | counter 1 | ..... | counter N-1 | + // \---------------------------------------------/ + + // Even though the type of params and temps is void** in the host's view, in + // LLVM IR this is represented by i8*, similarly to void*. It's up to the code + // to use GEPs to unravel the indirection layers. + llvm::FunctionType* function_type = llvm::FunctionType::get( + /*Result=*/llvm::Type::getVoidTy(llvm_module_->getContext()), + /*Params=*/ + GetComputeFunctionParams(llvm_module_, num_dynamic_loop_bounds_), + /*isVarArg=*/false); + + // Functions with local linkage get an inlining bonus. Because we know + // a-priori that embedded functions (non-entry functions) will not have its + // name resolved, give it local linkage. + function_ = + llvm_ir::CreateFunction(function_type, linkage, + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size_requested, + function_name, llvm_module_); + + // Set meaningful names for the function's arguments: useful for debugging. + llvm::Function::arg_iterator arg_iter = function_->arg_begin(); + arg_iter->setName("retval"); + result_arg_ = &*arg_iter; + (++arg_iter)->setName("run_options"); + exec_run_options_arg_ = &*arg_iter; + (++arg_iter)->setName("params"); + parameters_arg_ = &*arg_iter; + (++arg_iter)->setName("temps"); + temp_buffers_arg_ = &*arg_iter; + if (num_dynamic_loop_bounds_ > 0) { + (++arg_iter)->setName("dynamic_loop_bounds"); + dynamic_loop_bounds_arg_ = &*arg_iter; + } + (++arg_iter)->setName("prof_counters"); + profile_counters_arg_ = &*arg_iter; + + // We know a-priori that the function arguments are guaranteed to point to + // disjoint objects. + llvm::Argument* retval = result_arg(); + for (llvm::Argument& argument : function_->args()) { + // However, the return buffer aliases the temporaries and thus cannot be + // marked noalias. + if (&argument == retval) { + continue; + } + function_->addAttribute(argument.getArgNo() + 1, llvm::Attribute::NoAlias); + } + + ir_builder_->SetInsertPoint(llvm::BasicBlock::Create( + /*Context=*/llvm_module_->getContext(), + /*Name=*/"entry", + /*Parent=*/function_)); +} + +llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) { + CHECK_GT(num_dynamic_loop_bounds_, 0); + CHECK_LT(offset, num_dynamic_loop_bounds_ * 2); + string name = tensorflow::strings::StrCat("dynamic_loop_bound_", offset); + return ir_builder_->CreateLoad( + ir_builder_->CreateGEP(CHECK_NOTNULL(dynamic_loop_bounds_arg_), + ir_builder_->getInt64(offset), AsStringRef(name))); +} + +// Emits code to allocate an array of parameter address pointers, and store +// each address from 'parameter_addresses'. +// Returns an array of compute function call arguments (including parameter +// address buffer). +std::vector GetArrayFunctionCallArguments( + tensorflow::gtl::ArraySlice parameter_addresses, + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece name, + llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, + llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) { + llvm::Value* parameter_addresses_buffer = + llvm_ir::EmitAllocaAtFunctionEntryWithCount( + ir_builder->getInt8PtrTy(), + ir_builder->getInt32(parameter_addresses.size()), + tensorflow::strings::StrCat(name, "_parameter_addresses"), + ir_builder); + for (size_t i = 0; i < parameter_addresses.size(); ++i) { + llvm::Value* parameter_as_i8ptr = ir_builder->CreateBitCast( + parameter_addresses[i], ir_builder->getInt8PtrTy(), + AsStringRef(tensorflow::strings::StrCat(name, "_parameter_", i, + "_address_as_i8ptr"))); + llvm::Value* slot_in_param_adresses = ir_builder->CreateInBoundsGEP( + parameter_addresses_buffer, {ir_builder->getInt64(i)}); + ir_builder->CreateStore(parameter_as_i8ptr, slot_in_param_adresses); + } + + const auto to_int8_ptr = [=](llvm::Value* ptr) { + return ir_builder->CreatePointerCast(ptr, ir_builder->getInt8PtrTy()); + }; + std::vector arguments{ + to_int8_ptr(return_value_buffer), to_int8_ptr(exec_run_options_arg), + parameter_addresses_buffer, temp_buffers_arg}; + if (profile_counters_arg != nullptr) { + arguments.push_back(profile_counters_arg); + } + return arguments; +} + +// Emits a call to a runtime fork/join function which dispatches parallel +// calls to 'parallel_function' (and joins threads before returning). +Status EmitCallToParallelForkJoin( + const std::vector& arguments, const Shape& shape, + const std::vector& dimension_partition_counts, + llvm::IRBuilder<>* ir_builder, llvm::Function* parallel_function, + const string& name) { + llvm::Module* module = ir_builder->GetInsertBlock()->getModule(); + + // Build ParallelForkJoin function type. + std::vector compute_function_params = + GetComputeFunctionParams(module, /*num_dynamic_loop_bounds=*/0); + // Number of parallel compute functions. + compute_function_params.push_back(ir_builder->getInt32Ty()); + // Array of partitions. There is an array element for each + // partition x partition_dim x 2 (for dimension start and limit). + compute_function_params.push_back( + llvm::Type::getInt64PtrTy(module->getContext())); + // Number of partitioned most-major dimensions in 'shape'. + compute_function_params.push_back(ir_builder->getInt32Ty()); + // Function pointer for compute function to be dispatched in parallel. + compute_function_params.push_back( + llvm::Type::getInt8PtrTy(module->getContext())); + + llvm::FunctionType* fork_join_type = llvm::FunctionType::get( + /*Result=*/llvm::Type::getVoidTy(module->getContext()), + /*Params=*/compute_function_params, + /*isVarArg=*/false); + + llvm::Function* fork_join_func = + llvm::cast(module->getOrInsertFunction( + runtime::kParallelForkJoinSymbolName, fork_join_type)); + fork_join_func->setCallingConv(llvm::CallingConv::C); + fork_join_func->setDoesNotThrow(); + + // Add common compute function arguments. + std::vector fork_join_arguments(arguments); + + // Create ShapePartitionIterator to generate all partitions of 'shape'. + ShapePartitionIterator partition_iterator(shape, dimension_partition_counts); + const int64 num_partitions = partition_iterator.GetTotalPartitionCount(); + // Add argument specifying the number of parallel partitions. + fork_join_arguments.push_back(ir_builder->getInt32(num_partitions)); + + // The number of partitioned most-major dimensions in 'shape'. + const int32 num_partitioned_dims = dimension_partition_counts.size(); + // A dimension partition consists of two elements: [start_index, limit_index). + const int32 dim_partition_size = 2; + // Calculate array partition stride. + const int32 array_partition_stride = + num_partitioned_dims * dim_partition_size; + // Calculate the total number of elements in the partition array. + const int32 partition_array_size = + dim_partition_size * num_partitioned_dims * num_partitions; + + // Store dimension partition values as llvm constants in 'partitions'. + // See comments in runtime_fork_join.cc for array layout description. + std::vector partitions(partition_array_size); + for (int32 i = 0; i < num_partitions; ++i) { + std::vector> dim_partitions = + partition_iterator.GetPartition(i); + CHECK_EQ(num_partitioned_dims, dim_partitions.size()); + const int32 partition_index = i * array_partition_stride; + for (int32 j = 0; j < num_partitioned_dims; ++j) { + const std::pair& dim_partition = dim_partitions[j]; + const int32 index = partition_index + j * dim_partition_size; + // Store partition [dim_start, dim_limit) intervals for each dimension. + partitions[index] = ir_builder->getInt64(dim_partition.first); + partitions[index + 1] = + ir_builder->getInt64(dim_partition.first + dim_partition.second); + } + } + + // Create global variable out of dimension partitions in 'partitions'. + llvm::ArrayType* partitions_array_type = + llvm::ArrayType::get(ir_builder->getInt64Ty(), partition_array_size); + llvm::Constant* partitions_array = + llvm::ConstantArray::get(partitions_array_type, partitions); + llvm::GlobalVariable* global_partitions_array = new llvm::GlobalVariable( + /*M=*/*module, + /*Ty=*/partitions_array_type, + /*isConstant=*/true, + /*Linkage=*/llvm::GlobalValue::PrivateLinkage, + /*Initializer=*/partitions_array, + /*Name=*/ + AsStringRef( + tensorflow::strings::StrCat(name, "_parallel_dimension_partitions"))); + + // Add argument specifying parallel dimension partitions. + fork_join_arguments.push_back(ir_builder->CreateBitCast( + global_partitions_array, + llvm::Type::getInt64PtrTy(module->getContext()))); + // Add argument specifying the number of partitioned most-major dimensions. + fork_join_arguments.push_back(ir_builder->getInt32(num_partitioned_dims)); + // Add argument for parallel compute function pointer. + fork_join_arguments.push_back( + ir_builder->CreateBitCast(parallel_function, ir_builder->getInt8PtrTy())); + // Emit call to parallel fork/join. + ir_builder->CreateCall(fork_join_func, fork_join_arguments); + + return Status::OK(); +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.h b/tensorflow/compiler/xla/service/cpu/ir_function.h new file mode 100644 index 0000000000000000000000000000000000000000..1fd2da4dce23982ed030f3aa8ec604182d0ebab8 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/ir_function.h @@ -0,0 +1,134 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_ + +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { +namespace cpu { + +// IrFunction creates and encapsulates an llvm::Function, exposing methods to +// emitters for function and function argument access. +// The llvm::Function is created with the standard function signature +// used in the XLA CPU backend (see ir_function.cc for argument details). +// In addtion IrFunction saves the callers IR insert point during contruction, +// and restores it after desctruction. +// +// Example usage: +// +// // Create and initialize new IrFunction. +// std::unique_ptr compute_function(new IrFunction(...)); +// // Emit IR for function body using IrFunction helper methods. +// ... +// // Store reference to llvm::Function for future invocation. +// ir_functions.push_back(compute_function.function()); +// // Delete IrFunction (finalizes IR function and restores caller insertion +// // point). +// compute_function.reset(); +// + +class IrFunction { + public: + IrFunction(const string& function_name, llvm::Function::LinkageTypes linkage, + const bool optimize_for_size_requested, + const bool enable_fast_math, llvm::Module* llvm_module, + llvm::IRBuilder<>* ir_builder, int64 num_dynamic_loop_bounds); + ~IrFunction(); + + // Emit ir to read and return the set of ir values representing the dynamic + // loop bounds argument of this function. + // Each element in returned vector is a pair of ir values representing + // the loop bounds for a specific dimension, where the first element of the + // pair is the dimension start index, and the second element of the pair + // is the dimension limit. + // EX: [dimension_i_index_start_ir_value, dimension_i_index_limit_ir_value] + // + DynamicLoopBounds GetDynamicLoopBounds(); + + // Returns the encapculated llvm::Function. + llvm::Function* function() { return function_; } + + // Get the llvm::Value* that represents this functions "retval" argument. + llvm::Argument* result_arg() { return result_arg_; } + + // Get the xla::ExecutableRunOptions that represents this functions + // "run_options" argument. + llvm::Value* exec_run_options_arg() { return exec_run_options_arg_; } + + // Get the llvm::Value* that represents this functions parameters argument. + llvm::Value* parameters_arg() { return parameters_arg_; } + + // Get the llvm::Value* that represents this functions "temps" argument. + llvm::Value* temp_buffers_arg() { return temp_buffers_arg_; } + + // Get the llvm::Value* that represents this functions "prof_counters" + // argument. + llvm::Value* profile_counters_arg() { return profile_counters_arg_; } + + private: + // Initialize an llvm::Function with standard signature based on arguments. + void Initialize(const string& function_name, + llvm::Function::LinkageTypes linkage, + bool optimize_for_size_requested, bool enable_fast_math); + + // Emit ir to read and return the ir value for the dynamic loop bound at + // 'offset' from the "dynamic_loop_bounds" argument of this function. + llvm::Value* GetDynamicLoopBound(int64 offset); + + llvm::IRBuilder<>* ir_builder_; + llvm::Module* llvm_module_; + llvm::IRBuilder<>::InsertPointGuard caller_insert_point_guard_; + + int64 num_dynamic_loop_bounds_ = 0; + // Encapsulated llvm::Function. + llvm::Function* function_; + // Function argument IR values. + llvm::Argument* result_arg_; + llvm::Value* exec_run_options_arg_; + llvm::Value* parameters_arg_; + llvm::Value* temp_buffers_arg_; + llvm::Value* dynamic_loop_bounds_arg_ = nullptr; + llvm::Value* profile_counters_arg_; +}; + +// Returns an array of compute function call argument ir values. +std::vector GetArrayFunctionCallArguments( + tensorflow::gtl::ArraySlice parameter_addresses, + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece name, + llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, + llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg); + +// Emits a call to a runtime fork/join function which dispatches parallel +// calls to 'parallel_function' (and joins threads before returning). +Status EmitCallToParallelForkJoin( + const std::vector& arguments, const Shape& shape, + const std::vector& dimension_partition_counts, + llvm::IRBuilder<>* ir_builder, llvm::Function* parallel_function, + const string& name); + +} // namespace cpu +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_ diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc index aff61296ced47a911ded207f611747564b5ac7eb..d1b88b27f068962fb86477fcad3e4390b1636c2b 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc @@ -59,19 +59,20 @@ ParallelCpuExecutable::ParallelCpuExecutable( std::unique_ptr assignment, std::unique_ptr hlo_module, std::unique_ptr> function_names, - std::unordered_map hlo_to_profile_idx, std::unordered_map> - aligned_constants) - : Executable(std::move(hlo_module)), + aligned_constants, + std::unique_ptr hlo_profile_printer, + std::unique_ptr hlo_profile_index_map) + : Executable(std::move(hlo_module), std::move(hlo_profile_printer), + std::move(hlo_profile_index_map)), jit_(std::move(jit)), assignment_(std::move(assignment)), function_names_(std::move(function_names)), - hlo_to_profile_idx_(std::move(hlo_to_profile_idx)), aligned_constants_(std::move(aligned_constants)) {} // Type of the computation function we expect in the JIT. using ComputeFunctionType = void (*)(void*, const void*, const void**, void**, - int64*, uint64*); + int64*, int64*); // Given a pointer to an output buffer (following the CPU JIT calling // conventions), mark addresses that are "live". The initial pointer itself is @@ -106,7 +107,7 @@ class Executor { const ServiceExecutableRunOptions* run_options, std::list* pending, HloInstructionMap* results, void** temps_array, - uint64* profile_counters_array, const BufferAssignment* assignment) + int64* profile_counters_array, const BufferAssignment* assignment) : functions_(functions), run_options_(run_options), pending_(pending), @@ -147,7 +148,7 @@ class Executor { std::list* pending_; HloInstructionMap* results_; void** temps_array_; - uint64* profile_counters_array_; + int64* profile_counters_array_; tensorflow::thread::ThreadPool* thread_pool_; const BufferAssignment* assignment_; @@ -375,23 +376,12 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( tensorflow::gtl::ArraySlice arguments, tensorflow::gtl::ArraySlice buffers, HloExecutionProfile* hlo_execution_profile) { - std::vector argument_buffers(arguments.size()); - for (int i = 0; i < arguments.size(); ++i) { - argument_buffers[i] = arguments[i]->buffer(/*index=*/{}); - } - return ExecuteComputeFunctions(run_options, argument_buffers, buffers, - hlo_execution_profile); -} - -Status ParallelCpuExecutable::ExecuteComputeFunctions( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - tensorflow::gtl::ArraySlice buffers, - HloExecutionProfile* hlo_execution_profile) { // Allocate profiling counters for each hlo instruction that we would like to - // profile. Allocate an additional profile counter for the entire - // computation. - std::vector profile_counters(hlo_to_profile_idx_.size() + 1); + // profile. + std::vector* profile_counters = nullptr; + if (hlo_execution_profile) { + profile_counters = hlo_execution_profile->mutable_profile_counters(); + } std::vector buffer_pointers; buffer_pointers.reserve(buffers.size()); @@ -425,8 +415,9 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( // just copy the existing buffer into the map containing instruction // results.. if (instruction->opcode() == HloOpcode::kParameter) { - InsertOrDie(&results, instruction, - arguments[instruction->parameter_number()].opaque()); + InsertOrDie( + &results, instruction, + arguments[instruction->parameter_number()]->root_buffer().opaque()); } else if (instruction->opcode() == HloOpcode::kConstant) { unsigned char* aligned_data = FindOrDie(aligned_constants_, instruction).get(); @@ -441,9 +432,9 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( // For example, if we expect a library conv/matmul call to run at max // concurrency, we should not dispatch runnable instructions until the // library call is finished (to avoid expensive cache invalidation). - Executor executor(functions, run_options, &pending, &results, - buffer_pointers.data(), profile_counters.data(), - assignment_.get()); + Executor executor( + functions, run_options, &pending, &results, buffer_pointers.data(), + profile_counters ? profile_counters->data() : nullptr, assignment_.get()); TF_RETURN_IF_ERROR(executor.Run()); @@ -453,86 +444,11 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( tensorflow::mutex_lock lock(mutex_); double nanoseconds = (end_micros - start_micros) * 1000.0; execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0)); - // The last profile counter is used for the computation as a whole. - execution_profile_.set_compute_cycle_count(profile_counters.back()); - } - if (hlo_execution_profile != nullptr) { - hlo_execution_profile->set_total_cycles_executed(entry_computation, - profile_counters.back()); - - for (auto hlo_prof_idx : hlo_to_profile_idx_) { - const HloInstruction* hlo = hlo_prof_idx.first; - uint64 cycles_taken = profile_counters[hlo_prof_idx.second]; - hlo_execution_profile->SetCyclesTakenBy(hlo, cycles_taken); - } } return Status::OK(); } -StatusOr -ParallelCpuExecutable::ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - HloExecutionProfile* hlo_execution_profile) { - se::Stream* stream = run_options->stream(); - DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - VLOG(3) << "ExecuteOnStream arg size: " << arguments.size(); - if (!arguments.empty()) { - VLOG(3) << "ExecuteOnStream arg[0]: " << arguments.at(0).opaque(); - } - - // Allocate the temporary buffers required for the computation. - se::StreamExecutor* stream_executor = stream->parent(); - int device_ordinal = stream_executor->device_ordinal(); - int64 buffer_count = assignment_->Allocations().size(); - VLOG(3) << "temp buffer count: " << buffer_count; - - std::vector device_allocations( - assignment_->Allocations().size()); - TF_RETURN_IF_ERROR(AllocateBuffers(memory_allocator, - stream->parent()->device_ordinal(), - &device_allocations)); - - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, - assignment_->GetUniqueTopLevelOutputSlice()); - const BufferAllocation::Index result_index = result_slice.index(); - VLOG(3) << "result index: " << result_index; - - TF_RETURN_IF_ERROR(ExecuteComputeFunctions( - run_options, arguments, device_allocations, hlo_execution_profile)); - - // Mark the buffers that are actually live (used in the output) when the - // computation finishes executing. - std::unordered_set marked_addresses; - MarkLiveAddressesInOutput(device_allocations[result_index].opaque(), - result_shape(), &marked_addresses); - - VLOG(3) << "Live addresses in output marking found " - << marked_addresses.size() << " addresses:\n" - << tensorflow::str_util::Join( - marked_addresses, ", ", [](string* out, const void* address) { - tensorflow::strings::StrAppend( - out, tensorflow::strings::Printf("%p", address)); - }); - - // Computation is done - deallocate temp buffers. Keep those marked - // live because they are referenced by the output of the computation - // and are needed by the service. They will be deallocated by the - // service. - for (size_t i = 0; i < device_allocations.size(); ++i) { - auto alloc = device_allocations[i]; - if (marked_addresses.count(alloc.opaque()) == 0 && - alloc.opaque() != nullptr) { - VLOG(3) << "ParallelCpuExecutable deallocating buffer #" << i << " [" - << alloc.opaque() << "]"; - TF_RETURN_IF_ERROR(memory_allocator->Deallocate(device_ordinal, &alloc)); - } - } - - return device_allocations[result_index]; -} - StatusOr> ParallelCpuExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, @@ -545,9 +461,9 @@ StatusOr> ParallelCpuExecutable::ExecuteOnStream( DeviceMemoryAllocator* memory_allocator = run_options->allocator(); std::vector buffers(assignment_->Allocations().size()); - auto result_buffer = - MakeUnique(result_shape(), stream->parent()->platform(), - stream->parent()->device_ordinal()); + auto result_buffer = MakeUnique( + /*on_host_shape=*/result_shape(), /*on_device_shape=*/result_shape(), + stream->parent()->platform(), stream->parent()->device_ordinal()); TF_RETURN_IF_ERROR(AllocateBuffers( memory_allocator, stream->parent()->device_ordinal(), &buffers)); @@ -558,37 +474,30 @@ StatusOr> ParallelCpuExecutable::ExecuteOnStream( // Copy DeviceMemoryBase values which into the respective location in // ShapedBuffer which is returned to the caller. std::vector buffers_in_result(assignment_->Allocations().size(), false); - TF_RETURN_IF_ERROR( - result_buffer->mutable_shape_index_to_buffer_entry() - ->ForEachMutableElementWithStatus( - [&buffers, &buffers_in_result, &result_buffer, this]( - const ShapeIndex& index, size_t* buffer_entry) { - const auto& sources = - this->GetRootPointsToSet().element(index); - // The points to set is unambiguous so the set should be a - // singleton. - CHECK_EQ(1, sources.size()); - const LogicalBuffer* buffer_source = sources[0]; - HloInstruction* src = buffer_source->instruction(); - - // The source for this result buffer can be a nested buffer - // such as a tuple element. - - // The source instruction should have a non-parameter buffer - // assigned. - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, - this->assignment_->GetUniqueSlice( - src, buffer_source->index())); - CHECK(!slice.allocation()->is_entry_computation_parameter()); - - const BufferAllocation::Index buffer_index = slice.index(); - const se::DeviceMemoryBase& buffer = buffers[buffer_index]; - CHECK(!buffer.is_null() || buffer.size() == 0); - *buffer_entry = result_buffer->mutable_buffers()->size(); - result_buffer->mutable_buffers()->push_back(buffer); - buffers_in_result[buffer_index] = true; - return Status::OK(); - })); + TF_RETURN_IF_ERROR(result_buffer->buffers().ForEachMutableElementWithStatus( + [&](const ShapeIndex& index, se::DeviceMemoryBase* device_memory) { + const auto& sources = this->GetRootPointsToSet().element(index); + + // The points to set is unambiguous so the set should be a singleton. + CHECK_EQ(1, sources.size()); + const LogicalBuffer* buffer_source = sources[0]; + HloInstruction* src = buffer_source->instruction(); + + // The source for this result buffer can be a nested buffer such as a + // tuple element. The source instruction should have a non-parameter + // buffer assigned. + TF_ASSIGN_OR_RETURN( + const BufferAllocation::Slice slice, + this->assignment_->GetUniqueSlice(src, buffer_source->index())); + CHECK(!slice.allocation()->is_entry_computation_parameter()); + + const BufferAllocation::Index buffer_index = slice.index(); + const se::DeviceMemoryBase& buffer = buffers[buffer_index]; + CHECK(!buffer.is_null() || buffer.size() == 0); + *device_memory = buffer; + buffers_in_result[buffer_index] = true; + return Status::OK(); + })); // Free all buffers not in the result. for (size_t i = 0; i < buffers.size(); ++i) { @@ -604,10 +513,10 @@ StatusOr> ParallelCpuExecutable::ExecuteOnStream( return std::move(result_buffer); } -StatusOr +StatusOr> ParallelCpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) { + tensorflow::gtl::ArraySlice arguments) { // TODO(b/30671675): Implement asynchronous execution mode. return Unimplemented( "Asynchronous execution on stream is not yet supported on CPU."); @@ -618,10 +527,5 @@ const PointsToSet& ParallelCpuExecutable::GetRootPointsToSet() const { module().entry_computation()->root_instruction()); } -std::unique_ptr ParallelCpuExecutable::CreateCostAnalysis() - const { - return MakeUnique(ShapeSizeBytes); -} - } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h index db16aaf48b0ef2aaa727c1bd0106bc51d1a65095..90ac94ef9288b2e860cb30c47ed44a7b96e4825d 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h @@ -52,27 +52,21 @@ class ParallelCpuExecutable : public Executable { std::unique_ptr assignment, std::unique_ptr hlo_module, std::unique_ptr> function_names, - std::unordered_map hlo_to_profile_idx, std::unordered_map> - aligned_constants); + aligned_constants, + std::unique_ptr hlo_profile_printer, + std::unique_ptr hlo_profile_index_map); ~ParallelCpuExecutable() override {} - StatusOr ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments, - HloExecutionProfile* hlo_execution_profile) override; - StatusOr> ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) override; - StatusOr ExecuteAsyncOnStream( + StatusOr> ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments) override; + tensorflow::gtl::ArraySlice arguments) override; // This should be called after set_ir_module_string. const string& ir_module_string() const { return ir_module_string_; } @@ -95,8 +89,6 @@ class ParallelCpuExecutable : public Executable { "Equality test on CPU parallel executable is not implemented."); } - std::unique_ptr CreateCostAnalysis() const override; - private: // Allocate buffers required for execution and assign them to the elements of // "buffers". "buffers" should be sized to the number of buffers in buffer @@ -109,13 +101,6 @@ class ParallelCpuExecutable : public Executable { // Calls the generated functions in 'function_names_', performing the // computation with the given arguments using the supplied buffers. - Status ExecuteComputeFunctions( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments, - tensorflow::gtl::ArraySlice - buffers, - HloExecutionProfile* hlo_execution_profile); Status ExecuteComputeFunctions( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, @@ -143,9 +128,6 @@ class ParallelCpuExecutable : public Executable { // Map containing the JITted function names for each HLO instruction. const std::unique_ptr> function_names_; - // Maps HLOs to their index into the profile counter array. - const std::unordered_map hlo_to_profile_idx_; - // Map from HLO Constant instructions to a pointer to their literal data. // The data stored in the protocol buffer might be insufficiently aligned, // we create a sufficiently aligned copy and store it in this map. diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc new file mode 100644 index 0000000000000000000000000000000000000000..1e439cde11cf74272101b80c867a308e51ab26a6 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc @@ -0,0 +1,76 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h" + +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" + +namespace xla { +namespace cpu { + +ParallelLoopEmitter::ParallelLoopEmitter( + const llvm_ir::ElementGenerator& target_element_generator, + const llvm_ir::IrArray& target_array, + const DynamicLoopBounds* dynamic_loop_bounds, llvm::IRBuilder<>* ir_builder) + : LoopEmitter(target_element_generator, target_array, ir_builder), + dynamic_loop_bounds_(dynamic_loop_bounds) {} + +llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( + tensorflow::StringPiece loop_name) { + CHECK(!ShapeUtil::IsTuple(shape_)); + CHECK(!ShapeUtil::IsScalar(shape_)); + + llvm_ir::ForLoopNest loop_nest(loop_name, ir_builder_); + const int64 num_dims = shape_.dimensions_size(); + llvm_ir::IrArray::Index array_index(num_dims); + + // Add loops from outer-most to inner-most dimensions. + for (int i = LayoutUtil::MinorToMajor(shape_).size() - 1; i >= 0; --i) { + const int64 dimension = LayoutUtil::Minor(shape_.layout(), i); + const int bounds_index = num_dims - 1 - i; + if (bounds_index < dynamic_loop_bounds_->size()) { + // Emit dynamic loop bounds for this dimension. Dynamic loop bounds + // are read from ir function dynamic loop bounds argument. + llvm::Value* start_index = (*dynamic_loop_bounds_)[bounds_index].first; + llvm::Value* end_index = (*dynamic_loop_bounds_)[bounds_index].second; + + std::unique_ptr loop = loop_nest.AddLoop( + /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension), + start_index, end_index); + array_index[dimension] = loop->GetIndVarValue(); + } else { + // Emit static loop bounds for this dimension. + std::unique_ptr loop = loop_nest.AddLoop( + /*start_index=*/0, + /*end_index=*/shape_.dimensions(dimension), + /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension)); + array_index[dimension] = loop->GetIndVarValue(); + } + } + // Point IR builder at inner loop BB. + llvm_ir::SetToFirstInsertPoint(loop_nest.GetInnerLoopBodyBasicBlock(), + ir_builder_); + + // Set exit_bb_ to the exit block of the loop nest. + exit_bb_ = loop_nest.GetOuterLoopExitBasicBlock(); + CHECK(exit_bb_ != nullptr); + + return array_index; +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h new file mode 100644 index 0000000000000000000000000000000000000000..9335d2818e99eb3588537d80dabddda08c1c020e --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h @@ -0,0 +1,73 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_LOOP_EMITTER_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_LOOP_EMITTER_H_ + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" + +namespace xla { +namespace cpu { + +// ParallelLoopEmitter emits a loop nest for the target array shape. +// The outer loop bounds of the loop nest are passed as ir values at runtime +// (specified in 'dynamic_loop_bounds'), and the inner loop bounds are static. +// Dynamic loop bounds are specified as an array of dimension index +// [start, limit) pairs of ir values (one for each partitioned outer dimension). +// +// EX: Let 'shape' = [8, 16, 32], with the loop bounds of the two-most major +// dimensions dynamic. Then 'dynamic_loop_bounds' will contain the +// following ir values for the two most-major dimensions: +// [dim0_index_start_ir_value, dim0_index_limit_ir_value] +// [dim1_index_start_ir_value, dim1_index_limit_ir_value] +// +// Code emitted by ParallelLoopEmitter will be called in a multi-threaded +// context where each thread will be assigned a different set of outer dimension +// partitions, and where all threads will collectively iterate over the +// entire target array shape. +// +// Outer dimension partitions can be generated using the ShapePartitionAssigner +// and ShapePartitionIterator utility classes from shape_partition.cc. +// +class ParallelLoopEmitter : public llvm_ir::LoopEmitter { + public: + // Constructs a ParallelLoopEmitter which uses 'target_element_generator' to + // generate elements, 'dynamic_loop_bounds' to set the loop bounds of the + // most-major dimensions, and 'target_array.' shape to set the static loop + // bounds for the most-minor dimensions. + ParallelLoopEmitter(const llvm_ir::ElementGenerator& target_element_generator, + const llvm_ir::IrArray& target_array, + const DynamicLoopBounds* dynamic_loop_bounds, + llvm::IRBuilder<>* ir_builder); + + ParallelLoopEmitter(const ParallelLoopEmitter&) = delete; + ParallelLoopEmitter& operator=(const ParallelLoopEmitter&) = delete; + ~ParallelLoopEmitter() override = default; + + llvm_ir::IrArray::Index EmitIndexAndSetExitBasicBlock( + tensorflow::StringPiece loop_name) override; + + private: + const DynamicLoopBounds* dynamic_loop_bounds_; +}; + +} // namespace cpu +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_LOOP_EMITTER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index cda2783307925b77ac6d8cfe679c5b325db2befc..c942cd6bf12c58873d5195f7454249763e639f91 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -102,9 +102,21 @@ llvm::StringRef GetHostCpuName() { CompilerFunctor::VectorIntrinsics GetAvailableIntrinsics() { CompilerFunctor::VectorIntrinsics intrinsics; - intrinsics.sse_intrinsics = (&__xla_cpu_runtime_ExpV4F32SSE != nullptr); - intrinsics.avx_intrinsics = (&__xla_cpu_runtime_ExpV8F32AVX != nullptr); - intrinsics.neon_intrinsics = (&__xla_cpu_runtime_ExpV4F32NEON != nullptr); +#ifdef __SSE4_1__ + intrinsics.sse_intrinsics = true; +#else + intrinsics.sse_intrinsics = false; +#endif +#ifdef __AVX__ + intrinsics.avx_intrinsics = true; +#else + intrinsics.avx_intrinsics = false; +#endif +#ifdef __ARM_NEON__ + intrinsics.neon_intrinsics = true; +#else + intrinsics.neon_intrinsics = false; +#endif return intrinsics; } @@ -201,12 +213,18 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64); +#ifdef __ARM_NEON__ REGISTER_CPU_RUNTIME_SYMBOL(ExpV4F32NEON); - REGISTER_CPU_RUNTIME_SYMBOL(ExpV4F32SSE); - REGISTER_CPU_RUNTIME_SYMBOL(ExpV8F32AVX); REGISTER_CPU_RUNTIME_SYMBOL(LogV4F32NEON); +#endif +#ifdef __SSE4_1__ + REGISTER_CPU_RUNTIME_SYMBOL(ExpV4F32SSE); REGISTER_CPU_RUNTIME_SYMBOL(LogV4F32SSE); +#endif +#ifdef __AVX__ + REGISTER_CPU_RUNTIME_SYMBOL(ExpV8F32AVX); REGISTER_CPU_RUNTIME_SYMBOL(LogV8F32AVX); +#endif REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation); @@ -275,7 +293,11 @@ bool RegisterKnownJITSymbols() { REGISTER_LIBM_SYMBOL(scalbln, double (*)(double, long)); REGISTER_LIBM_SYMBOL(scalbn, double (*)(double, int)); REGISTER_LIBM_SYMBOL(sin, double (*)(double)); +#ifdef __APPLE__ + REGISTER_LIBM_SYMBOL(__sincos, void (*)(double, double*, double*)); +#else REGISTER_LIBM_SYMBOL(sincos, void (*)(double, double*, double*)); +#endif REGISTER_LIBM_SYMBOL(sinh, double (*)(double)); REGISTER_LIBM_SYMBOL(sqrt, double (*)(double)); REGISTER_LIBM_SYMBOL(tan, double (*)(double)); diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index bc73839a88d8d3f231b4f3e924706b1a207562c6..0d54e325e618a7b1aae38407958ddf7b41ef1cda 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -86,6 +86,9 @@ class DfsHloVisitorBase { virtual Status HandleConvert(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } + virtual Status HandleBitcastConvert(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } virtual Status HandleCopy(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } @@ -208,6 +211,7 @@ class DfsHloVisitorBase { virtual Status HandleReduceWindow(HloInstructionPtr hlo) = 0; virtual Status HandleSelectAndScatter(HloInstructionPtr hlo) = 0; virtual Status HandleWhile(HloInstructionPtr hlo) = 0; + virtual Status HandleConditional(HloInstructionPtr hlo) = 0; virtual Status HandlePad(HloInstructionPtr hlo) = 0; @@ -243,6 +247,10 @@ class DfsHloVisitorBase { // affecting correctness. void ReserveVisitStates(int num) { visit_state_.Reserve(num); } + // Useful when we want to visit the same computation more than once with the + // same visitor. + void ResetVisitStates() { visit_state_.Reset(); } + void SetVisitState(int id, VisitState state) { visit_state_.SetState(id, state); } @@ -322,6 +330,7 @@ class DfsHloVisitorBase { *w = (*w & ~mask) | (static_cast(state) << shift); DCHECK_EQ(GetState(id), state); } + void Reset() { states_.clear(); } private: static const uint32 kStatesPerWord = sizeof(uint64) / 2 /*bits per entry*/; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 5415bab5b358edb3f64467f457e5273d117429b8..133aa2509405738de8388708b0c61a82023e2738 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -167,6 +167,9 @@ class DfsHloVisitorWithDefaultBase Status HandleWhile(HloInstructionPtr xla_while) override { return DefaultAction(xla_while); } + Status HandleConditional(HloInstructionPtr conditional) override { + return DefaultAction(conditional); + } Status HandleRecv(HloInstructionPtr recv) override { return DefaultAction(recv); } diff --git a/tensorflow/compiler/xla/service/dot_decomposer.cc b/tensorflow/compiler/xla/service/dot_decomposer.cc new file mode 100644 index 0000000000000000000000000000000000000000..12faed69677cd99c6ed82c8d13dad3138d9461b7 --- /dev/null +++ b/tensorflow/compiler/xla/service/dot_decomposer.cc @@ -0,0 +1,185 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dot_decomposer.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +namespace { + +// TODO(b/69062148) Remove this code when all backends support BatchDot +// natively. +Status DecomposeBatchDot(HloInstruction* dot) { + auto computation = dot->parent(); + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); + HloInstruction* lhs = dot->mutable_operand(0); + HloInstruction* rhs = dot->mutable_operand(1); + const Shape& lhs_shape = lhs->shape(); + const Shape& rhs_shape = rhs->shape(); + const Shape& dot_shape = dot->shape(); + + // ShapeInference should guarantee that lhs/rhs batch dimensions match. + CHECK_EQ(dnums.lhs_batch_dimensions_size(), + dnums.rhs_batch_dimensions_size()); + const int64 num_batch_dims = dnums.lhs_batch_dimensions_size(); + // Calculate total batch size (note that ShapeInference requires that + // the batch dimensions are most-major). + int64 batch_size = 1; + for (int i = 0; i < num_batch_dims; ++i) { + CHECK_EQ(lhs_shape.dimensions(dnums.lhs_batch_dimensions(i)), + rhs_shape.dimensions(dnums.rhs_batch_dimensions(i))); + batch_size *= lhs_shape.dimensions(dnums.lhs_batch_dimensions(i)); + } + + // Set lhs/rhs_transpose. + CHECK_EQ(1, dnums.lhs_contracting_dimensions_size()); + const int64 lhs_contracting_dim_number = dnums.lhs_contracting_dimensions(0); + const bool lhs_transpose = (lhs_contracting_dim_number - num_batch_dims) == 0; + + CHECK_EQ(1, dnums.rhs_contracting_dimensions_size()); + const int64 rhs_contracting_dim_number = dnums.rhs_contracting_dimensions(0); + const bool rhs_transpose = (rhs_contracting_dim_number - num_batch_dims) == 1; + + // Compute R3 and R3 shapes for lhs. + PrimitiveType lhs_type = lhs_shape.element_type(); + const int64 lhs_rows = lhs_shape.dimensions(num_batch_dims + 0); + const int64 lhs_cols = lhs_shape.dimensions(num_batch_dims + 1); + Shape lhs_shape_r3 = + ShapeUtil::MakeShape(lhs_type, {batch_size, lhs_rows, lhs_cols}); + Shape lhs_slice_shape_r3 = + ShapeUtil::MakeShape(lhs_type, {1, lhs_rows, lhs_cols}); + Shape lhs_slice_shape_r2 = + ShapeUtil::MakeShape(lhs_type, {lhs_rows, lhs_cols}); + + // Compute R3 and R3 shapes for rhs. + PrimitiveType rhs_type = rhs_shape.element_type(); + const int64 rhs_rows = rhs_shape.dimensions(num_batch_dims + 0); + const int64 rhs_cols = rhs_shape.dimensions(num_batch_dims + 1); + Shape rhs_shape_r3 = + ShapeUtil::MakeShape(rhs_type, {batch_size, rhs_rows, rhs_cols}); + Shape rhs_slice_shape_r3 = + ShapeUtil::MakeShape(rhs_type, {1, rhs_rows, rhs_cols}); + Shape rhs_slice_shape_r2 = + ShapeUtil::MakeShape(rhs_type, {rhs_rows, rhs_cols}); + + // Compute R3 and R3 shapes for dot output. + PrimitiveType dot_type = dot_shape.element_type(); + const int64 dot_rows = dot_shape.dimensions(num_batch_dims + 0); + const int64 dot_cols = dot_shape.dimensions(num_batch_dims + 1); + Shape dot_shape_r2 = ShapeUtil::MakeShape(dot_type, {dot_rows, dot_cols}); + Shape dot_shape_r3 = ShapeUtil::MakeShape(dot_type, {1, dot_rows, dot_cols}); + Shape concat_shape_r3 = + ShapeUtil::MakeShape(dot_type, {batch_size, dot_rows, dot_cols}); + + // Reshape lhs/rhs into R3. + auto lhs_r3 = computation->AddInstruction( + HloInstruction::CreateReshape(lhs_shape_r3, lhs)); + auto rhs_r3 = computation->AddInstruction( + HloInstruction::CreateReshape(rhs_shape_r3, rhs)); + + // Loop through batch size, slicing out required lhs/rhs to compute each Dot. + std::vector output_slices(batch_size); + for (int64 i = 0; i < batch_size; ++i) { + // Slice R3 shape from 'lhs' and reshape to R2. + auto lhs_slice_r3 = computation->AddInstruction( + HloInstruction::CreateSlice(lhs_slice_shape_r3, lhs_r3, {i, 0, 0}, + {i + 1, lhs_rows, lhs_cols}, {1, 1, 1})); + auto lhs_slice_r2 = computation->AddInstruction( + HloInstruction::CreateReshape(lhs_slice_shape_r2, lhs_slice_r3)); + + // Slice R3 shape from 'rhs' and reshape to R2. + auto rhs_slice_r3 = computation->AddInstruction( + HloInstruction::CreateSlice(rhs_slice_shape_r3, rhs_r3, {i, 0, 0}, + {i + 1, rhs_rows, rhs_cols}, {1, 1, 1})); + auto rhs_slice_r2 = computation->AddInstruction( + HloInstruction::CreateReshape(rhs_slice_shape_r2, rhs_slice_r3)); + + // Transpose lhs/rhs (if needed). + if (lhs_transpose) { + Shape lhs_slice_shape_r2_transpose = + ShapeUtil::MakeShape(lhs_type, {lhs_cols, lhs_rows}); + lhs_slice_r2 = + computation->AddInstruction(HloInstruction::CreateTranspose( + lhs_slice_shape_r2_transpose, lhs_slice_r2, {1, 0})); + } + if (rhs_transpose) { + Shape rhs_slice_shape_r2_transpose = + ShapeUtil::MakeShape(rhs_type, {rhs_cols, rhs_rows}); + rhs_slice_r2 = + computation->AddInstruction(HloInstruction::CreateTranspose( + rhs_slice_shape_r2_transpose, rhs_slice_r2, {1, 0})); + } + + // Compute Dot of lhs/rhs R2 slices. + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot_r2 = computation->AddInstruction(HloInstruction::CreateDot( + dot_shape_r2, lhs_slice_r2, rhs_slice_r2, dot_dnums)); + + // Reshape Dot to R3 so we can concat along batch dimension. + auto dot_r3 = computation->AddInstruction( + HloInstruction::CreateReshape(dot_shape_r3, dot_r2)); + + output_slices[i] = dot_r3; + } + + // Concatenate slices from 'output_slices' along batch dimension. + auto concat = computation->AddInstruction( + HloInstruction::CreateConcatenate(concat_shape_r3, output_slices, 0)); + // Reshape output 'new_dot' to original dimensions. + auto new_dot = computation->AddInstruction( + HloInstruction::CreateReshape(dot_shape, concat)); + + // Replace all uses of 'dot' in 'computation' with 'new_dot'. + return computation->ReplaceInstruction(dot, new_dot); +} + +} // namespace + +StatusOr DotDecomposer::Run(HloModule* module) { + XLA_VLOG_LINES(2, "DotDecomposer ENTRY\n" + module->ToString()); + // Gather all batch Dot operations. + std::vector batch_dots; + for (auto* computation : module->MakeNonfusionComputations()) { + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kDot) { + continue; + } + const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers(); + if (dnums.lhs_batch_dimensions_size() > 0 && decompose_batch_dot_) { + batch_dots.push_back(instruction); + } + } + } + // Decompose each batch Dot in 'batch_dots'. + bool changed = false; + for (auto* dot : batch_dots) { + TF_RETURN_IF_ERROR(DecomposeBatchDot(dot)); + changed = true; + } + XLA_VLOG_LINES(2, "DotDecompose EXIT\n" + module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dot_decomposer.h b/tensorflow/compiler/xla/service/dot_decomposer.h new file mode 100644 index 0000000000000000000000000000000000000000..5ff0ab34eac0cd0fbc264b408c57653c944402a6 --- /dev/null +++ b/tensorflow/compiler/xla/service/dot_decomposer.h @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_DOT_DECOMPOSER_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_DOT_DECOMPOSER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// DotDecomposer is a pass which decomposes batch Dot operations into a +// sequence of smaller (R2) Dot operations. +class DotDecomposer : public HloPassInterface { + public: + // Decomposes batch Dot operations when 'decompose_batch_dot' is true. + DotDecomposer(bool decompose_batch_dot = true) + : decompose_batch_dot_(decompose_batch_dot) {} + ~DotDecomposer() = default; + tensorflow::StringPiece name() const override { return "dot_decomposer"; } + + // Run DotDecomposer pass on computations in 'module'. + // Returns whether the 'module' was changed. + StatusOr Run(HloModule* module) override; + + private: + bool decompose_batch_dot_; +}; + +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_DOT_DECOMPOSER_H_ diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 606868034ac54c6fe0062d20e7a185c0a9ccd841..37929294327d2a57bb0ab1c48e90b6843cba6ae4 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -50,11 +50,161 @@ using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; using tensorflow::strings::StrCat; +namespace { + +llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits, + int64 mantissa_bits, + llvm::IRBuilder<>* ir_builder) { + // Integer and float types for casting and constant generation. + llvm::Type* float_type = x->getType(); + llvm::IntegerType* int_type = ir_builder->getInt32Ty(); + + // Cast the input value to an integer for bitwise manipulation. + llvm::Value* x_as_int = ir_builder->CreateBitCast(x, int_type); + + if (mantissa_bits < 23) { + // Last remaining mantissa bit. + const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits); + + // Compute rounding bias for round-to-nearest with ties to even. This is + // equal to a base value of 0111... plus one bit if the last remaining + // mantissa bit is 1. + const uint32_t base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1; + llvm::Value* x_last_mantissa_bit = ir_builder->CreateLShr( + ir_builder->CreateAnd( + x_as_int, llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)), + (23 - mantissa_bits)); + llvm::Value* x_rounding_bias = ir_builder->CreateAdd( + x_last_mantissa_bit, + llvm::ConstantInt::get(int_type, base_rounding_bias)); + + // Add rounding bias, and mask out truncated bits. Note that the case + // where adding the rounding bias overflows into the exponent bits is + // correct; the non-masked mantissa bits will all be zero, and the + // exponent will be incremented by one. + const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1); + x_as_int = ir_builder->CreateAdd(x_as_int, x_rounding_bias); + x_as_int = ir_builder->CreateAnd( + x_as_int, llvm::ConstantInt::get(int_type, truncation_mask)); + } + + if (exponent_bits < 8) { + // Masks for f32 values. + const uint32_t f32_sign_bit_mask = 1u << 31; + const uint32_t f32_exp_bits_mask = 0xffu << 23; + + // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most- + // significant bit -- is equal to 1.0f for all exponent sizes. Adding + // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit- + // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest' + // exponent (corresponding to 0.0f). + // + // Thus, the f32 exponent corresponding to the highest non-infinite + // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 + // exponent corresponding to the lowest exponent for a bit size of n is + // (2^7-1) - 2^(n-1)-1. + // + // Note that we have already checked that exponents_bits >= 1. + const uint32_t f32_exponent_bias = (1 << 7) - 1; + const uint32_t reduced_exponent_bias = (1 << (exponent_bits - 1)) - 1; + const uint32_t reduced_max_exponent = + f32_exponent_bias + reduced_exponent_bias; + const uint32_t reduced_min_exponent = + f32_exponent_bias - reduced_exponent_bias; + + // Do we overflow or underflow? + llvm::Value* x_exponent = ir_builder->CreateAnd( + x_as_int, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); + llvm::Value* x_overflows = ir_builder->CreateICmpUGT( + x_exponent, + llvm::ConstantInt::get(int_type, reduced_max_exponent << 23)); + llvm::Value* x_underflows = ir_builder->CreateICmpULE( + x_exponent, + llvm::ConstantInt::get(int_type, reduced_min_exponent << 23)); + + // Compute appropriately-signed values of zero and infinity. + llvm::Value* x_signed_zero = ir_builder->CreateAnd( + x_as_int, llvm::ConstantInt::get(int_type, f32_sign_bit_mask)); + llvm::Value* x_signed_inf = ir_builder->CreateOr( + x_signed_zero, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); + + // Force to zero or infinity if overflow or underflow. (Note that this + // truncates all denormal values to zero, rather than rounding them.) + x_as_int = ir_builder->CreateSelect(x_overflows, x_signed_inf, x_as_int); + x_as_int = ir_builder->CreateSelect(x_underflows, x_signed_zero, x_as_int); + } + + // Cast the result back to a floating-point type. + llvm::Value* result = ir_builder->CreateBitCast(x_as_int, float_type); + + // Correct result for NaN inputs. + // + // The exponent handling will "normalize" NaN values to infinities, which is + // undesirable (except in the case with no mantissa bits, in which case it + // is mandatory). This logic also handles cases where mantissa-rounding + // causes a NaN's mantissa to overflow into the exponent bits, which would + // otherwise create an erroneous zero value. + // + // If the fast-math flags are set to assume no NaNs, the comparison is likely + // to be optimized away, so there's no point in even emitting it. + if (!ir_builder->getFastMathFlags().noNaNs()) { + llvm::Value* x_is_nan = ir_builder->CreateFCmpUNO(x, x); + + if (mantissa_bits > 0) { + result = ir_builder->CreateSelect(x_is_nan, x, result); + } else { + result = ir_builder->CreateSelect( + x_is_nan, llvm::ConstantFP::getInfinity(float_type), result); + } + } + return result; +} + +llvm::Value* EmitF32ToBF16(llvm::Value* f32_value, + llvm::IRBuilder<>* ir_builder) { + auto reduced_precision = EmitReducePrecisionFloat( + f32_value, + /*exponent_bits=*/primitive_util::kBFloat16ExponentBits, + /*mantissa_bits=*/primitive_util::kBFloat16MantissaBits, ir_builder); + auto as_int32 = + ir_builder->CreateBitCast(reduced_precision, ir_builder->getInt32Ty()); + auto shifted = ir_builder->CreateLShr(as_int32, 16); + auto truncated = ir_builder->CreateTrunc(shifted, ir_builder->getInt16Ty()); + return ir_builder->CreateBitCast(truncated, ir_builder->getInt16Ty()); +} + +llvm::Value* EmitBF16ToF32(llvm::Value* bf16_value, + llvm::IRBuilder<>* ir_builder) { + auto as_int16 = + ir_builder->CreateBitCast(bf16_value, ir_builder->getInt16Ty()); + auto as_int32 = ir_builder->CreateZExt(as_int16, ir_builder->getInt32Ty()); + auto shifted = ir_builder->CreateShl(as_int32, 16); + return ir_builder->CreateBitCast(shifted, ir_builder->getFloatTy()); +} + +llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value, + PrimitiveType from_type, + PrimitiveType to_type, llvm::Module* module, + llvm::IRBuilder<>* ir_builder) { + if (primitive_util::IsSignedIntegralType(from_type)) { + return ir_builder->CreateSIToFP( + integer_value, llvm_ir::PrimitiveTypeToIrType(to_type, module)); + } else { + CHECK(primitive_util::IsUnsignedIntegralType(from_type) || + from_type == PRED); + return ir_builder->CreateUIToFP( + integer_value, llvm_ir::PrimitiveTypeToIrType(to_type, module)); + } +} + +} // namespace + StatusOr ElementalIrEmitter::EmitUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const { if (op->opcode() == HloOpcode::kCopy) { return operand_value; - } else if (operand_value->getType()->isIntegerTy()) { + } else if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) || + op->operand(0)->shape().element_type() == PRED) { return EmitIntegerUnaryOp(op, operand_value); } else if (ShapeUtil::ElementIsComplex(op->operand(0)->shape())) { return EmitComplexUnaryOp(op, operand_value); @@ -79,15 +229,14 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( primitive_util::IsSignedIntegralType(to_type)); } if (primitive_util::IsFloatingPointType(to_type)) { - if (primitive_util::IsSignedIntegralType(from_type)) { - return ir_builder_->CreateSIToFP( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); - } - if (primitive_util::IsUnsignedIntegralType(from_type) || - from_type == PRED) { - return ir_builder_->CreateUIToFP( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + if (to_type == BF16) { + return EmitF32ToBF16( + EmitIntegralToFloating(operand_value, from_type, F32, module_, + ir_builder_), + ir_builder_); } + return EmitIntegralToFloating(operand_value, from_type, to_type, + module_, ir_builder_); } if (primitive_util::IsComplexType(to_type)) { auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType( @@ -110,6 +259,26 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( PrimitiveType_Name(from_type).c_str(), PrimitiveType_Name(to_type).c_str()); } + case HloOpcode::kBitcastConvert: { + PrimitiveType from_type = op->operand(0)->shape().element_type(); + PrimitiveType to_type = op->shape().element_type(); + CHECK(primitive_util::IsIntegralType(from_type)); + if (from_type == to_type) { + return operand_value; + } + if (primitive_util::BitWidth(from_type) == + primitive_util::BitWidth(to_type)) { + return ir_builder_->CreateBitCast( + operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + } + return InvalidArgument( + "bitcast conversion from primitive type %s to %s with unequal " + "bit-widths (%u versus %u) ", + PrimitiveType_Name(from_type).c_str(), + PrimitiveType_Name(to_type).c_str(), + primitive_util::BitWidth(from_type), + primitive_util::BitWidth(to_type)); + } case HloOpcode::kAbs: { bool is_signed = primitive_util::IsSignedIntegralType(op->shape().element_type()); @@ -187,6 +356,17 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)), nullptr); } + if (from_type == BF16) { + TF_RET_CHECK(to_type != BF16); + operand_value = EmitBF16ToF32(operand_value, ir_builder_); + from_type = F32; + if (from_type == to_type) { + return operand_value; + } + } + if (from_type == F32 && to_type == BF16) { + return EmitF32ToBF16(operand_value, ir_builder_); + } if (primitive_util::IsFloatingPointType(to_type)) { return ir_builder_->CreateFPCast( operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); @@ -203,22 +383,34 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( PrimitiveType_Name(from_type).c_str(), PrimitiveType_Name(to_type).c_str()); } + case HloOpcode::kBitcastConvert: { + PrimitiveType from_type = op->operand(0)->shape().element_type(); + PrimitiveType to_type = op->shape().element_type(); + CHECK(primitive_util::IsFloatingPointType(from_type)); + if (from_type == to_type) { + return operand_value; + } + if (primitive_util::BitWidth(from_type) == + primitive_util::BitWidth(to_type)) { + return ir_builder_->CreateBitCast( + operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + } + return InvalidArgument( + "bitcast conversion from primitive type %s to %s with unequal " + "bit-widths (%u versus %u) ", + PrimitiveType_Name(from_type).c_str(), + PrimitiveType_Name(to_type).c_str(), + primitive_util::BitWidth(from_type), + primitive_util::BitWidth(to_type)); + } case HloOpcode::kExp: - return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {operand_value}, - {operand_value->getType()}, - ir_builder_); + return EmitExp(op->shape().element_type(), operand_value); case HloOpcode::kLog: - return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {operand_value}, - {operand_value->getType()}, - ir_builder_); + return EmitLog(op->shape().element_type(), operand_value); case HloOpcode::kCos: - return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {operand_value}, - {operand_value->getType()}, - ir_builder_); + return EmitCos(op->shape().element_type(), operand_value); case HloOpcode::kSin: - return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {operand_value}, - {operand_value->getType()}, - ir_builder_); + return EmitSin(op->shape().element_type(), operand_value); case HloOpcode::kFloor: return llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::floor, {operand_value}, {operand_value->getType()}, @@ -269,9 +461,25 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( StatusOr ElementalIrEmitter::EmitComplexUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const { + PrimitiveType input_type = op->operand(0)->shape().element_type(); + PrimitiveType component_type = + primitive_util::IsComplexType(input_type) + ? primitive_util::ComplexComponentType(input_type) + : input_type; switch (op->opcode()) { - // TODO(b/65209142): Angle/Log require atan2. - // case HloOpcode::kLog: // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a) + case HloOpcode::kLog: { + // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a) + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); + llvm::Type* llvm_ty = a->getType(); + auto sum_sq = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a), + ir_builder_->CreateFMul(b, b)); + TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq)); + TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a)); + auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); + return EmitComposeComplex( + op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle); + } case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); TF_RET_CHECK(primitive_util::IsComplexType(from_type)); @@ -293,15 +501,12 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( } case HloOpcode::kExp: { // e^(a+bi) = e^a*(cos(b)+sin(b)i) - auto exp_a = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::exp, {EmitExtractReal(operand_value)}, - {EmitExtractReal(operand_value)->getType()}, ir_builder_); - auto cos_b = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::cos, {EmitExtractImag(operand_value)}, - {EmitExtractImag(operand_value)->getType()}, ir_builder_); - auto sin_b = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::sin, {EmitExtractImag(operand_value)}, - {EmitExtractImag(operand_value)->getType()}, ir_builder_); + TF_ASSIGN_OR_RETURN( + auto exp_a, EmitExp(component_type, EmitExtractReal(operand_value))); + TF_ASSIGN_OR_RETURN( + auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value))); + TF_ASSIGN_OR_RETURN( + auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value))); return EmitComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), ir_builder_->CreateFMul(exp_a, sin_b)); } @@ -316,16 +521,13 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto a = EmitExtractReal(operand_value); auto b = EmitExtractImag(operand_value); auto type = a->getType(); - auto exp_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {b}, - {type}, ir_builder_); + TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b)); auto half_exp_b = ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b); auto half_exp_neg_b = ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b); - auto cos_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {a}, - {type}, ir_builder_); - auto sin_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {a}, - {type}, ir_builder_); + TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a)); + TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a)); return EmitComposeComplex( op, ir_builder_->CreateFMul( @@ -346,16 +548,13 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto a = EmitExtractReal(operand_value); auto b = EmitExtractImag(operand_value); auto type = a->getType(); - auto exp_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {b}, - {type}, ir_builder_); + TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b)); auto half_exp_b = ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b); auto half_exp_neg_b = ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b); - auto cos_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {a}, - {type}, ir_builder_); - auto sin_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {a}, - {type}, ir_builder_); + TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a)); + TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a)); return EmitComposeComplex( op, ir_builder_->CreateFMul( @@ -363,6 +562,58 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( ir_builder_->CreateFMul( cos_a, ir_builder_->CreateFSub(half_exp_b, half_exp_neg_b))); } + case HloOpcode::kTanh: { + /* + tanh=(exp(x)-exp(-x)) / (exp(x)+exp(-x)) + e^(a+bi) = e^a*(cos(b)+sin(b)i) + so tanh=(((cos(b)+sin(b)i)e^a - (cos(-b)+sin(-b)i)e^-a)) / + (((cos(b)+sin(b)i)e^a + (cos(-b)+sin(-b)i)e^-a)) + cos(b)=cos(-b), sin(-b)=-sin(b) + so tanh=(((cos(b)+sin(b)i)e^a - (cos(b)-sin(b)i)e^-a)) / + (((cos(b)+sin(b)i)e^a + (cos(b)-sin(b)i)e^-a)) + =(cos(b)e^a+i*sin(b)e^a + cos(b)(-e^-a)+i*sin(b)e^-a) / + (cos(b)e^a+i*sin(b)e^a + cos(b)e^-a+i*sin(b)(-e^-a)) + =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) / + (cos(b)(e^a+e^-a) + i*sin(b)(e^a-e^-a)) + This is a complex division, so we can multiply by denom_conj/denom_conj + =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) * + (cos(b)(e^a+e^-a) - i*sin(b)(e^a-e^-a)) / + ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2) + =(cos(b)^2(e^(2a)-e^(-2a)) + sin(b)^2(e^(2a)-e^(-2a)) + + i*(cos(b)sin(b)(e^a+e^-a)^2 - cos(b)sin(b)(e^a-e^-a)^2)) / + ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2) + */ + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); + TF_ASSIGN_OR_RETURN(auto exp_a, EmitExp(component_type, a)); + TF_ASSIGN_OR_RETURN(auto cos_b, EmitCos(component_type, b)); + TF_ASSIGN_OR_RETURN(auto sin_b, EmitSin(component_type, b)); + auto exp_neg_a = ir_builder_->CreateFDiv( + llvm::ConstantFP::get(exp_a->getType(), 1), exp_a); + auto exp_2a_minus_exp_neg_2a = ir_builder_->CreateFSub( + ir_builder_->CreateFMul(exp_a, exp_a), + ir_builder_->CreateFMul(exp_neg_a, exp_neg_a)); + auto cos_b_sq = ir_builder_->CreateFMul(cos_b, cos_b); + auto sin_b_sq = ir_builder_->CreateFMul(sin_b, sin_b); + auto real_num = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(cos_b_sq, exp_2a_minus_exp_neg_2a), + ir_builder_->CreateFMul(sin_b_sq, exp_2a_minus_exp_neg_2a)); + auto cos_b_sin_b = ir_builder_->CreateFMul(cos_b, sin_b); + auto exp_a_plus_exp_neg_a = ir_builder_->CreateFAdd(exp_a, exp_neg_a); + auto exp_a_plus_exp_neg_a_sq = + ir_builder_->CreateFMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a); + auto exp_a_minus_exp_neg_a = ir_builder_->CreateFSub(exp_a, exp_neg_a); + auto exp_a_minus_exp_neg_a_sq = + ir_builder_->CreateFMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a); + auto imag_num = ir_builder_->CreateFMul( + cos_b_sin_b, ir_builder_->CreateFSub(exp_a_plus_exp_neg_a_sq, + exp_a_minus_exp_neg_a_sq)); + auto denom = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(cos_b_sq, exp_a_plus_exp_neg_a_sq), + ir_builder_->CreateFMul(sin_b_sq, exp_a_minus_exp_neg_a_sq)); + return EmitComposeComplex(op, ir_builder_->CreateFDiv(real_num, denom), + ir_builder_->CreateFDiv(imag_num, denom)); + } case HloOpcode::kAbs: { auto sum_sq = ir_builder_->CreateFAdd( ir_builder_->CreateFMul(EmitExtractReal(operand_value), @@ -409,7 +660,8 @@ StatusOr ElementalIrEmitter::EmitBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) const { PrimitiveType operand_type = op->operand(0)->shape().element_type(); - if (lhs_value->getType()->isIntegerTy()) { + if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) || + operand_type == PRED) { return EmitIntegerBinaryOp( op, lhs_value, rhs_value, primitive_util::IsSignedIntegralType(operand_type)); @@ -424,7 +676,6 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) const { switch (op->opcode()) { - // case HloOpcode::kAtan2: // TODO(b/65209142): CPU atan2 support case HloOpcode::kComplex: return EmitComposeComplex(op, lhs_value, rhs_value); case HloOpcode::kAdd: @@ -468,10 +719,9 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( case HloOpcode::kMinimum: return EmitFloatMin(lhs_value, rhs_value); case HloOpcode::kPower: - return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, - {lhs_value, rhs_value}, - {lhs_value->getType()}, ir_builder_); - + return EmitPow(op->shape().element_type(), lhs_value, rhs_value); + case HloOpcode::kAtan2: + return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value); default: return Unimplemented("binary floating point op '%s'", HloOpcodeString(op->opcode()).c_str()); @@ -567,9 +817,40 @@ StatusOr ElementalIrEmitter::EmitComplexBinaryOp( EmitExtractImag(lhs_value), EmitExtractImag(rhs_value), ir_builder_)); - // TODO(b/65209142): requires arg(z) -> requires atan|atan2 intrinsic - // case HloOpcode::kPower: - // // (a+bi)^(c+di) = exp(i(c+di)*arg(a+bi)) * (a*a+b*b)^(c/2+di/2) + case HloOpcode::kPower: { + // (a+bi)^(c+di) = + // (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)), + // where q = c*atan2(b,a)+0.5d*ln(a*a+b*b) + PrimitiveType component_type = + primitive_util::ComplexComponentType(op->shape().element_type()); + auto a = EmitExtractReal(lhs_value); + auto b = EmitExtractImag(lhs_value); + auto c = EmitExtractReal(rhs_value); + auto d = EmitExtractImag(rhs_value); + auto aa_p_bb = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a), + ir_builder_->CreateFMul(b, b)); + auto one_half = llvm::ConstantFP::get(a->getType(), 0.5); + auto half_c = ir_builder_->CreateFMul(one_half, c); + + TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c, + EmitPow(component_type, aa_p_bb, half_c)); + auto neg_d = ir_builder_->CreateFNeg(d); + TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a)); + auto neg_d_arg_lhs = ir_builder_->CreateFMul(neg_d, arg_lhs); + TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs, + EmitExp(component_type, neg_d_arg_lhs)); + auto coeff = + ir_builder_->CreateFMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs); + TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb)); + auto half_d = ir_builder_->CreateFMul(one_half, d); + auto q = + ir_builder_->CreateFAdd(ir_builder_->CreateFMul(c, arg_lhs), + ir_builder_->CreateFMul(half_d, ln_aa_p_bb)); + TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q)); + TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q)); + return EmitComposeComplex(op, ir_builder_->CreateFMul(coeff, cos_q), + ir_builder_->CreateFMul(coeff, sin_q)); + } default: return Unimplemented("binary complex op '%s'", HloOpcodeString(op->opcode()).c_str()); @@ -672,116 +953,51 @@ StatusOr ElementalIrEmitter::EmitErfcInv( return EmitErfInv(prim_type, ir_builder_->CreateFSub(one, value)); } -StatusOr ElementalIrEmitter::EmitReducePrecision( - const HloInstruction* hlo, llvm::Value* x) const { - if (hlo->operand(0)->shape().element_type() != F32) { - return Unimplemented("reduce-precision only implemented for F32"); - } - - // Integer and float types for casting and constant generation. - llvm::Type* float_type = x->getType(); - llvm::IntegerType* int_type = ir_builder_->getInt32Ty(); - - // Cast the input value to an integer for bitwise manipulation. - llvm::Value* x_as_int = ir_builder_->CreateBitCast(x, int_type); - - if (hlo->mantissa_bits() < 23) { - // Last remaining mantissa bit. - const uint32_t last_mantissa_bit_mask = 1u << (23 - hlo->mantissa_bits()); - - // Compute rounding bias for round-to-nearest with ties to even. This is - // equal to a base value of 0111... plus one bit if the last remaining - // mantissa bit is 1. - const uint32_t base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1; - llvm::Value* x_last_mantissa_bit = ir_builder_->CreateLShr( - ir_builder_->CreateAnd( - x_as_int, llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)), - (23 - hlo->mantissa_bits())); - llvm::Value* x_rounding_bias = ir_builder_->CreateAdd( - x_last_mantissa_bit, - llvm::ConstantInt::get(int_type, base_rounding_bias)); - - // Add rounding bias, and mask out truncated bits. Note that the case - // where adding the rounding bias overflows into the exponent bits is - // correct; the non-masked mantissa bits will all be zero, and the - // exponent will be incremented by one. - const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1); - x_as_int = ir_builder_->CreateAdd(x_as_int, x_rounding_bias); - x_as_int = ir_builder_->CreateAnd( - x_as_int, llvm::ConstantInt::get(int_type, truncation_mask)); - } - - if (hlo->exponent_bits() < 8) { - // Masks for f32 values. - const uint32_t f32_sign_bit_mask = 1u << 31; - const uint32_t f32_exp_bits_mask = 0xffu << 23; - - // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most- - // significant bit -- is equal to 1.0f for all exponent sizes. Adding - // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit- - // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest' - // exponent (corresponding to 0.0f). - // - // Thus, the f32 exponent corresponding to the highest non-infinite - // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 - // exponent corresponding to the lowest exponent for a bit size of n is - // (2^7-1) - 2^(n-1)-1. - // - // Note that we have already checked that exponents_bits >= 1. - const uint32_t f32_exponent_bias = (1 << 7) - 1; - const uint32_t reduced_exponent_bias = - (1 << (hlo->exponent_bits() - 1)) - 1; - const uint32_t reduced_max_exponent = - f32_exponent_bias + reduced_exponent_bias; - const uint32_t reduced_min_exponent = - f32_exponent_bias - reduced_exponent_bias; +StatusOr ElementalIrEmitter::EmitLog(PrimitiveType prim_type, + llvm::Value* value) const { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {value}, + {value->getType()}, ir_builder_); +} - // Do we overflow or underflow? - llvm::Value* x_exponent = ir_builder_->CreateAnd( - x_as_int, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); - llvm::Value* x_overflows = ir_builder_->CreateICmpUGT( - x_exponent, - llvm::ConstantInt::get(int_type, reduced_max_exponent << 23)); - llvm::Value* x_underflows = ir_builder_->CreateICmpULE( - x_exponent, - llvm::ConstantInt::get(int_type, reduced_min_exponent << 23)); +StatusOr ElementalIrEmitter::EmitSin(PrimitiveType prim_type, + llvm::Value* value) const { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value}, + {value->getType()}, ir_builder_); +} - // Compute appropriately-signed values of zero and infinity. - llvm::Value* x_signed_zero = ir_builder_->CreateAnd( - x_as_int, llvm::ConstantInt::get(int_type, f32_sign_bit_mask)); - llvm::Value* x_signed_inf = ir_builder_->CreateOr( - x_signed_zero, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); +StatusOr ElementalIrEmitter::EmitCos(PrimitiveType prim_type, + llvm::Value* value) const { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {value}, + {value->getType()}, ir_builder_); +} - // Force to zero or infinity if overflow or underflow. (Note that this - // truncates all denormal values to zero, rather than rounding them.) - x_as_int = ir_builder_->CreateSelect(x_overflows, x_signed_inf, x_as_int); - x_as_int = ir_builder_->CreateSelect(x_underflows, x_signed_zero, x_as_int); - } +StatusOr ElementalIrEmitter::EmitExp(PrimitiveType prim_type, + llvm::Value* value) const { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value}, + {value->getType()}, ir_builder_); +} - // Cast the result back to a floating-point type. - llvm::Value* result = ir_builder_->CreateBitCast(x_as_int, float_type); +StatusOr ElementalIrEmitter::EmitPow(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs) const { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs}, + {lhs->getType()}, ir_builder_); +} - // Correct result for NaN inputs. - // - // The exponent handling will "normalize" NaN values to infinities, which is - // undesirable (except in the case with no mantissa bits, in which case it - // is mandatory). This logic also handles cases where mantissa-rounding - // causes a NaN's mantissa to overflow into the exponent bits, which would - // otherwise create an erroneous zero value. - // - // If the fast-math flags are set to assume no NaNs, the comparison is likely - // to be optimized away, so there's no point in even emitting it. - if (!ir_builder_->getFastMathFlags().noNaNs()) { - llvm::Value* x_is_nan = ir_builder_->CreateFCmpUNO(x, x); +StatusOr ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs) const { + return Unimplemented("atan2"); +} - if (hlo->mantissa_bits() > 0) { - result = ir_builder_->CreateSelect(x_is_nan, x, result); - } else { - result = ir_builder_->CreateSelect( - x_is_nan, llvm::ConstantFP::getInfinity(float_type), result); - } +StatusOr ElementalIrEmitter::EmitReducePrecision( + const HloInstruction* hlo, llvm::Value* x) const { + if (hlo->operand(0)->shape().element_type() != F32) { + return Unimplemented("reduce-precision only implemented for F32"); } - return result; + return EmitReducePrecisionFloat(x, /*exponent_bits=*/hlo->exponent_bits(), + /*mantissa_bits=*/hlo->mantissa_bits(), + ir_builder_); } StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( @@ -865,7 +1081,7 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( // If no implicit broadcast is needed for this operand, returns the target // index as the source index. - if (ShapeUtil::Compatible(operand_shape, hlo.shape())) { + if (ShapeUtil::CompatibleIgnoringElementType(operand_shape, hlo.shape())) { return target_index; } @@ -1073,6 +1289,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kRoundNearestAfz: case HloOpcode::kCeil: case HloOpcode::kConvert: + case HloOpcode::kBitcastConvert: case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kExp: @@ -1081,11 +1298,11 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kIsFinite: case HloOpcode::kLog: case HloOpcode::kNegate: + case HloOpcode::kNot: case HloOpcode::kReal: case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kTanh: - case HloOpcode::kNot: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, @@ -1094,6 +1311,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( return EmitUnaryOp(hlo, operand_value); }; case HloOpcode::kAdd: + case HloOpcode::kAnd: case HloOpcode::kAtan2: case HloOpcode::kComplex: case HloOpcode::kDivide: @@ -1106,14 +1324,13 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kMinimum: case HloOpcode::kMultiply: case HloOpcode::kNe: + case HloOpcode::kOr: case HloOpcode::kPower: case HloOpcode::kRemainder: - case HloOpcode::kSubtract: - case HloOpcode::kAnd: - case HloOpcode::kOr: case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: + case HloOpcode::kSubtract: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { const HloInstruction* lhs = hlo->operand(0); diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index cccb498f82936283a215370787907b293827ff2d..1a48eb5fcb960b60d524ea56a43e15269576db76 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -39,7 +39,7 @@ class ElementalIrEmitter { module_(module), hlo_module_config_(hlo_module_config) {} - virtual ~ElementalIrEmitter() {} + virtual ~ElementalIrEmitter() = default; virtual StatusOr EmitUnaryOp(const HloInstruction* op, llvm::Value* operand_value) const; @@ -92,6 +92,26 @@ class ElementalIrEmitter { virtual StatusOr EmitErfcInv(PrimitiveType prim_type, llvm::Value* value) const; + virtual StatusOr EmitAtan2(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs) const; + + virtual StatusOr EmitLog(PrimitiveType prim_type, + llvm::Value* value) const; + + virtual StatusOr EmitSin(PrimitiveType prim_type, + llvm::Value* value) const; + + virtual StatusOr EmitCos(PrimitiveType prim_type, + llvm::Value* value) const; + + virtual StatusOr EmitExp(PrimitiveType prim_type, + llvm::Value* value) const; + + virtual StatusOr EmitPow(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs) const; + virtual StatusOr EmitReducePrecision(const HloInstruction* hlo, llvm::Value* x) const; diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 9c96d9eb30b5f9e51b7f5d82391c6b9f366898d6..c50aaec5725021eeaa2fe0c3247f7539327268ae 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -26,23 +26,23 @@ limitations under the License. namespace xla { -StatusOr> +StatusOr>> Executable::ExecuteOnStreams( tensorflow::gtl::ArraySlice run_options, tensorflow::gtl::ArraySlice< - tensorflow::gtl::ArraySlice> + tensorflow::gtl::ArraySlice> arguments) { TF_RET_CHECK(run_options.size() == arguments.size()); + std::vector> return_values(run_options.size()); + if (run_options.size() == 1) { - TF_ASSIGN_OR_RETURN(auto result, + TF_ASSIGN_OR_RETURN(return_values[0], ExecuteOnStream(&run_options[0], arguments[0], /*hlo_execution_profile=*/nullptr)); - return std::vector({result}); + return std::move(return_values); } - std::vector return_values( - run_options.size()); for (size_t i = 0; i < run_options.size(); ++i) { // We cannot BlockHostUntilDone() on the already-launched executions in case // of error, since if the executions communicate, the initially launched @@ -52,9 +52,9 @@ Executable::ExecuteOnStreams( } for (const auto& options : run_options) { TF_RET_CHECK(options.stream() != nullptr); - options.stream()->BlockHostUntilDone(); + TF_RETURN_IF_ERROR(options.stream()->BlockHostUntilDone()); } - return return_values; + return std::move(return_values); } Status Executable::DumpSessionModule() { diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 2135707371809f119f0ed427f250ea500f786d3c..23864dda78fa9e9aeefc44c5aa018686e998a558 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -44,8 +44,15 @@ namespace xla { // interface that is used for launching compiled programs across platforms. class Executable { public: - explicit Executable(std::unique_ptr hlo_module) - : hlo_module_(std::move(hlo_module)) {} + explicit Executable(std::unique_ptr hlo_module, + std::unique_ptr hlo_profile_printer, + std::unique_ptr hlo_profile_index_map) + : hlo_module_(std::move(hlo_module)), + hlo_profile_printer_(std::move(hlo_profile_printer)), + hlo_profile_index_map_(std::move(hlo_profile_index_map)) { + CHECK_EQ(hlo_profile_printer_.get() == nullptr, + hlo_profile_index_map_.get() == nullptr); + } virtual ~Executable() {} // Enqueues the compilation result on the provided stream, passing the given @@ -54,16 +61,7 @@ class Executable { // If the hlo_execution_profile is provided as non-nullptr, profiling will be // enabled. // - // Returns the device memory region that a successful execution would - // populate. - virtual StatusOr ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments, - HloExecutionProfile* hlo_execution_profile) = 0; - - // Overload of ExecuteOnStream which returns and takes arguments as - // ShapedBuffers. Used for LocalService execution. + // Returns a shaped buffer containing the result of the computation. virtual StatusOr> ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, @@ -71,21 +69,19 @@ class Executable { // Same as ExecuteOnStream(), but this call is non-blocking and returns as // soon as all of the operations are enqueued for launch on the stream. - virtual StatusOr ExecuteAsyncOnStream( + virtual StatusOr> ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments) = 0; + tensorflow::gtl::ArraySlice arguments) = 0; // Same as ExecuteOnStream(), but runs this executable on multiple // streams. arguments[i] contains the arguments to the execution on // run_options[i]->stream() and the returned value is at index i of the // returned vector. - virtual StatusOr> - ExecuteOnStreams( + virtual StatusOr>> ExecuteOnStreams( tensorflow::gtl::ArraySlice run_options, tensorflow::gtl::ArraySlice< - tensorflow::gtl::ArraySlice> + tensorflow::gtl::ArraySlice> arguments); // Populates `hlo_execution_profile` from `executor`. This is implicit in any @@ -123,12 +119,20 @@ class Executable { "Equality test on this executable is not implemented."); } + const HloProfilePrinter& hlo_profile_printer() const { + CHECK(hlo_profiling_enabled()); + return *hlo_profile_printer_; + } + + const HloProfileIndexMap& hlo_profile_index_map() const { + CHECK(hlo_profiling_enabled()); + return *hlo_profile_index_map_; + } + // Returns whether this executable was compiled with HLO profilings support // enabled. If not, the caller should not expect an hlo_execution_profile // passed to ExecuteOnStream above to be populated during execution. - bool hlo_profiling_enabled() const { - return hlo_module_->config().hlo_profiling_enabled(); - } + bool hlo_profiling_enabled() const { return hlo_profile_printer_ != nullptr; } const HloModule& module() const { return *hlo_module_; } @@ -160,10 +164,6 @@ class Executable { static Status DumpToDirectory(const string& directory_path, string filename, const SessionModule& session_module); - // Returns a cost analysis object appropriate for the platform on which this - // executable can run. - virtual std::unique_ptr CreateCostAnalysis() const = 0; - protected: mutable tensorflow::mutex mutex_; @@ -181,6 +181,9 @@ class Executable { // Execution count, used to generate a unique filename for each dumped // execution. int64 execution_count_ = 0; + + std::unique_ptr hlo_profile_printer_; + std::unique_ptr hlo_profile_index_map_; }; template @@ -200,7 +203,8 @@ StatusOr Executable::ExecuteOnStreamWrapper( std::unique_ptr profile_ptr = module_config().debug_options().xla_hlo_profile() && hlo_profiling_enabled() - ? MakeUnique(module(), *CreateCostAnalysis()) + ? MakeUnique(&hlo_profile_printer(), + &hlo_profile_index_map()) : nullptr; auto return_value = @@ -208,14 +212,19 @@ StatusOr Executable::ExecuteOnStreamWrapper( if (profile != nullptr) { VLOG(1) << "enqueueing 'stop timer' and blocking host until done..."; - stream->ThenStopTimer(timer.get()).BlockHostUntilDone(); + stream->ThenStopTimer(timer.get()); + TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); VLOG(1) << "done with block-host-until-done"; // Merge in run-time profile information from execution_profile. profile->MergeFrom(execution_profile()); // Overall execution time (in nanoseconds) from the executor timer. - profile->set_compute_and_transfer_time_ns(timer->Nanoseconds()); + if (stream->ok()) { + // Don't read timer->Nanoseconds() if the stream isn't OK -- that's + // illegal. + profile->set_compute_and_transfer_time_ns(timer->Nanoseconds()); + } // TODO(b/28123297): On GPU we end up including transfer time in // the compute time this way. Instead, we should get the correct diff --git a/tensorflow/compiler/xla/service/execution_tracker.cc b/tensorflow/compiler/xla/service/execution_tracker.cc index c225e62e3e11d2d01251b0f92272b0949eff8af1..2f0b9ed2bd98fbea4e67c0a30d5aa41ff6a06979 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.cc +++ b/tensorflow/compiler/xla/service/execution_tracker.cc @@ -39,9 +39,7 @@ AsyncExecution::AsyncExecution(Backend* backend, tensorflow::Status AsyncExecution::BlockUntilDone() const { for (auto& stream : streams_) { - if (!stream->BlockHostUntilDone()) { - return InternalError("failed to block until done"); - } + TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); } return tensorflow::Status::OK(); } diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.cc b/tensorflow/compiler/xla/service/flatten_call_graph.cc index dfba22a6c4c5cf071c2cd8621643b8da6587ee3b..2b6caa149439a86d6d047605099bc3ff7b295a8e 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph.cc @@ -26,7 +26,10 @@ namespace xla { namespace { -// Helper to replace the called computation at a while- or call-instruction. +// Helper to replace the called computation at a while-, call-, or +// conditional-instruction. This function replaces exactly one instance of +// 'computation' with 'new_computation' even if 'instruction' calls +// 'computation' more than once. void ReplaceCalledComputation(HloInstruction* instruction, HloComputation* computation, HloComputation* new_computation) { @@ -45,6 +48,15 @@ void ReplaceCalledComputation(HloInstruction* instruction, instruction->set_to_apply(new_computation); break; } + case HloOpcode::kConditional: { + if (computation == instruction->true_computation()) { + instruction->set_true_computation(new_computation); + } else { + CHECK_EQ(computation, instruction->false_computation()); + instruction->set_false_computation(new_computation); + } + break; + } default: LOG(FATAL) << "unexpected opcode: " << HloOpcodeString(instruction->opcode()); diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc index a68e90b7d009890012f94baa790d911871c9c960..d3854b40de3572a60df1ad99d8a4589f59ad7194 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc @@ -223,5 +223,35 @@ TEST_F(FlattenCallGraphTest, FlattenCalls) { EXPECT_EQ(1, b_node.caller_callsites().size()); } +TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) { + auto module = CreateNewModule(); + HloComputation* sub_computation = + module->AddEmbeddedComputation(MakeScalarComputation()); + + // Create entry computation, which is a conditional that has the same + // computation in the true and false branch. + HloComputation::Builder builder(TestName()); + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(56.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(12.0f))); + builder.AddInstruction(HloInstruction::CreateConditional( + kScalarShape, pred, constant1, sub_computation, constant2, + sub_computation)); + module->AddEntryComputation(builder.Build()); + EXPECT_EQ(2, module->computation_count()); + + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); + EXPECT_TRUE(result); + std::unique_ptr call_graph = CallGraph::Build(module.get()); + // The true and false computations must now be different. + EXPECT_EQ(3, module->computation_count()); + + const CallGraphNode& sub_node = call_graph->GetNode(sub_computation); + EXPECT_EQ(1, sub_node.caller_callsites().size()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index 74aa77b4f165be76fbc0a8aa1a4a7e90a8e9acec..271a856efd66f9f977ac4e201161ba4b505f31e1 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -51,83 +51,7 @@ se::Platform::Id GenericTransferManager::PlatformId() const { return platform_id_; } -Status GenericTransferManager::TransferLiteralFromDevice( - se::StreamExecutor* executor, const se::DeviceMemoryBase& source, - const Shape& device_shape, const Shape& literal_shape, Literal* literal) { - VLOG(2) << "transferring literal shape from device: " - << ShapeUtil::HumanString(literal_shape) - << "; device location: " << source.opaque(); - TF_RET_CHECK(ShapeUtil::Compatible(device_shape, literal_shape)); - - // Tuples are a special case and contain one or more shapes inside of them to - // an arbitrary nesting depth. - if (device_shape.element_type() == TUPLE) { - *literal->mutable_shape() = literal_shape; - TF_ASSIGN_OR_RETURN( - std::vector element_buffers, - ShallowCopyTupleFromDevice(executor, source, device_shape)); - TF_RET_CHECK(element_buffers.size() == - ShapeUtil::TupleElementCount(device_shape)); - for (int64 i = 0; i < element_buffers.size(); ++i) { - const Shape& element_device_shape = device_shape.tuple_shapes(i); - const Shape& element_literal_shape = literal_shape.tuple_shapes(i); - Literal* element_literal = literal->add_tuple_literals(); - // Recursively call TransferFromDevice to copy over the data in the - // element array. - TF_RETURN_IF_ERROR(TransferLiteralFromDevice( - executor, element_buffers[i], /*device_shape=*/element_device_shape, - /*literal_shape=*/element_literal_shape, element_literal)); - } - return Status::OK(); - } - - *literal->mutable_shape() = device_shape; - literal->Reserve(ShapeUtil::ElementsIn(device_shape)); - TF_RETURN_IF_ERROR(TransferBufferFromDevice( - executor, source, /*size=*/ShapeUtil::ByteSizeOf(device_shape), - /*destination=*/literal->MutableInternalData())); - if (!ShapeUtil::Equal(literal_shape, device_shape)) { - *literal = std::move(*literal->Relayout(literal_shape.layout())); - } - TF_RET_CHECK(ShapeUtil::Equal(literal_shape, literal->shape())); - return Status::OK(); -} - -StatusOr> -GenericTransferManager::ShallowCopyTupleFromDevice( - se::StreamExecutor* executor, const se::DeviceMemoryBase& source, - const Shape& shape) { - TF_RET_CHECK(ShapeUtil::IsTuple(shape)); - - // For devices which use the GenericTransferManager, a tuple is stored as an - // array of pointers to buffers. Copy the contents of the tuple buffer into - // a vector of void* pointers. - std::vector element_pointers(ShapeUtil::TupleElementCount(shape), - nullptr); - int64 tuple_size = ShapeUtil::ByteSizeOf(shape, pointer_size_); - auto copy_status = executor->SynchronousMemcpyD2H(source, tuple_size, - element_pointers.data()); - if (!copy_status.ok()) { - return AddStatus( - Status(static_cast(copy_status.code()), - copy_status.error_message()), - "failed transfer of tuple buffer " + ShapeUtil::HumanString(shape)); - } - - // Create a DeviceMemoryBase from each void* pointer. - std::vector destination; - for (size_t i = 0; i < element_pointers.size(); ++i) { - if (element_pointers[i] == nullptr && - !ShapeUtil::HasZeroElements(shape.tuple_shapes(i))) { - return FailedPrecondition("tuple contains nullptr at element %lu", i); - } - destination.emplace_back(element_pointers[i], - GetByteSizeRequirement(shape.tuple_shapes(i))); - } - return std::move(destination); -} - -Status GenericTransferManager::WriteTuplePointersToDevice( +Status GenericTransferManager::WriteSingleTupleIndexTable( perftools::gputools::StreamExecutor* executor, tensorflow::gtl::ArraySlice elements, const Shape& shape, perftools::gputools::DeviceMemoryBase* region) { @@ -145,16 +69,19 @@ StatusOr> GenericTransferManager::TransferLiteralFromDevice( se::StreamExecutor* executor, const ShapedBuffer& device_buffer) { VLOG(2) << "transferring literal from device ordinal " - << executor->device_ordinal() << "; device shape: " - << ShapeUtil::HumanStringWithLayout(device_buffer.shape()) - << "; opaque: " << device_buffer.buffer(/*index=*/{}).opaque(); + << executor->device_ordinal() << "; device buffer: " << device_buffer; TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); + // The on-host and on-device shape should always be the same for the generic + // transfer manager. + TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(), + device_buffer.on_host_shape())); + std::unique_ptr literal = - Literal::CreateFromShape(device_buffer.shape()); + Literal::CreateFromShape(device_buffer.on_host_shape()); TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( - device_buffer.shape(), + device_buffer.on_host_shape(), [&](const Shape& subshape, const ShapeIndex& index) -> Status { if (!ShapeUtil::IsTuple(subshape)) { TF_RETURN_IF_ERROR(TransferBufferFromDevice( @@ -175,16 +102,22 @@ Status GenericTransferManager::TransferLiteralToDevice( const ShapedBuffer& device_buffer) { const Shape& shape = literal.shape(); VLOG(2) << "transferring literal shape to device: " - << ShapeUtil::HumanString(shape) << "; device location: " - << device_buffer.buffer(/*index=*/{}).opaque(); + << ShapeUtil::HumanString(shape) + << "; device buffer: " << device_buffer; + + // The on-host and on-device shape should always be the same for the generic + // transfer manager. + TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(), + device_buffer.on_host_shape())); - TF_RET_CHECK(ShapeUtil::Compatible(literal.shape(), device_buffer.shape())); + TF_RET_CHECK( + ShapeUtil::Compatible(literal.shape(), device_buffer.on_host_shape())); TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); TF_RETURN_IF_ERROR(WriteTupleIndexTables(executor, device_buffer)); return ShapeUtil::ForEachSubshapeWithStatus( - device_buffer.shape(), + device_buffer.on_host_shape(), [&](const Shape& device_subshape, const ShapeIndex& index) -> Status { se::DeviceMemoryBase device_memory = device_buffer.buffer(index); if (ShapeUtil::IsArray(device_subshape)) { @@ -212,33 +145,6 @@ Status GenericTransferManager::TransferLiteralToDevice( }); } -Status GenericTransferManager::TransferLiteralToDevice( - se::StreamExecutor* executor, const Literal& literal, - se::DeviceMemoryBase* destination) { - const Shape& shape = literal.shape(); - VLOG(2) << "transferring literal shape to device: " - << ShapeUtil::HumanString(shape) - << "; device location: " << destination->opaque(); - - if (ShapeUtil::IsTuple(literal.shape())) { - std::vector tuple_elements_on_device; - for (const Literal& tuple_element : literal.tuple_literals()) { - se::DeviceMemoryBase allocation = executor->AllocateArray( - GetByteSizeRequirement(tuple_element.shape())); - TF_RETURN_IF_ERROR( - TransferLiteralToDevice(executor, tuple_element, &allocation)); - tuple_elements_on_device.push_back(allocation.opaque()); - } - return TransferBufferToDevice( - executor, tuple_elements_on_device.size() * sizeof(void*), - tuple_elements_on_device.data(), destination); - } - - return TransferBufferToDevice(executor, - /*size=*/GetByteSizeRequirement(shape), - /*source=*/literal.InternalData(), destination); -} - Status GenericTransferManager::TransferLiteralToInfeed( se::StreamExecutor* executor, const Literal& literal) { return Unimplemented("Generic transfer to Infeed"); diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 50dca6aec5012f0b02cb54846b622f008600e48e..63a7c820cf4e5fbbdf870086a4fb5316ac50d10b 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -42,16 +42,6 @@ class GenericTransferManager : public TransferManager { perftools::gputools::Platform::Id PlatformId() const override; - Status TransferLiteralFromDevice( - perftools::gputools::StreamExecutor* executor, - const perftools::gputools::DeviceMemoryBase& source, - const Shape& device_shape, const Shape& literal_shape, - Literal* literal) override; - - Status TransferLiteralToDevice( - perftools::gputools::StreamExecutor* executor, const Literal& literal, - perftools::gputools::DeviceMemoryBase* destination) override; - StatusOr> TransferLiteralFromDevice( perftools::gputools::StreamExecutor* executor, const ShapedBuffer& device_buffer) override; @@ -62,9 +52,6 @@ class GenericTransferManager : public TransferManager { Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor, const Literal& literal) override; - Status TransferBufferToInfeed(perftools::gputools::StreamExecutor* executor, - int64 size, const void* source) override; - Status TransferLiteralFromOutfeed( perftools::gputools::StreamExecutor* executor, const Shape& literal_shape, Literal* literal) override; @@ -73,16 +60,13 @@ class GenericTransferManager : public TransferManager { tensorflow::gtl::ArraySlice executors) override; - StatusOr> - ShallowCopyTupleFromDevice( - perftools::gputools::StreamExecutor* executor, - const perftools::gputools::DeviceMemoryBase& source, - const Shape& shape) override; - int64 GetByteSizeRequirement(const Shape& shape) const override; protected: - Status WriteTuplePointersToDevice( + Status TransferBufferToInfeed(perftools::gputools::StreamExecutor* executor, + int64 size, const void* source) override; + + Status WriteSingleTupleIndexTable( perftools::gputools::StreamExecutor* executor, tensorflow::gtl::ArraySlice elements, diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 364b76b93c288f13f2bf447cebfc25f705d77826..f673f0cbd079b2e3a7e783c02ab9d9af2f466b63 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -343,15 +343,16 @@ tf_cc_test( ) cc_library( - name = "copy_insertion", - srcs = ["copy_insertion.cc"], - hdrs = ["copy_insertion.h"], + name = "gpu_copy_insertion", + srcs = ["gpu_copy_insertion.cc"], + hdrs = ["gpu_copy_insertion.h"], deps = [ ":ir_emission_utils", + "//tensorflow/compiler/xla/service:call_graph", "//tensorflow/compiler/xla/service:copy_insertion", "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:logical_buffer", - "//tensorflow/compiler/xla/service:tuple_points_to_analysis", + "//tensorflow/compiler/xla/service:hlo_dataflow_analysis", + "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", ], ) @@ -427,14 +428,14 @@ cc_library( hdrs = ["gpu_compiler.h"], deps = [ ":convolution_folding", - ":copy_insertion", ":fusion_merger", + ":gpu_copy_insertion", ":gpu_executable", + ":gpu_layout_assignment", ":hlo_schedule", ":instruction_fusion", ":ir_emission_utils", ":ir_emitter", - ":layout_assignment", ":pad_insertion", ":partition_assignment", ":stream_assignment", @@ -444,10 +445,11 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:algebraic_simplifier", - "//tensorflow/compiler/xla/service:batchnorm_rewriter", + "//tensorflow/compiler/xla/service:batchnorm_expander", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", + "//tensorflow/compiler/xla/service:dot_decomposer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", @@ -491,9 +493,9 @@ cc_library( ) cc_library( - name = "layout_assignment", - srcs = ["layout_assignment.cc"], - hdrs = ["layout_assignment.h"], + name = "gpu_layout_assignment", + srcs = ["gpu_layout_assignment.cc"], + hdrs = ["gpu_layout_assignment.h"], deps = [ ":ir_emission_utils", "//tensorflow/compiler/xla:shape_util", @@ -507,10 +509,10 @@ cc_library( ) tf_cc_test( - name = "layout_assignment_test", - srcs = ["layout_assignment_test.cc"], + name = "gpu_layout_assignment_test", + srcs = ["gpu_layout_assignment_test.cc"], deps = [ - ":layout_assignment", + ":gpu_layout_assignment", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", @@ -574,11 +576,14 @@ tf_cc_test( deps = [ ":instruction_fusion", ":while_transformer", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/service:copy_insertion", + "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc b/tensorflow/compiler/xla/service/gpu/convolution_folding.cc index 5aaf072f9d2c95e2fff70a1c5337432a12a1aa48..b0626ca3bc9f843e513d4727932f0e2d5fa37748 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_folding.cc @@ -55,28 +55,20 @@ MatchBackwardFilter(HloInstruction* conv) { // v v // Convolution // conv - // | - // v - // Transpose (optional if identity transposition) CHECK_EQ(HloOpcode::kConvolution, conv->opcode()); - // If the forward convolution is followed by a transpose, we can fuse the - // transpose into the backward convolution as well. - HloInstruction* transpose = nullptr; - if (conv->user_count() == 1) { - HloInstruction* single_user = *conv->users().begin(); - if (single_user->opcode() == HloOpcode::kTranspose) { - transpose = single_user; - } - } // Step 2: match paddings and dimension numbers of the forward convolution. const ConvolutionDimensionNumbers& conv_dnums = conv->convolution_dimension_numbers(); auto input_batch_dim = conv_dnums.input_batch_dimension(); auto input_feature_dim = conv_dnums.input_feature_dimension(); + auto input_spatial_dims = conv_dnums.input_spatial_dimensions(); + auto kernel_input_feature_dim = conv_dnums.kernel_input_feature_dimension(); + auto kernel_output_feature_dim = conv_dnums.kernel_output_feature_dimension(); + auto kernel_spatial_dims = conv_dnums.kernel_spatial_dimensions(); auto output_batch_dim = conv_dnums.output_batch_dimension(); auto output_feature_dim = conv_dnums.output_feature_dimension(); - auto spatial_dims = conv_dnums.spatial_dimensions(); + auto output_spatial_dims = conv_dnums.output_spatial_dimensions(); for (const WindowDimension& window_dim : conv->window().dimensions()) { if (window_dim.stride() != 1) { @@ -95,9 +87,14 @@ MatchBackwardFilter(HloInstruction* conv) { VLOG(1) << "Padding low should be non-negative."; return no_match_result; } + if (window_dim.window_reversal()) { + VLOG(1) << "Window reversal field not supported"; + return no_match_result; + } // Padding high will be checked in Step 3. } - if (transpose == nullptr && !window_util::HasWindowDilation(conv->window())) { + if (input_batch_dim == output_batch_dim && + !window_util::HasWindowDilation(conv->window())) { VLOG(1) << conv->ToString() << " is a regular forward convolution. No need " "to fold it to a backward filter convolution."; @@ -108,11 +105,11 @@ MatchBackwardFilter(HloInstruction* conv) { // // Compute the window of the backward convolution. Window backward_conv_window; - for (int i = 0; i < spatial_dims.size(); ++i) { + for (int i = 0; i < input_spatial_dims.size(); ++i) { WindowDimension* dim = backward_conv_window.add_dimensions(); // The window size of the backward convolution equals the output size of the // forward convolution. - int64 filter_size = conv->shape().dimensions(spatial_dims[i]); + int64 filter_size = conv->shape().dimensions(output_spatial_dims[i]); dim->set_size(filter_size); // The window stride equals the window dilation of the forward convolution. dim->set_stride(conv->window().dimensions(i).window_dilation()); @@ -120,7 +117,8 @@ MatchBackwardFilter(HloInstruction* conv) { // activations. dim->set_padding_low(conv->window().dimensions(i).padding_low()); - int64 input_size = conv->operand(0)->shape().dimensions(spatial_dims[i]); + int64 input_size = + conv->operand(0)->shape().dimensions(input_spatial_dims[i]); int64 output_size = conv->window().dimensions(i).size(); // Compute the range of the amount of valid high padding. We first compute // min_padding_high, the amount of padding on the right/bottom to ensure the @@ -167,50 +165,32 @@ MatchBackwardFilter(HloInstruction* conv) { } } - // To make future HLO passes easier, we canonicalize the fused expression by - // adding an identity transposition if it's omitted in the pattern. - if (transpose == nullptr) { - // Create an identity transposition with the same rank as the forward - // convolution. - HloComputation* parent_computation = conv->parent(); - std::vector transpose_dimensions(ShapeUtil::Rank(conv->shape())); - std::iota(transpose_dimensions.begin(), transpose_dimensions.end(), 0); - transpose = - parent_computation->AddInstruction(HloInstruction::CreateTranspose( - conv->shape(), conv, transpose_dimensions)); - TF_CHECK_OK(conv->ReplaceAllUsesWith(transpose)); - } - // Restore the dimension numbers of the backward convolution from the forward // convolution. The two activation dimensions are reversed (batch and // feature). ConvolutionDimensionNumbers backward_conv_dnums; backward_conv_dnums.set_input_batch_dimension(input_feature_dim); backward_conv_dnums.set_input_feature_dimension(input_batch_dim); - backward_conv_dnums.set_output_batch_dimension(output_feature_dim); - backward_conv_dnums.set_output_feature_dimension(output_batch_dim); - for (int i = 0; i < spatial_dims.size(); ++i) { - backward_conv_dnums.add_spatial_dimensions(spatial_dims[i]); + for (int i = 0; i < input_spatial_dims.size(); ++i) { + backward_conv_dnums.add_input_spatial_dimensions(input_spatial_dims[i]); + } + backward_conv_dnums.set_output_batch_dimension(kernel_input_feature_dim); + backward_conv_dnums.set_output_feature_dimension(kernel_output_feature_dim); + for (int i = 0; i < kernel_spatial_dims.size(); ++i) { + backward_conv_dnums.add_output_spatial_dimensions(kernel_spatial_dims[i]); } // The dimension numbering of the output of the forward convolution (before // transposition) is the same as that of the activations (according to the // semantics of kConvolution). The batch dimension of the activations should // be treated as the input feature dimension, and the feature dimension should // be treated as the output feature. - // - // The output of the forward convolution needs to be transposed to fit into - // the dimension numbering of the weight gradients. This transposition maps - // dimension i to PositionInContainer(transpose->dimensions(), i). - backward_conv_dnums.set_kernel_input_feature_dimension( - PositionInContainer(transpose->dimensions(), output_batch_dim)); - backward_conv_dnums.set_kernel_output_feature_dimension( - PositionInContainer(transpose->dimensions(), output_feature_dim)); - for (int i = 0; i < spatial_dims.size(); ++i) { - backward_conv_dnums.add_kernel_spatial_dimensions( - PositionInContainer(transpose->dimensions(), spatial_dims[i])); + backward_conv_dnums.set_kernel_input_feature_dimension(output_batch_dim); + backward_conv_dnums.set_kernel_output_feature_dimension(output_feature_dim); + for (int i = 0; i < output_spatial_dims.size(); ++i) { + backward_conv_dnums.add_kernel_spatial_dimensions(output_spatial_dims[i]); } - return std::make_tuple(true, std::vector({transpose, conv}), + return std::make_tuple(true, std::vector({conv}), backward_conv_window, backward_conv_dnums); } @@ -270,14 +250,20 @@ MatchBackwardInput(HloInstruction* conv) { << " should have no window dilation."; return no_match_result; } + if (window_dim.window_reversal()) { + VLOG(1) << "Window reversal field not supported"; + return no_match_result; + } } - const auto& spatial_dims = dnums.spatial_dimensions(); - CHECK_EQ(conv->window().dimensions().size(), spatial_dims.size()); + const auto& input_spatial_dims = dnums.input_spatial_dimensions(); + const auto& output_spatial_dims = dnums.output_spatial_dimensions(); + CHECK_EQ(conv->window().dimensions().size(), input_spatial_dims.size()); + CHECK_EQ(output_spatial_dims.size(), input_spatial_dims.size()); const Window& old_window = conv->window(); Window new_window = old_window; - for (size_t i = 0; i < spatial_dims.size(); ++i) { + for (size_t i = 0; i < input_spatial_dims.size(); ++i) { // Restore backward convolution's padding config from the matched pattern. // See the comment in tensorflow/core/kernels/conv_grad_tuple_ops.cc // for how we convert backward input convolution to a variant of forward @@ -310,8 +296,9 @@ MatchBackwardInput(HloInstruction* conv) { // end at the border. The maximum amount (max_padding_high) equals // min_padding_high+stride-1 -- max_padding_high+1 would cause the output // size to change. - auto unpadded_input_size = conv->shape().dimensions(spatial_dims[i]); - auto output_size = conv->operand(0)->shape().dimensions(spatial_dims[i]); + auto unpadded_input_size = conv->shape().dimensions(output_spatial_dims[i]); + auto output_size = + conv->operand(0)->shape().dimensions(input_spatial_dims[i]); auto padded_input_size = kernel_size + dim->stride() * (output_size - 1); auto total_pad_size = padded_input_size - unpadded_input_size; auto min_padding_high = total_pad_size - backward_padding_low; diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc b/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc index 19b122ba0603b4ec08d73e05da4c2ae11a760553..34e6bdb117d47a3d7e1eb3bae5806e130e94ea79 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc @@ -46,23 +46,27 @@ class ConvolutionFoldingTest : public HloTestBase { // // TODO(jingyue): Add more tests on NCHW input order which TF also supports. tf_default_dnums_for_backward_filter_.set_input_batch_dimension(3); - tf_default_dnums_for_backward_filter_.set_output_batch_dimension(3); tf_default_dnums_for_backward_filter_.set_input_feature_dimension(0); - tf_default_dnums_for_backward_filter_.set_output_feature_dimension(0); - tf_default_dnums_for_backward_filter_.add_spatial_dimensions(1); - tf_default_dnums_for_backward_filter_.add_spatial_dimensions(2); + tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(1); + tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(2); tf_default_dnums_for_backward_filter_.set_kernel_input_feature_dimension(0); tf_default_dnums_for_backward_filter_.set_kernel_output_feature_dimension( 3); tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(1); tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(2); + tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(0); + tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(1); + tf_default_dnums_for_backward_filter_.set_output_batch_dimension(2); + tf_default_dnums_for_backward_filter_.set_output_feature_dimension(3); tf_default_dnums_for_backward_input_.set_input_batch_dimension(0); tf_default_dnums_for_backward_input_.set_output_batch_dimension(0); tf_default_dnums_for_backward_input_.set_input_feature_dimension(3); tf_default_dnums_for_backward_input_.set_output_feature_dimension(3); - tf_default_dnums_for_backward_input_.add_spatial_dimensions(1); - tf_default_dnums_for_backward_input_.add_spatial_dimensions(2); + tf_default_dnums_for_backward_input_.add_input_spatial_dimensions(1); + tf_default_dnums_for_backward_input_.add_output_spatial_dimensions(1); + tf_default_dnums_for_backward_input_.add_input_spatial_dimensions(2); + tf_default_dnums_for_backward_input_.add_output_spatial_dimensions(2); tf_default_dnums_for_backward_input_.set_kernel_input_feature_dimension(3); tf_default_dnums_for_backward_input_.set_kernel_output_feature_dimension(2); tf_default_dnums_for_backward_input_.add_kernel_spatial_dimensions(0); @@ -82,7 +86,7 @@ class ConvolutionFoldingTest : public HloTestBase { ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_; }; -TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithoutTranspose) { +TEST_F(ConvolutionFoldingTest, BackwardFilterConvolve) { HloComputation::Builder builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -132,7 +136,7 @@ TEST_F(ConvolutionFoldingTest, auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConvolution(module.get())); + EXPECT_TRUE(FoldConvolution(module.get())); } // Extracted from block35 training. @@ -151,13 +155,9 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedActivations) { conv_window.mutable_dimensions(i)->set_padding_low(1); conv_window.mutable_dimensions(i)->set_padding_high(1); } - HloInstruction* convolution = - builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients, - conv_window, tf_default_dnums_for_backward_filter_)); - - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {3, 3, 32, 32}), convolution, {1, 2, 3, 0})); + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients, + conv_window, tf_default_dnums_for_backward_filter_)); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -185,13 +185,9 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedGradients) { conv_window.mutable_dimensions(i)->set_padding_high(-1); conv_window.mutable_dimensions(i)->set_window_dilation(2); } - HloInstruction* convolution = - builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients, - conv_window, tf_default_dnums_for_backward_filter_)); - - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), convolution, {1, 2, 3, 0})); + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients, + conv_window, tf_default_dnums_for_backward_filter_)); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -218,13 +214,9 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithUnevenPadding) { // Uneven padding: padding_low=0, padding_high=1 conv_window.mutable_dimensions(i)->set_padding_high(1); } - HloInstruction* convolution = - builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients, - conv_window, tf_default_dnums_for_backward_filter_)); - - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {2, 2, 32, 32}), convolution, {1, 2, 3, 0})); + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients, + conv_window, tf_default_dnums_for_backward_filter_)); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -258,8 +250,10 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveEvenPadding) { conv_dnums.set_output_batch_dimension(0); conv_dnums.set_input_feature_dimension(1); conv_dnums.set_output_feature_dimension(1); - conv_dnums.add_spatial_dimensions(2); - conv_dnums.add_spatial_dimensions(3); + conv_dnums.add_input_spatial_dimensions(2); + conv_dnums.add_output_spatial_dimensions(2); + conv_dnums.add_input_spatial_dimensions(3); + conv_dnums.add_output_spatial_dimensions(3); conv_dnums.set_kernel_input_feature_dimension(0); conv_dnums.set_kernel_output_feature_dimension(1); conv_dnums.add_kernel_spatial_dimensions(2); diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index e79d0a4c795c16a5c3298f69b3e3dcea55a97b9c..899cc5c83b99f1bb6154f883ca17871863e1f457 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -29,12 +29,12 @@ namespace se = ::perftools::gputools; namespace xla { namespace gpu { +using se::dnn::AlgorithmDesc; using se::dnn::BatchDescriptor; using se::dnn::ConvolutionDescriptor; using se::dnn::DataLayout; using se::dnn::FilterDescriptor; using se::dnn::FilterLayout; -using se::dnn::AlgorithmDesc; ConvolveScratchAllocator::ConvolveScratchAllocator( int device_ordinal, DeviceMemoryAllocator* memory_allocator) @@ -131,8 +131,9 @@ tensorflow::Status ConvolutionThunk::ExecuteOnStream( const int effective_num_dimensions = std::max(2, num_dimensions); CHECK_EQ(F32, output_shape_.element_type()); - CHECK_EQ(num_dimensions, dim_nums_.spatial_dimensions_size()); + CHECK_EQ(num_dimensions, dim_nums_.input_spatial_dimensions_size()); CHECK_EQ(num_dimensions, dim_nums_.kernel_spatial_dimensions_size()); + CHECK_EQ(num_dimensions, dim_nums_.output_spatial_dimensions_size()); for (const WindowDimension& dim : window_.dimensions()) { CHECK_EQ(dim.padding_low(), dim.padding_high()); } @@ -148,7 +149,7 @@ tensorflow::Status ConvolutionThunk::ExecuteOnStream( // Note that the dimensions are reversed. The same holds below. input_descriptor.set_spatial_dim( static_cast(effective_num_dimensions - dim - 1), - input_shape_.dimensions(dim_nums_.spatial_dimensions(dim))); + input_shape_.dimensions(dim_nums_.input_spatial_dimensions(dim))); } FilterDescriptor filter_descriptor(effective_num_dimensions); @@ -182,7 +183,7 @@ tensorflow::Status ConvolutionThunk::ExecuteOnStream( for (int dim = 0; dim < num_dimensions; ++dim) { output_descriptor.set_spatial_dim( static_cast(effective_num_dimensions - dim - 1), - output_shape_.dimensions(dim_nums_.spatial_dimensions(dim))); + output_shape_.dimensions(dim_nums_.output_spatial_dimensions(dim))); } // Add a singleton dimension in the 1D convolution case. @@ -258,22 +259,19 @@ tensorflow::Status ConvolutionThunk::Convolve( } std::vector ConvolutionThunk::GetAlgorithms( - se::StreamExecutor* stream_exec) const { + bool with_winograd_nonfused, se::StreamExecutor* stream_exec) const { std::vector algorithms; - // TODO(yangzihao): Currently disable the use of winograd nonfused in XLA - // by default. Should send in conv parameters and enable it when - // ShouldIncludeWinogradNonfusedAlgo() returns true. switch (convolution_kind_) { case ConvolutionKind::kBackwardFilter: CHECK(stream_exec->GetConvolveBackwardFilterAlgorithms( - /*with_winograd_nonfused=*/false, &algorithms)); + with_winograd_nonfused, &algorithms)); break; case ConvolutionKind::kBackwardInput: CHECK(stream_exec->GetConvolveBackwardDataAlgorithms( - /*with_winograd_nonfused=*/false, &algorithms)); + with_winograd_nonfused, &algorithms)); break; case ConvolutionKind::kForward: - CHECK(stream_exec->GetConvolveAlgorithms(/*with_winograd_nonfused=*/false, + CHECK(stream_exec->GetConvolveAlgorithms(with_winograd_nonfused, &algorithms)); break; } @@ -287,6 +285,26 @@ static string AlgorithmToString(const se::dnn::AlgorithmDesc& algo) { return tensorflow::strings::StrCat(algo.algo_id()); } +// Determines whether we can safely perform a winograd non-fused convolution for +// the given input and output descriptors. This works around b/68264959, an +// integer overflow in cuDNNv5 and cuDNNv6. +static bool ShouldIncludeWinogradNonfusedAlgo( + const BatchDescriptor& input_descriptor, + const BatchDescriptor& output_descriptor) { + int64 batch = input_descriptor.count(); + int64 in_depths = input_descriptor.feature_map_count(); + int64 in_rows = input_descriptor.height(); + int64 in_cols = input_descriptor.width(); + int64 out_depths = output_descriptor.feature_map_count(); + + int64 total_size = 16 * std::ceil(batch / 16.0) * + std::max(in_depths, out_depths) * in_cols * in_rows * + sizeof(float); + int64 threshold = 1L << 31; + + return total_size < threshold; +} + tensorflow::Status ConvolutionThunk::ConvolveWithTune( const BatchDescriptor& input_descriptor, se::DeviceMemory input_data, const FilterDescriptor& filter_descriptor, @@ -296,16 +314,22 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune( const ConvolutionDescriptor& convolution_descriptor, const BufferAllocations& buffer_allocations, se::Stream* stream) { // TODO(b/29126320): Try cudnn v5's new auto-tuner when it's rolled out. - if (best_algorithm_.algorithm().is_default()) { + if (!best_algorithm_.has_value()) { + best_algorithm_.emplace(); + // Auto-tuning either is disabled or only happens in the first run of this // function. VLOG(2) << "Profiling for best convolution algorithm used for " "ConvolutionThunk: " << this; + bool with_winograd_nonfused = + ShouldIncludeWinogradNonfusedAlgo(input_descriptor, output_descriptor); + se::dnn::ProfileResult best_result; se::dnn::ProfileResult best_result_without_scratch; - std::vector algorithms = GetAlgorithms(stream->parent()); + std::vector algorithms = + GetAlgorithms(with_winograd_nonfused, stream->parent()); for (auto algorithm : algorithms) { ConvolveScratchAllocator scratch_allocator( buffer_allocations.device_ordinal(), @@ -341,35 +365,35 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune( } if (best_result.is_valid()) { - best_algorithm_.set_algorithm(best_result.algorithm()); + best_algorithm_->set_algorithm(best_result.algorithm()); } else { LOG(ERROR) << "No convolution algorithm works with profiling. Fall back " "to the default algorithm."; - best_algorithm_.set_algorithm(AlgorithmDesc()); + best_algorithm_->set_algorithm(AlgorithmDesc()); } if (best_result_without_scratch.is_valid()) { - best_algorithm_.set_algorithm_no_scratch( + best_algorithm_->set_algorithm_no_scratch( best_result_without_scratch.algorithm()); } else { LOG(ERROR) << "No convolution algorithm without scratch works with " "profiling. Fall back " "to the default algorithm."; - best_algorithm_.set_algorithm_no_scratch(AlgorithmDesc()); + best_algorithm_->set_algorithm_no_scratch(AlgorithmDesc()); } } { VLOG(2) << "Using convolution algorithm (" - << AlgorithmToString(best_algorithm_.algorithm()) << ", " - << AlgorithmToString(best_algorithm_.algorithm_no_scratch()) + << AlgorithmToString(best_algorithm_->algorithm()) << ", " + << AlgorithmToString(best_algorithm_->algorithm_no_scratch()) << ") for ConvolutionThunk: " << this; ConvolveScratchAllocator scratch_allocator( buffer_allocations.device_ordinal(), buffer_allocations.memory_allocator()); return Convolve(input_descriptor, input_data, filter_descriptor, filter_data, output_descriptor, output_data, - convolution_descriptor, best_algorithm_, stream, + convolution_descriptor, *best_algorithm_, stream, &scratch_allocator, nullptr); } } diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index 13432301b2af34ab4bd0864e39ce22366cc1d11d..7c25a2e6450e30292667ecd7de54b50ac2450767 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -87,6 +88,14 @@ class ConvolutionThunk : public Thunk { const BufferAllocations& buffer_allocations, perftools::gputools::Stream* stream) override; + // Returns true if the next run of ExecuteOnStream will do autotuning. If so, + // we want the GPU to be quiescent during autotuning, so as not to introduce + // noise in our results. + bool ShouldHaltAllActivityBeforeRunning( + perftools::gputools::Stream*) override { + return !best_algorithm_.has_value(); + } + private: tensorflow::Status ConvolveWithTune( const perftools::gputools::dnn::BatchDescriptor& input_descriptor, @@ -116,13 +125,15 @@ class ConvolutionThunk : public Thunk { // Returns the convolve algorithms that can be used for this ConvolutionThunk. std::vector GetAlgorithms( + bool with_winograd_nonfused, perftools::gputools::StreamExecutor* stream_exec) const; // Fastest cuDNN convolution algorithm for this thunk learned from // auto-tuning. If auto-tuning is disabled or failed, best_algorithm_ is set - // to the default value indicating cuDNN's convolution will choose - // the best algorithm from some heuristics based on its parameters. - perftools::gputools::dnn::AlgorithmConfig best_algorithm_; + // to the default value, indicating cuDNN's convolution will choose the best + // algorithm from some heuristics based on its parameters. + tensorflow::gtl::optional + best_algorithm_; const ConvolutionKind convolution_kind_; diff --git a/tensorflow/compiler/xla/service/gpu/copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/copy_insertion.cc deleted file mode 100644 index 3dc85552015be67c20db9099704334c864b44b51..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/copy_insertion.cc +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/gpu/copy_insertion.h" - -#include -#include -#include - -#include "tensorflow/compiler/xla/service/copy_insertion.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/logical_buffer.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/logging.h" - -namespace xla { -namespace gpu { - -StatusOr GpuCopyInsertion::Run(HloModule* module) { - TF_ASSIGN_OR_RETURN(bool changed, CopyInsertion::Run(module)); - - TF_ASSIGN_OR_RETURN(auto points_to_analysis, - TuplePointsToAnalysis::Run(module)); - - // Make sure all operands of a library call are in memory instead of constants - // in IR. The top-level (index {}) of the points-to set of each operand - // indicates the source(s) of the array buffer. If any of these are constant, - // then add a copy to materialize the array. - HloComputation* computation = module->entry_computation(); - for (HloInstruction* hlo : computation->MakeInstructionPostOrder()) { - if (ImplementedAsLibraryCall(*hlo)) { - for (int64 i = 0; i < hlo->operand_count(); ++i) { - HloInstruction* operand = hlo->mutable_operand(i); - const PointsToSet& points_to = - points_to_analysis->GetPointsToSet(operand); - const auto& element = points_to.element(/*index=*/{}); - if (std::any_of(element.begin(), element.end(), - [](const LogicalBuffer* buffer_source) { - return buffer_source->instruction()->opcode() == - HloOpcode::kConstant; - })) { - TF_ASSIGN_OR_RETURN(HloInstruction * copy, - CopyInsertion::FindOrInsertCopy(operand)); - TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(i, copy)); - changed = true; - } - } - } - } - - return changed; -} - -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 6bf00cfb8a53723ae9608093480bf2eed10144dd..4b511cb4bb94addfae53d6b2e6d6f86d5b9afd84 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -135,10 +135,6 @@ StatusOr GpuElementalIrEmitter::EmitFloatBinaryOp( PrimitiveType rhs_input_type = op->operand(1)->shape().element_type(); PrimitiveType output_type = op->shape().element_type(); switch (op->opcode()) { - case HloOpcode::kAtan2: - return EmitLibdeviceMathCall("__nv_atan2", {lhs_value, rhs_value}, - {lhs_input_type, rhs_input_type}, - output_type); case HloOpcode::kRemainder: { return EmitLibdeviceMathCall("__nv_fmod", {lhs_value, rhs_value}, {lhs_input_type, rhs_input_type}, @@ -199,29 +195,50 @@ StatusOr GpuElementalIrEmitter::EmitErfcInv( return EmitLibdeviceMathCall("__nv_erfcinv", {value}, {prim_type}, prim_type); } +StatusOr GpuElementalIrEmitter::EmitLog( + PrimitiveType prim_type, llvm::Value* value) const { + return EmitLibdeviceMathCall("__nv_log", {value}, {prim_type}, prim_type); +} + +StatusOr GpuElementalIrEmitter::EmitSin( + PrimitiveType prim_type, llvm::Value* value) const { + return EmitLibdeviceMathCall("__nv_sin", {value}, {prim_type}, prim_type); +} + +StatusOr GpuElementalIrEmitter::EmitCos( + PrimitiveType prim_type, llvm::Value* value) const { + return EmitLibdeviceMathCall("__nv_cos", {value}, {prim_type}, prim_type); +} + +StatusOr GpuElementalIrEmitter::EmitExp( + PrimitiveType prim_type, llvm::Value* value) const { + return EmitLibdeviceMathCall("__nv_exp", {value}, {prim_type}, prim_type); +} + +StatusOr GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs) const { + return EmitLibdeviceMathCall("__nv_pow", {lhs, rhs}, {prim_type, prim_type}, + prim_type); +} + +StatusOr GpuElementalIrEmitter::EmitAtan2( + PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { + return EmitLibdeviceMathCall("__nv_atan2", {lhs, rhs}, {prim_type, prim_type}, + prim_type); +} + StatusOr GpuElementalIrEmitter::EmitFloatUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const { PrimitiveType input_type = op->operand(0)->shape().element_type(); PrimitiveType output_type = op->shape().element_type(); switch (op->opcode()) { - case HloOpcode::kExp: - return EmitLibdeviceMathCall("__nv_exp", {operand_value}, {input_type}, - output_type); case HloOpcode::kFloor: return EmitLibdeviceMathCall("__nv_floor", {operand_value}, {input_type}, output_type); case HloOpcode::kCeil: return EmitLibdeviceMathCall("__nv_ceil", {operand_value}, {input_type}, output_type); - case HloOpcode::kLog: - return EmitLibdeviceMathCall("__nv_log", {operand_value}, {input_type}, - output_type); - case HloOpcode::kCos: - return EmitLibdeviceMathCall("__nv_cos", {operand_value}, {input_type}, - output_type); - case HloOpcode::kSin: - return EmitLibdeviceMathCall("__nv_sin", {operand_value}, {input_type}, - output_type); case HloOpcode::kTanh: return EmitLibdeviceMathCall("__nv_tanh", {operand_value}, {input_type}, output_type); @@ -230,224 +247,6 @@ StatusOr GpuElementalIrEmitter::EmitFloatUnaryOp( } } -StatusOr GpuElementalIrEmitter::EmitComplexBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const { - PrimitiveType input_type = op->operand(0)->shape().element_type(); - TF_RET_CHECK(primitive_util::IsComplexType(input_type)); - PrimitiveType component_type = - primitive_util::ComplexComponentType(input_type); - switch (op->opcode()) { - case HloOpcode::kPower: { - // (a+bi)^(c+di) = - // (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)), - // where q = c*atan2(b,a)+0.5d*ln(a*a+b*b) - auto a = EmitExtractReal(lhs_value); - auto b = EmitExtractImag(lhs_value); - auto c = EmitExtractReal(rhs_value); - auto d = EmitExtractImag(rhs_value); - auto aa_p_bb = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a), - ir_builder_->CreateFMul(b, b)); - auto one_half = llvm::ConstantFP::get(a->getType(), 0.5); - auto half_c = ir_builder_->CreateFMul(one_half, c); - - TF_ASSIGN_OR_RETURN( - auto aa_p_bb_to_half_c, - EmitLibdeviceMathCall("__nv_pow", {aa_p_bb, half_c}, - {component_type, component_type}, - component_type)); - auto neg_d = ir_builder_->CreateFNeg(d); - TF_ASSIGN_OR_RETURN( - auto arg_lhs, EmitLibdeviceMathCall("__nv_atan2", {b, a}, - {component_type, component_type}, - component_type)); - auto neg_d_arg_lhs = ir_builder_->CreateFMul(neg_d, arg_lhs); - TF_ASSIGN_OR_RETURN( - auto e_to_neg_d_arg_lhs, - EmitLibdeviceMathCall("__nv_exp", {neg_d_arg_lhs}, {component_type}, - component_type)); - auto coeff = - ir_builder_->CreateFMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs); - TF_ASSIGN_OR_RETURN( - auto ln_aa_p_bb, - EmitLibdeviceMathCall("__nv_log", {aa_p_bb}, {component_type}, - component_type)); - auto half_d = ir_builder_->CreateFMul(one_half, d); - auto q = - ir_builder_->CreateFAdd(ir_builder_->CreateFMul(c, arg_lhs), - ir_builder_->CreateFMul(half_d, ln_aa_p_bb)); - TF_ASSIGN_OR_RETURN( - auto cos_q, EmitLibdeviceMathCall("__nv_cos", {q}, {component_type}, - component_type)); - TF_ASSIGN_OR_RETURN( - auto sin_q, EmitLibdeviceMathCall("__nv_sin", {q}, {component_type}, - component_type)); - return EmitComposeComplex(op, ir_builder_->CreateFMul(coeff, cos_q), - ir_builder_->CreateFMul(coeff, sin_q)); - } - default: - return ElementalIrEmitter::EmitComplexBinaryOp(op, lhs_value, rhs_value); - } -} - -StatusOr GpuElementalIrEmitter::EmitComplexUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const { - PrimitiveType input_type = op->operand(0)->shape().element_type(); - PrimitiveType component_type = - primitive_util::IsComplexType(input_type) - ? primitive_util::ComplexComponentType(input_type) - : input_type; - - switch (op->opcode()) { - case HloOpcode::kLog: { - // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a) - auto a = EmitExtractReal(operand_value); - auto b = EmitExtractImag(operand_value); - llvm::Type* llvm_ty = a->getType(); - auto sum_sq = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a), - ir_builder_->CreateFMul(b, b)); - TF_ASSIGN_OR_RETURN( - auto log_sum_sq, - EmitLibdeviceMathCall("__nv_log", {sum_sq}, {component_type}, - component_type)); - TF_ASSIGN_OR_RETURN( - auto angle, EmitLibdeviceMathCall("__nv_atan2", {b, a}, - {component_type, component_type}, - component_type)); - auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); - return EmitComposeComplex( - op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle); - } - case HloOpcode::kExp: { - // e^(a+bi) = e^a*(cos(b)+sin(b)i) - auto b = EmitExtractImag(operand_value); - TF_ASSIGN_OR_RETURN( - auto exp_a, - EmitLibdeviceMathCall("__nv_exp", {EmitExtractReal(operand_value)}, - {component_type}, component_type)); - TF_ASSIGN_OR_RETURN( - auto cos_b, EmitLibdeviceMathCall("__nv_cos", {b}, {component_type}, - component_type)); - TF_ASSIGN_OR_RETURN( - auto sin_b, EmitLibdeviceMathCall("__nv_sin", {b}, {component_type}, - component_type)); - return EmitComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), - ir_builder_->CreateFMul(exp_a, sin_b)); - } - case HloOpcode::kCos: { - // cos(a+bi) = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b)) - auto a = EmitExtractReal(operand_value); - auto llvm_ty = a->getType(); - TF_ASSIGN_OR_RETURN( - auto exp_b, - EmitLibdeviceMathCall("__nv_exp", {EmitExtractImag(operand_value)}, - {component_type}, component_type)); - TF_ASSIGN_OR_RETURN( - auto cos_a, EmitLibdeviceMathCall("__nv_cos", {a}, {component_type}, - component_type)); - TF_ASSIGN_OR_RETURN( - auto sin_a, EmitLibdeviceMathCall("__nv_sin", {a}, {component_type}, - component_type)); - auto half_exp_b = - ir_builder_->CreateFMul(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); - auto half_exp_neg_b = - ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); - return EmitComposeComplex( - op, - ir_builder_->CreateFMul( - cos_a, ir_builder_->CreateFAdd(half_exp_neg_b, half_exp_b)), - ir_builder_->CreateFMul( - sin_a, ir_builder_->CreateFSub(half_exp_neg_b, half_exp_b))); - } - - case HloOpcode::kSin: { - // sin(a+bi) = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b) - auto a = EmitExtractReal(operand_value); - auto llvm_ty = a->getType(); - TF_ASSIGN_OR_RETURN( - auto exp_b, - EmitLibdeviceMathCall("__nv_exp", {EmitExtractImag(operand_value)}, - {component_type}, component_type)); - TF_ASSIGN_OR_RETURN( - auto cos_a, EmitLibdeviceMathCall("__nv_cos", {a}, {component_type}, - component_type)); - TF_ASSIGN_OR_RETURN( - auto sin_a, EmitLibdeviceMathCall("__nv_sin", {a}, {component_type}, - component_type)); - auto half_exp_b = - ir_builder_->CreateFMul(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); - auto half_exp_neg_b = - ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); - return EmitComposeComplex( - op, - ir_builder_->CreateFMul( - sin_a, ir_builder_->CreateFAdd(half_exp_b, half_exp_neg_b)), - ir_builder_->CreateFMul( - cos_a, ir_builder_->CreateFSub(half_exp_b, half_exp_neg_b))); - } - case HloOpcode::kTanh: { - /* - tanh=(exp(x)-exp(-x)) / (exp(x)+exp(-x)) - e^(a+bi) = e^a*(cos(b)+sin(b)i) - so tanh=(((cos(b)+sin(b)i)e^a - (cos(-b)+sin(-b)i)e^-a)) / - (((cos(b)+sin(b)i)e^a + (cos(-b)+sin(-b)i)e^-a)) - cos(b)=cos(-b), sin(-b)=-sin(b) - so tanh=(((cos(b)+sin(b)i)e^a - (cos(b)-sin(b)i)e^-a)) / - (((cos(b)+sin(b)i)e^a + (cos(b)-sin(b)i)e^-a)) - =(cos(b)e^a+i*sin(b)e^a + cos(b)(-e^-a)+i*sin(b)e^-a) / - (cos(b)e^a+i*sin(b)e^a + cos(b)e^-a+i*sin(b)(-e^-a)) - =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) / - (cos(b)(e^a+e^-a) + i*sin(b)(e^a-e^-a)) - This is a complex division, so we can multiply by denom_conj/denom_conj - =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) * - (cos(b)(e^a+e^-a) - i*sin(b)(e^a-e^-a)) / - ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2) - =(cos(b)^2(e^(2a)-e^(-2a)) + sin(b)^2(e^(2a)-e^(-2a)) + - i*(cos(b)sin(b)(e^a+e^-a)^2 - cos(b)sin(b)(e^a-e^-a)^2)) / - ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2) - */ - auto a = EmitExtractReal(operand_value); - auto b = EmitExtractImag(operand_value); - TF_ASSIGN_OR_RETURN( - auto exp_a, EmitLibdeviceMathCall("__nv_exp", {a}, {component_type}, - component_type)); - TF_ASSIGN_OR_RETURN( - auto cos_b, EmitLibdeviceMathCall("__nv_cos", {b}, {component_type}, - component_type)); - TF_ASSIGN_OR_RETURN( - auto sin_b, EmitLibdeviceMathCall("__nv_sin", {b}, {component_type}, - component_type)); - auto exp_neg_a = ir_builder_->CreateFDiv( - llvm::ConstantFP::get(exp_a->getType(), 1), exp_a); - auto exp_2a_minus_exp_neg_2a = ir_builder_->CreateFSub( - ir_builder_->CreateFMul(exp_a, exp_a), - ir_builder_->CreateFMul(exp_neg_a, exp_neg_a)); - auto cos_b_sq = ir_builder_->CreateFMul(cos_b, cos_b); - auto sin_b_sq = ir_builder_->CreateFMul(sin_b, sin_b); - auto real_num = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(cos_b_sq, exp_2a_minus_exp_neg_2a), - ir_builder_->CreateFMul(sin_b_sq, exp_2a_minus_exp_neg_2a)); - auto cos_b_sin_b = ir_builder_->CreateFMul(cos_b, sin_b); - auto exp_a_plus_exp_neg_a = ir_builder_->CreateFAdd(exp_a, exp_neg_a); - auto exp_a_plus_exp_neg_a_sq = - ir_builder_->CreateFMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a); - auto exp_a_minus_exp_neg_a = ir_builder_->CreateFSub(exp_a, exp_neg_a); - auto exp_a_minus_exp_neg_a_sq = - ir_builder_->CreateFMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a); - auto imag_num = ir_builder_->CreateFMul( - cos_b_sin_b, ir_builder_->CreateFSub(exp_a_plus_exp_neg_a_sq, - exp_a_minus_exp_neg_a_sq)); - auto denom = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(cos_b_sq, exp_a_plus_exp_neg_a_sq), - ir_builder_->CreateFMul(sin_b_sq, exp_a_minus_exp_neg_a_sq)); - return EmitComposeComplex(op, ir_builder_->CreateFDiv(real_num, denom), - ir_builder_->CreateFDiv(imag_num, denom)); - } - default: - return ElementalIrEmitter::EmitComplexUnaryOp(op, operand_value); - } -} - llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( const string& callee_name, tensorflow::gtl::ArraySlice operands, diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index 6a537d015209bc507af36b13eeb5d69ce58d8fea..77d4569b1e8e398005e8f517ff086a77aedd382d 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -54,20 +54,31 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { StatusOr EmitFloatUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const override; - StatusOr EmitComplexUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const override; - StatusOr EmitFloatBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) const override; - StatusOr EmitComplexBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const override; - StatusOr EmitErfcInv(PrimitiveType prim_type, llvm::Value* value) const override; + StatusOr EmitLog(PrimitiveType prim_type, + llvm::Value* value) const override; + + StatusOr EmitSin(PrimitiveType prim_type, + llvm::Value* value) const override; + + StatusOr EmitCos(PrimitiveType prim_type, + llvm::Value* value) const override; + + StatusOr EmitExp(PrimitiveType prim_type, + llvm::Value* value) const override; + + StatusOr EmitPow(PrimitiveType prim_type, llvm::Value* lhs, + llvm::Value* rhs) const override; + + StatusOr EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, + llvm::Value* rhs) const override; + llvm::Value* EmitThreadId() const override; private: diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index e784046450ed1cca088770c65c786e80adda869f..8e3aebbc12b5e6d746700956b9743bc94db50167 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -264,9 +264,9 @@ tensorflow::Status GemmThunk::ExecuteOnStream( auto make_descriptor = [this](se::DeviceMemoryBase data, const Shape& shape, bool transpose) -> MatrixDescriptor { - bool is_row_major = shape.layout().minor_to_major(0) != 0; - bool layout_mismatch = shape.layout().minor_to_major(0) != - output_shape_.layout().minor_to_major(0); + bool is_row_major = LayoutUtil::Minor(shape.layout(), 0) != 0; + bool layout_mismatch = LayoutUtil::Minor(shape.layout(), 0) != + LayoutUtil::Minor(output_shape_.layout(), 0); return MatrixDescriptor(data, transpose ^ layout_mismatch, shape.dimensions(is_row_major), shape.dimensions(!is_row_major)); @@ -320,7 +320,7 @@ tensorflow::Status GemmThunk::ExecuteOnStream( }; bool launch_ok; - if (output_shape_.layout().minor_to_major(0) == 0) { + if (LayoutUtil::Minor(output_shape_.layout(), 0) == 0) { launch_ok = launch( lhs_descriptor, rhs_descriptor, MatrixDescriptor(output_data, false, output_num_rows, output_num_cols), diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h index 983cb872924f22be0dfad8aa9ad86f233b909c46..8c6a1f51a8a09ef78950dfe7e89994a3fe247f49 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h @@ -52,6 +52,15 @@ class GemmThunk : public Thunk { const BufferAllocations& buffer_allocations, perftools::gputools::Stream* stream) override; + // Returns true if we'll perform autotuning if run on the given stream. If + // so, we want the GPU to be quiescent during autotuning, so as not to + // introduce noise in our results. + bool ShouldHaltAllActivityBeforeRunning( + perftools::gputools::Stream* stream) override { + return autotune_results_.count( + stream->parent()->GetDeviceDescription().name()) != 0; + } + private: const BufferAllocation::Slice lhs_buffer_; const BufferAllocation::Slice rhs_buffer_; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 23fb308ec6b4ec363cfba318fa4e1236766069ae..fc3b299936779dc938a6777e7da7907a3b43a3be 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -27,21 +27,22 @@ limitations under the License. #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" -#include "tensorflow/compiler/xla/service/batchnorm_rewriter.h" +#include "tensorflow/compiler/xla/service/batchnorm_expander.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/gpu/convolution_folding.h" -#include "tensorflow/compiler/xla/service/gpu/copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" #include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" -#include "tensorflow/compiler/xla/service/gpu/layout_assignment.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" #include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" @@ -126,7 +127,7 @@ string GetLibdeviceDir(const string& config_cuda_data_dir) { // Runs optimization passes on the given HLO module. tensorflow::Status OptimizeHloModule( - HloModule* hlo_module, const se::DeviceDescription& device_desc, + HloModule* hlo_module, const HloCostAnalysis::ShapeSizeFunction& shape_size_function) { { HloPassPipeline pipeline("optimization"); @@ -137,15 +138,15 @@ tensorflow::Status OptimizeHloModule( // TODO(b/64094172): make Call work on GPU instead of inlining. pipeline.AddPass(); - + pipeline.AddPass(); { auto& pass = pipeline.AddPass>("simplification"); pass.AddInvariantChecker(shape_size_function); - // TODO(b/62764704): Do not rewrite on GPU, use cuDNN's BatchNorm APIs + // TODO(b/62764704): Do not expand on GPU, use cuDNN's BatchNorm APIs // instead. - pass.AddPass( + pass.AddPass( /*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true, @@ -224,9 +225,8 @@ tensorflow::Status PrepareHloModuleForIrEmitting( // (and sometime after) copy insertion, to avoid dead code from interfering // with the rewrites. pipeline.AddPass(); - pipeline.AddPass(); - pipeline.AddPass(); pipeline.AddPass(); + pipeline.AddPass(); return pipeline.Run(hlo_module).status(); } @@ -295,21 +295,26 @@ StatusOr> CompilePtx(const string& ptx, int cc_major, } // namespace GpuCompiler::GpuCompiler() - : pointer_size_(llvm::DataLayout(kDataLayout).getPointerSize()) {} + : pointer_size_(llvm::DataLayout(kDataLayout) + .getPointerSize(0 /* default address space */)) {} + +StatusOr> GpuCompiler::RunHloPasses( + std::unique_ptr module, se::StreamExecutor* /*stream_exec*/) { + XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses"); + Tracing::TraceMe annotation("HLO Transforms", module->name(), + /*is_expensive=*/true); + TF_RETURN_IF_ERROR(OptimizeHloModule(module.get(), ShapeSizeBytesFunction())); + return std::move(module); +} -StatusOr> GpuCompiler::Compile( +StatusOr> GpuCompiler::RunBackend( std::unique_ptr module, se::StreamExecutor* stream_exec) { + XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend"); + TF_RET_CHECK(stream_exec != nullptr); - { - Tracing::TraceMe annotation("HLO Transforms", module->name(), - /*is_expensive=*/true); - TF_RETURN_IF_ERROR(OptimizeHloModule(module.get(), - stream_exec->GetDeviceDescription(), - ShapeSizeBytesFunction())); - TF_RETURN_IF_ERROR( - PrepareHloModuleForIrEmitting(module.get(), ShapeSizeBytesFunction())); - } + TF_RETURN_IF_ERROR( + PrepareHloModuleForIrEmitting(module.get(), ShapeSizeBytesFunction())); llvm::LLVMContext llvm_context; std::string buffer; @@ -362,8 +367,11 @@ StatusOr> GpuCompiler::Compile( HloComputation* entry_computation = module->entry_computation(); IrEmitterUnnested ir_emitter(module->config(), entry_computation, &ir_emitter_context); - TF_RETURN_IF_ERROR( - entry_computation->root_instruction()->Accept(&ir_emitter)); + { + XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission"); + TF_RETURN_IF_ERROR( + entry_computation->root_instruction()->Accept(&ir_emitter)); + } if (user_pre_optimization_hook_) { TF_CHECK_OK(user_pre_optimization_hook_(llvm_module)); @@ -412,9 +420,12 @@ StatusOr> GpuCompiler::Compile( cc_minor = 0; } - TF_ASSIGN_OR_RETURN(string ptx, - CompileToPtx(&llvm_module, {cc_major, cc_minor}, - module->config(), libdevice_dir)); + string ptx; + { + XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - CompileToPtx"); + TF_ASSIGN_OR_RETURN(ptx, CompileToPtx(&llvm_module, {cc_major, cc_minor}, + module->config(), libdevice_dir)); + } if (!ir_dump_directory.empty()) { TF_RETURN_IF_ERROR(llvm_ir::DumpIRToDirectory( @@ -456,10 +467,20 @@ StatusOr> GpuCompiler::Compile( VLOG(2) << "Printing the thunk schedule..."; XLA_VLOG_LINES(2, thunk_schedule->ToString()); - auto* gpu_executable = - new GpuExecutable(ptx, cubin, {cc_major, cc_minor}, - std::move(thunk_schedule), std::move(module), - std::move(buffer_assignment), ShapeSizeBytesFunction()); + std::unique_ptr profile_index_map; + std::unique_ptr profile_printer; + + if (module->config().hlo_profiling_enabled()) { + HloCostAnalysis cost_analysis(ShapeSizeBytesFunction()); + profile_index_map = MakeUnique(*module); + profile_printer = + CreateHloProfilePrinter(*profile_index_map, cost_analysis); + } + + auto* gpu_executable = new GpuExecutable( + ptx, cubin, {cc_major, cc_minor}, std::move(thunk_schedule), + std::move(module), std::move(buffer_assignment), + std::move(profile_printer), std::move(profile_index_map)); if (embed_ir_in_executable) { DCHECK_NE("", ir_module_string_before_opt); gpu_executable->set_ir_module_string(ir_module_string_before_opt); @@ -470,6 +491,7 @@ StatusOr> GpuCompiler::Compile( std::vector GpuCompiler::CompilePtxOrGetCachedResult(const string& ptx, int cc_major, int cc_minor) { + XLA_SCOPED_LOGGING_TIMER("GpuCompiler::CompilePtxOrGetCachedResult"); Tracing::TraceMe annotation("PTX->CUBIN", /*is_expensive=*/true); bool inserted; decltype(compilation_cache_.begin()) iter; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h index fe5fce615fc1fbf12b14d626398b56dc7ece81e8..18e34340205b6f51497e26c45520799d21c55a46 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h @@ -49,7 +49,11 @@ class GpuCompiler : public LLVMCompiler { // stream_execs) using LLVMCompiler::Compile; - StatusOr> Compile( + StatusOr> RunHloPasses( + std::unique_ptr module, + perftools::gputools::StreamExecutor* stream_exec) override; + + StatusOr> RunBackend( std::unique_ptr module, perftools::gputools::StreamExecutor* stream_exec) override; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc new file mode 100644 index 0000000000000000000000000000000000000000..33d739b79d3664fec3586bbc924b7fa2e10d3256 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc @@ -0,0 +1,112 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/copy_insertion.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +namespace gpu { + +StatusOr GpuCopyInsertion::FindOrInsertCopy( + HloInstruction* hlo) { + HloInstruction*& copy = inserted_copies_[hlo]; + if (copy == nullptr) { + TF_ASSIGN_OR_RETURN(copy, hlo->parent()->DeepCopyInstruction(hlo)); + } + return copy; +} + +StatusOr GpuCopyInsertion::Run(HloModule* module) { + CopyInsertion generic_copy_insertion; + + TF_ASSIGN_OR_RETURN(bool changed, generic_copy_insertion.Run(module)); + + TF_ASSIGN_OR_RETURN(std::unique_ptr dataflow, + HloDataflowAnalysis::Run(module)); + + // Make sure all operands of a library call are in memory instead of constants + // in IR. + for (HloInstruction* hlo : + module->entry_computation()->MakeInstructionPostOrder()) { + if (ImplementedAsLibraryCall(*hlo)) { + for (int64 i = 0; i < hlo->operand_count(); ++i) { + HloInstruction* operand = hlo->mutable_operand(i); + TF_RET_CHECK(ShapeUtil::IsArray(operand->shape())); + const auto& values = dataflow->GetValueSet(operand).values(); + if (std::any_of(values.begin(), values.end(), + [](const HloValue* value) { + return value->defining_instruction()->opcode() == + HloOpcode::kConstant; + })) { + TF_ASSIGN_OR_RETURN(HloInstruction * copy, FindOrInsertCopy(operand)); + TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(i, copy)); + changed = true; + } + } + } + } + + // Init values of a while node cannot be constants. Insert copies for any + // constants found at the operand of a while. + tensorflow::gtl::FlatSet copied_constants; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kWhile) { + continue; + } + for (auto& pair : + dataflow->GetInstructionValueSet(instruction->operand(0))) { + const HloValueSet& value_set = pair.second; + for (const HloValue* value : value_set.values()) { + if (value->defining_instruction()->opcode() == + HloOpcode::kConstant && + !ContainsKey(copied_constants, value->defining_instruction())) { + HloInstruction* constant = value->defining_instruction(); + TF_ASSIGN_OR_RETURN(HloInstruction * copy, + FindOrInsertCopy(constant)); + TF_RETURN_IF_ERROR(constant->ReplaceAllUsesWith(copy)); + copied_constants.insert(constant); + changed = true; + } + } + } + } + } + + // The GPU backend needs additional copies added due to deficiencies in + // buffer assignment. + TF_ASSIGN_OR_RETURN(bool buffer_assignment_changed, + CopyInsertion::AddCopiesForBufferAssignment(module)); + + return changed || buffer_assignment_changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/copy_insertion.h b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h similarity index 56% rename from tensorflow/compiler/xla/service/gpu/copy_insertion.h rename to tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h index 11077dad2e5506eab4fa84d47ad13a26ed1c035a..4d77f337e6eb20f7d79acc0829fde26bbe443f25 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_insertion.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_COPY_INSERTION_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_COPY_INSERTION_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COPY_INSERTION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COPY_INSERTION_H_ -#include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" namespace xla { namespace gpu { @@ -25,12 +25,23 @@ namespace gpu { // Besides the modifications made by the generic xla::CopyInsertion, this // GPU-specific copy insertion also materializes operands of library calls by // inserting kCopy instructions. -class GpuCopyInsertion : public CopyInsertion { +class GpuCopyInsertion : public HloPassInterface { public: + tensorflow::StringPiece name() const override { return "copy-insertion"; } + StatusOr Run(HloModule* module) override; + + protected: + // Returns a copy of `hlo`. Looks in inserted_copies_ first to avoid making + // duplicate copies. + StatusOr FindOrInsertCopy(HloInstruction* hlo); + + // A map containing all copies inserted to materialize operands of library + // calls. The key is the copied instruction and the value is the copy. + tensorflow::gtl::FlatMap inserted_copies_; }; } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_COPY_INSERTION_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COPY_INSERTION_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index c6f23f9b0506186c4f76a887e6a540dafdd79962..366d87e9c30ed043b38c8e0cea889d5d90e7c8d9 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -69,7 +69,7 @@ class HloExecutionProfiler { ~HloExecutionProfiler() { if (do_profile_) { stream_->ThenStopTimer(execution_timer_.get()); - stream_->BlockHostUntilDone(); + stream_->BlockHostUntilDone().IgnoreError(); profile_->set_total_cycles_executed( *computation_, execution_timer_->Nanoseconds() * clock_rate_ghz_); } @@ -87,7 +87,7 @@ class HloExecutionProfiler { void FinishOperation(const HloInstruction* hlo_instruction) { if (do_profile_) { stream_->ThenStopTimer(per_op_timer_.get()); - stream_->BlockHostUntilDone(); + stream_->BlockHostUntilDone().IgnoreError(); profile_->SetCyclesTakenBy( hlo_instruction, per_op_timer_->Nanoseconds() * clock_rate_ghz_); } @@ -113,14 +113,15 @@ GpuExecutable::GpuExecutable( std::unique_ptr thunk_schedule, std::unique_ptr hlo_module, std::unique_ptr assignment, - HloCostAnalysis::ShapeSizeFunction shape_size_function) - : Executable(std::move(hlo_module)), + std::unique_ptr hlo_profile_printer, + std::unique_ptr hlo_profile_index_map) + : Executable(std::move(hlo_module), std::move(hlo_profile_printer), + std::move(hlo_profile_index_map)), ptx_(ptx), cubin_(cubin), compute_capability_(compute_capability), thunk_schedule_(std::move(thunk_schedule)), - assignment_(std::move(assignment)), - shape_size_function_(std::move(shape_size_function)) {} + assignment_(std::move(assignment)) {} Status GpuExecutable::ExecuteThunks( const ServiceExecutableRunOptions* run_options, @@ -166,9 +167,16 @@ Status GpuExecutable::ExecuteThunks( stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, dependency).get()); } + // If this thunk requests it, wait for all currently-executing thunks to + // finish. This is useful e.g. if the thunk is about to perform autotuning. + if (thunk->ShouldHaltAllActivityBeforeRunning(stream)) { + TF_RETURN_IF_ERROR(main_stream->BlockHostUntilDone()); + } + profiler.StartOperation(); VLOG(2) << "Executing the thunk for " - << thunk->hlo_instruction()->ToString(); + << thunk->hlo_instruction()->ToString() << " on stream " + << stream_no; TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream)); if (thunk_schedule_->Depended(thunk)) { auto finish_event = MakeUnique(main_stream->parent()); @@ -183,90 +191,16 @@ Status GpuExecutable::ExecuteThunks( // Make sure kernels are completed before deallocating temporary buffers. // TODO(b/30100571): we could potentially postpone deallocating the temp // buffers until a different computation is executed. - if (block_host_until_done && !main_stream->BlockHostUntilDone()) { - return InternalError("Failed to complete all kernels launched on stream %p", - main_stream); - } - - return Status::OK(); -} - -StatusOr GpuExecutable::ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - HloExecutionProfile* hlo_execution_profile) { - se::Stream* stream = run_options->stream(); - DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - - BufferAllocations::Builder buffer_allocations_builder; - for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); - ++i) { - const BufferAllocation& allocation = assignment_->GetAllocation(i); - if (allocation.is_entry_computation_parameter()) { - buffer_allocations_builder.RegisterBuffer( - i, arguments[allocation.parameter_number()]); + if (block_host_until_done) { + Status block_status = main_stream->BlockHostUntilDone(); + if (!block_status.ok()) { + return InternalError( + "Failed to complete all kernels launched on stream %p: %s", + main_stream, block_status.error_message().c_str()); } } - se::StreamExecutor* executor = stream->parent(); - TF_ASSIGN_OR_RETURN( - auto buffer_allocations, - buffer_allocations_builder.Build(*assignment_, executor->device_ordinal(), - memory_allocator)); - bool block_host_until_done = - !memory_allocator->AllowsAsynchronousDeallocation(); - TF_RETURN_IF_ERROR(ExecuteThunks(run_options, *buffer_allocations, - block_host_until_done, - hlo_execution_profile)); - - HloInstruction* root = hlo_module_->entry_computation()->root_instruction(); - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice output_slice, - assignment_->GetUniqueTopLevelOutputSlice()); - se::DeviceMemoryBase output_buffer_address = - buffer_allocations->GetDeviceAddress(output_slice.index()); - - if (ShapeUtil::IsTuple(root->shape())) { - std::set referred_by_output; - if (GetRootPointsToSet().IsAmbiguous()) { - // The points-to set of the root is ambiguous so we need to examine the - // result data to determine which buffers are contained in the result. - TF_ASSIGN_OR_RETURN( - TransferManager * transfer_manager, - TransferManager::GetForPlatform(executor->platform())); - TF_ASSIGN_OR_RETURN(referred_by_output, - transfer_manager->GatherBufferPointersFromTuple( - executor, output_buffer_address, root->shape())); - } else { - // The points-to set of the root is unambiguous so it's known statically - // which buffers are in the result. Gather these buffers using the root's - // points-to set. - TF_RETURN_IF_ERROR(GetRootPointsToSet().ForEachElementWithStatus( - [&referred_by_output, &buffer_allocations, this]( - const ShapeIndex& /*index*/, - const PointsToSet::BufferList& buffers) { - // The points to set is unambiguous so the set should be a - // singleton. That is, we know exactly which instruction produced - // the array at this element. - CHECK_EQ(1, buffers.size()); - HloInstruction* hlo = buffers[0]->instruction(); - TF_ASSIGN_OR_RETURN( - const BufferAllocation::Slice slice, - this->assignment_->GetUniqueSlice(hlo, buffers[0]->index())); - CHECK(!slice.allocation()->is_entry_computation_parameter()); - referred_by_output.insert( - buffer_allocations->GetDeviceAddress(slice.index())); - return Status::OK(); - })); - } - TF_RETURN_IF_ERROR( - buffer_allocations->TearDown(referred_by_output, *assignment_)); - } else { - // If the computation result is not a tuple, we can delete all temporary - // buffers that are not the output. - TF_RETURN_IF_ERROR( - buffer_allocations->TearDown({output_buffer_address}, *assignment_)); - } - return output_buffer_address; + return Status::OK(); } StatusOr> GpuExecutable::ExecuteOnStream( @@ -286,7 +220,7 @@ StatusOr> GpuExecutable::ExecuteOnStream( if (allocation.is_entry_computation_parameter()) { auto param_no = allocation.parameter_number(); buffer_allocations_builder.RegisterBuffer( - i, arguments[param_no]->buffer(/*index=*/{})); + i, arguments[param_no]->root_buffer()); } } se::StreamExecutor* executor = run_options->stream()->parent(); @@ -304,50 +238,46 @@ StatusOr> GpuExecutable::ExecuteOnStream( HloInstruction* root = hlo_module_->entry_computation()->root_instruction(); auto device_ordinal = executor->device_ordinal(); auto shaped_buffer = MakeUnique( - root->shape(), executor->platform(), device_ordinal); + root->shape(), root->shape(), executor->platform(), device_ordinal); // Copy DeviceMemoryBase values which contain the array(s) of the result into // the respective location in ShapedBuffer. std::set buffers_in_result; - TF_RETURN_IF_ERROR( - shaped_buffer->mutable_shape_index_to_buffer_entry() - ->ForEachMutableElementWithStatus( - [&buffer_allocations, &buffers_in_result, &shaped_buffer, this]( - const ShapeIndex& index, size_t* buffer_entry) { - const auto& sources = this->GetRootPointsToSet().element(index); - // The points-to set is unambiguous so the set should be a - // singleton. That is, we know exactly which instruction - // produced the array at this element. - CHECK_EQ(1, sources.size()); - auto src_hlo = sources[0]->instruction(); - - VLOG(4) << "Looking at: " << sources[0]; - - // The source instruction should have a non-parameter buffer - // assigned. - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, - this->assignment_->GetUniqueSlice( - src_hlo, sources[0]->index())); - CHECK(!slice.allocation()->is_entry_computation_parameter()); - - perftools::gputools::DeviceMemoryBase src_base = - buffer_allocations->GetDeviceAddress(slice.index()); - CHECK(!src_base.is_null() || src_base.size() == 0); - shaped_buffer->mutable_buffers()->push_back(src_base); - *buffer_entry = shaped_buffer->mutable_buffers()->size() - 1; - - buffers_in_result.insert(src_base); - return Status::OK(); - })); + TF_RETURN_IF_ERROR(shaped_buffer->buffers().ForEachMutableElementWithStatus( + [&buffer_allocations, &buffers_in_result, &shaped_buffer, this]( + const ShapeIndex& index, se::DeviceMemoryBase* device_memory) { + const auto& sources = this->GetRootPointsToSet().element(index); + // The points-to set is unambiguous so the set should be a + // singleton. That is, we know exactly which instruction + // produced the array at this element. + CHECK_EQ(1, sources.size()); + auto src_hlo = sources[0]->instruction(); + + VLOG(4) << "Looking at: " << sources[0]; + + // The source instruction should have a non-parameter buffer + // assigned. + TF_ASSIGN_OR_RETURN( + const BufferAllocation::Slice slice, + this->assignment_->GetUniqueSlice(src_hlo, sources[0]->index())); + CHECK(!slice.allocation()->is_entry_computation_parameter()); + + perftools::gputools::DeviceMemoryBase src_base = + buffer_allocations->GetDeviceAddress(slice.index()); + CHECK(!src_base.is_null() || src_base.size() == 0); + *device_memory = src_base; + buffers_in_result.insert(src_base); + return Status::OK(); + })); TF_RETURN_IF_ERROR( buffer_allocations->TearDown(buffers_in_result, *assignment_)); return std::move(shaped_buffer); } -StatusOr GpuExecutable::ExecuteAsyncOnStream( +StatusOr> GpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) { + tensorflow::gtl::ArraySlice arguments) { // TODO(b/30671675): Implement asynchronous execution mode. return Unimplemented( "Asynchronous execution on stream is not yet supported on GPU."); @@ -358,9 +288,5 @@ const PointsToSet& GpuExecutable::GetRootPointsToSet() const { module().entry_computation()->root_instruction()); } -std::unique_ptr GpuExecutable::CreateCostAnalysis() const { - return MakeUnique(shape_size_function_); -} - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index a3815370c19af1da612bc6d9663cc0f8896062f7..00da64dfade8ddb0694c0ee7ac158c9f2e15a508 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -54,7 +54,8 @@ class GpuExecutable : public Executable { std::unique_ptr thunk_schedule, std::unique_ptr hlo_module, std::unique_ptr assignment, - HloCostAnalysis::ShapeSizeFunction shape_size_function); + std::unique_ptr hlo_profile_printer, + std::unique_ptr hlo_profile_index_map); // This should be called after set_ir_module_string. const string& ir_module_string() const { return ir_module_string_; } @@ -71,32 +72,22 @@ class GpuExecutable : public Executable { // empty, in which case compilation is left up to the GPU driver. const std::vector& cubin() const { return cubin_; } - // Both overloads of ExecuteOnStream will fail if the compute capability of - // the stream doesn't match the compute capability passed to this object's - // constructor. - StatusOr ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments, - HloExecutionProfile* hlo_execution_profile) override; - + // ExecuteOnStream will fail if the compute capability of the stream doesn't + // match the compute capability passed to this object's constructor. StatusOr> ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) override; - StatusOr ExecuteAsyncOnStream( + StatusOr> ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments) override; + tensorflow::gtl::ArraySlice arguments) override; const Status EqualOrFail(const Executable& executable) { // TODO(b/62952745) Implement equality test on GPU executable. return Unimplemented("Equality test on GPU executable is not implemented."); } - std::unique_ptr CreateCostAnalysis() const override; - private: // If `block_host_until_done` is false, execution will not block the host // until the kernels have completed. This is used as an optimization for @@ -140,9 +131,6 @@ class GpuExecutable : public Executable { // memory for every output/temp buffers. const std::unique_ptr assignment_; - // Function to compute the size of a given Shape, in bytes. - const HloCostAnalysis::ShapeSizeFunction shape_size_function_; - TF_DISALLOW_COPY_AND_ASSIGN(GpuExecutable); }; diff --git a/tensorflow/compiler/xla/service/gpu/layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc similarity index 93% rename from tensorflow/compiler/xla/service/gpu/layout_assignment.cc rename to tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index 0bbd63fb7bfc657cb7bb1de673253c198f5bd25f..50a249f448e7b4956e7bf6bd603d256eca88f71d 100644 --- a/tensorflow/compiler/xla/service/gpu/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/layout_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" #include @@ -80,9 +80,9 @@ Status GpuLayoutAssignment::AddBackendConstraints( const ConvolutionDimensionNumbers& dimension_numbers = instruction->convolution_dimension_numbers(); std::vector input_layout; - for (int i = dimension_numbers.spatial_dimensions_size() - 1; i >= 0; - --i) { - input_layout.push_back(dimension_numbers.spatial_dimensions(i)); + for (int i = dimension_numbers.input_spatial_dimensions_size() - 1; + i >= 0; --i) { + input_layout.push_back(dimension_numbers.input_spatial_dimensions(i)); } input_layout.push_back(dimension_numbers.input_feature_dimension()); input_layout.push_back(dimension_numbers.input_batch_dimension()); @@ -102,9 +102,9 @@ Status GpuLayoutAssignment::AddBackendConstraints( *filter_shape.mutable_layout() = LayoutUtil::MakeLayout(filter_layout); std::vector output_layout; - for (int i = dimension_numbers.spatial_dimensions_size() - 1; i >= 0; - --i) { - output_layout.push_back(dimension_numbers.spatial_dimensions(i)); + for (int i = dimension_numbers.output_spatial_dimensions_size() - 1; + i >= 0; --i) { + output_layout.push_back(dimension_numbers.output_spatial_dimensions(i)); } output_layout.push_back(dimension_numbers.output_feature_dimension()); output_layout.push_back(dimension_numbers.output_batch_dimension()); diff --git a/tensorflow/compiler/xla/service/gpu/layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h similarity index 86% rename from tensorflow/compiler/xla/service/gpu/layout_assignment.h rename to tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h index 169041eb85c633cb4f1f679bcea127714828308f..7655a3ebf45f83c0125a4257baae7a7229ebdc6d 100644 --- a/tensorflow/compiler/xla/service/gpu/layout_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LAYOUT_ASSIGNMENT_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LAYOUT_ASSIGNMENT_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_ #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" @@ -38,4 +38,4 @@ class GpuLayoutAssignment : public LayoutAssignment { } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LAYOUT_ASSIGNMENT_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_ diff --git a/tensorflow/compiler/xla/service/gpu/layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc similarity index 97% rename from tensorflow/compiler/xla/service/gpu/layout_assignment_test.cc rename to tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index ac206b89d329d7e4ac91ee51162c9694f6899d78..f68b23c8ce969372a01ce77840e016d82ca5d2ed 100644 --- a/tensorflow/compiler/xla/service/gpu/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/layout_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/computation_layout.h" diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index f0f036f7f381db15b84db85d3efeec5d8141884e..ae92daef8882de2e7d64b69f68452061cb5507f2 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -44,7 +44,7 @@ GpuTransferManager::GpuTransferManager() : GenericTransferManager( se::cuda::kCudaPlatformId, /*pointer_size=*/llvm::DataLayout(gpu::GpuCompiler::kDataLayout) - .getPointerSize()) {} + .getPointerSize(0 /* default address space */)) {} Status GpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, const Literal& literal) { @@ -105,12 +105,13 @@ Status GpuTransferManager::EnqueueBuffersToInfeed( // infeed requests, blocking on the stream might be // heavy-handed. Figure out if finer-grained acknowledgement is // possible. - if (!stream->BlockHostUntilDone()) { + Status block_status = stream->BlockHostUntilDone(); + if (!block_status.ok()) { for (gpu::InfeedBuffer* b : buffers) { b->Done(); } - return InternalError("Failed to complete data transfer on stream %p", - stream); + return InternalError("Failed to complete data transfer on stream %p: %s", + stream, block_status.error_message().c_str()); } infeed_manager->EnqueueBuffers(buffers); diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc index e33e904692ca5ad41e17d2e165dbb40b6bd4aa33..2ac95ceb692447c7ac6dbbcd8b9a38876f7a77b6 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc @@ -30,9 +30,8 @@ InfeedThunk::InfeedThunk( tuple_element_buffers.end()), destination_buffer_(destination_buffer) {} -tensorflow::Status InfeedThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) { +Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) { VLOG(2) << "Infeeding to GPU "; perftools::gputools::DeviceMemoryBase destination_address = @@ -66,15 +65,16 @@ tensorflow::Status InfeedThunk::ExecuteOnStream( buffer->length()); } - if (!stream->BlockHostUntilDone()) { - return InternalError("Failed to complete data transfer on stream %p", - stream); + Status block_status = stream->BlockHostUntilDone(); + if (!block_status.ok()) { + return InternalError("Failed to complete data transfer on stream %p: %s", + stream, block_status.error_message().c_str()); } infeed_manager->ReleaseBuffers(infeed_buffers); VLOG(2) << "Infeeding to GPU complete"; - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h index 371d71f9dbdd21cb5f36cc3108c8f398a4a91c29..86918705fa0305217f11753e383200c7bd71474b 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h @@ -43,9 +43,8 @@ class InfeedThunk : public Thunk { InfeedThunk(const InfeedThunk&) = delete; InfeedThunk& operator=(const InfeedThunk&) = delete; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; private: const std::vector tuple_element_buffers_; diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 9a4bfd0905bb62c02c70e7f2eea46872c07bca89..1d47ffde4331868cbc8a8afb2d01b11e77a7fab0 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -156,8 +156,10 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfConvolutionUnfused) { conv_dnums.set_output_batch_dimension(0); conv_dnums.set_input_feature_dimension(1); conv_dnums.set_output_feature_dimension(1); - conv_dnums.add_spatial_dimensions(2); - conv_dnums.add_spatial_dimensions(3); + conv_dnums.add_input_spatial_dimensions(2); + conv_dnums.add_output_spatial_dimensions(2); + conv_dnums.add_input_spatial_dimensions(3); + conv_dnums.add_output_spatial_dimensions(3); conv_dnums.set_kernel_output_feature_dimension(0); conv_dnums.set_kernel_input_feature_dimension(1); conv_dnums.add_kernel_spatial_dimensions(2); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 8fb7a6adda9dc7c36eb9aabcbcdc9d77e6c22c4a..c04a7e0bf8fb5a4f4f73892bdef1b0b3e9879778 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -100,7 +100,7 @@ bool ImplementedAsDnnConvolution(const HloInstruction& hlo) { if (hlo.opcode() == HloOpcode::kConvolution) { const ConvolutionDimensionNumbers& dnums = hlo.convolution_dimension_numbers(); - if (dnums.spatial_dimensions_size() > 3) { + if (dnums.input_spatial_dimensions_size() > 3) { return false; } @@ -110,6 +110,10 @@ bool ImplementedAsDnnConvolution(const HloInstruction& hlo) { return false; } + if (window_util::HasWindowReversal(hlo.window())) { + return false; + } + return true; } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 6e2bd4e11d3c4ff576edb0df3b724abebfc0e424..e71aa0d13306c9d6571c5c26b0b6f430655df09f 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -173,7 +173,7 @@ Status IrEmitter::EmitCallToNestedComputation( return Status::OK(); } -bool IrEmitter::MaybeEmitSpecialAtomicOperation( +bool IrEmitter::MaybeEmitDirectAtomicOperation( const HloComputation& computation, llvm::Value* output_address, llvm::Value* source_address) { CHECK_EQ(2, computation.num_parameters()); @@ -233,102 +233,189 @@ bool IrEmitter::MaybeEmitSpecialAtomicOperation( return false; } -Status IrEmitter::EmitAtomicOperationForNestedComputation( - const HloComputation& computation, llvm::Value* output_address, - llvm::Value* source_address) { - if (computation.num_parameters() != 2) { - // TODO(b/30258929): We only accept binary computations so far. - return Unimplemented( - "We only support atomic functions with exactly two parameters, but " - "computation %s has %lld.", - computation.name().c_str(), computation.num_parameters()); - } +// Implements atomic binary operations using atomic compare-and-swap +// (atomicCAS) as follows: +// 1. Reads the value from the memory pointed to by output_address and +// records it as old_output. +// 2. Uses old_output as one of the source operand to perform the binary +// operation and stores the result in new_output. +// 3. Calls atomicCAS which implements compare-and-swap as an atomic +// operation. In particular, atomicCAS reads the value from the memory +// pointed to by output_address, and compares the value with old_output. If +// the two values equal, new_output is written to the same memory location +// and true is returned to indicate that the atomic operation succeeds. +// Otherwise, the new value read from the memory is returned. In this case, +// the new value is copied to old_output, and steps 2. and 3. are repeated +// until atomicCAS succeeds. +// +// On Nvidia GPUs, atomicCAS can only operate on 32 bit and 64 bit integers. If +// the element type of the binary operation is 32 bits or 64 bits, the integer +// type of the same size is used for the atomicCAS operation. On the other hand, +// if the element type is smaller than 32 bits, int32 is used for the atomicCAS +// operation. In this case, atomicCAS reads and writes 32 bit values from +// the memory, which is larger than the memory size required by the original +// atomic binary operation. We mask off the last two bits of the output_address +// and use the result as an address to read the 32 bit values from the memory. +// This can avoid out of bound memory accesses if tensor buffers are 4 byte +// aligned and have a size of 4N, an assumption that the runtime can guarantee. +// +// The pseudo code is shown below. Variables *_address are pointers to a memory +// region with a size equal to the size of the atomicCAS operation, with the +// exception that new_output_address is a pointer to a memory region with a size +// equal to the element size of the binary operation. +// +// element_size = sizeof(element_type); +// atomic_size = max(32, element_size); +// cas_new_output_address = alloca(atomic_size); +// cas_old_output_address = alloca(atomic_size); +// if (atomic_size != element_size) { +// atomic_address = output_address & ((int64)(-2)); +// new_output_address = cas_new_output_address + (output_address & 3); +// } else { +// atomic_address = output_address; +// new_output_address = cas_new_output_address; +// } +// +// *cas_old_output_address = *atomic_address; +// do { +// *cas_new_output_address = *cas_old_output_address; +// *new_output_address = operation(*new_output_address, *source_address); +// (*cas_old_output_address, success) = +// atomicCAS(atomic_address, *cas_old_output_address, +// *cas_new_output_address); +// } while (!success); +// +Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation, + llvm::Value* output_address, + llvm::Value* source_address) { + llvm::PointerType* output_address_type = + llvm::dyn_cast(output_address->getType()); + CHECK_NE(output_address_type, nullptr); + + // element_type is the data type for the binary operation. + llvm::Type* element_type = output_address_type->getPointerElementType(); + int element_size = llvm_ir::GetSizeInBits(element_type); + llvm::Type* element_address_type = element_type->getPointerTo(); + + int atomic_size = (element_size < 32) ? 32 : element_size; + llvm::Type* atomic_type = ir_builder_.getIntNTy(atomic_size); + llvm::Type* atomic_address_type = + atomic_type->getPointerTo(output_address_type->getPointerAddressSpace()); + + // cas_old_output_address and cas_new_output_address point to the scratch + // memory where we store the old and new values for the repeated atomicCAS + // operations. + llvm::Value* cas_old_output_address = ir_builder_.CreateAlloca( + atomic_type, /*ArraySize=*/nullptr, "cas_old_output_address"); + llvm::Value* cas_new_output_address = ir_builder_.CreateAlloca( + atomic_type, /*ArraySize=*/nullptr, "cas_new_output_address"); - if (MaybeEmitSpecialAtomicOperation(computation, output_address, - source_address)) { - return Status::OK(); - } - - // Other binary computations can be made atomic as following (labels are basic - // block names used in the IR emitting code later). - // - // atomic_op_loop_preheader: - // ... - // source = *source_address; - // old_output = *output_address; - // do { - // atomic_op_loop_body_entry: - // new_output = computation(old_output, source); - // (old_output, success) = - // atomicCAS(output_address, old_output, new_output); - // } while (!success); - // - // atomic_op_loop_exit: - // ... - // - // TODO(jingyue): Consider encapsulate the logic of emitting control flow to - // something similar to llvm_ir::ForLoop. - // // Emit preparation code to the preheader. llvm::BasicBlock* loop_preheader_bb = ir_builder_.GetInsertBlock(); - llvm::Type* element_ir_type = - output_address->getType()->getPointerElementType(); - // old_output = *output_address; - llvm::Value* old_output_location = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, "old_output_location"); - ir_builder_.CreateStore(ir_builder_.CreateLoad(output_address, "old_output"), - old_output_location); + + llvm::Value* atomic_memory_address; + // binop_output_address points to the scratch memory that stores the + // result of the binary operation. + llvm::Value* binop_output_address; + if (element_size < 32) { + // Assume the element size is an integer number of bytes. + CHECK_EQ((element_size % sizeof(char)), 0); + llvm::Type* address_int_type = + module_->getDataLayout().getIntPtrType(output_address_type); + atomic_memory_address = + ir_builder_.CreatePtrToInt(output_address, address_int_type); + llvm::Value* mask = llvm::ConstantInt::get(address_int_type, 3); + llvm::Value* offset = ir_builder_.CreateAnd(atomic_memory_address, mask); + mask = llvm::ConstantInt::get(address_int_type, -2); + atomic_memory_address = ir_builder_.CreateAnd(atomic_memory_address, mask); + atomic_memory_address = + ir_builder_.CreateIntToPtr(atomic_memory_address, atomic_address_type); + binop_output_address = ir_builder_.CreateAdd( + ir_builder_.CreatePtrToInt(cas_new_output_address, address_int_type), + offset); + binop_output_address = + ir_builder_.CreateIntToPtr(binop_output_address, element_address_type); + } else { + atomic_memory_address = + ir_builder_.CreateBitCast(output_address, atomic_address_type); + binop_output_address = + ir_builder_.CreateBitCast(cas_new_output_address, element_address_type); + } + + // Use the value from the memory that atomicCAS operates on to initialize + // cas_old_output. + llvm::Value* cas_old_output = + ir_builder_.CreateLoad(atomic_memory_address, "cas_old_output"); + ir_builder_.CreateStore(cas_old_output, cas_old_output_address); + llvm::BasicBlock* loop_exit_bb = loop_preheader_bb->splitBasicBlock( ir_builder_.GetInsertPoint(), "atomic_op_loop_exit"); - - // Emit the body of the loop that repeatedly invokes atomicCAS. llvm::BasicBlock* loop_body_bb = llvm::BasicBlock::Create(ir_builder_.getContext(), "atomic_op_loop_body", ir_builder_.GetInsertBlock()->getParent()); ir_builder_.SetInsertPoint(loop_body_bb); // Change preheader's successor from loop_exit_bb to loop_body_bb. loop_preheader_bb->getTerminator()->setSuccessor(0, loop_body_bb); - // new_output = computation(old_output, source); - llvm::Value* new_output_location = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, "new_output_location"); + + // Emit the body of the loop that repeatedly invokes atomicCAS. + // + // Use cas_old_output to initialize cas_new_output. + cas_old_output = + ir_builder_.CreateLoad(cas_old_output_address, "cas_old_output"); + ir_builder_.CreateStore(cas_old_output, cas_new_output_address); + // Emits code to calculate new_output = operation(old_output, source); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - computation, {old_output_location, source_address}, new_output_location)); - - // (old_output, success) = atomicCAS(output_address, old_output, new_output); - int num_bits = llvm_ir::GetSizeInBits(element_ir_type); - llvm::Type* element_int_ir_type = ir_builder_.getIntNTy(num_bits); - // cmpxchg accepts integer only, and bitcast refuses to operate on aggregate - // types, so we bitcast load and store addresses to intN* of the same bit - // width. - llvm::Value* old_output = ir_builder_.CreateLoad( - ir_builder_.CreateBitCast(old_output_location, - element_int_ir_type->getPointerTo()), - "old_output"); - llvm::Value* new_output = ir_builder_.CreateLoad( - ir_builder_.CreateBitCast(new_output_location, - element_int_ir_type->getPointerTo()), - "new_output"); + computation, {binop_output_address, source_address}, + binop_output_address)); + + llvm::Value* cas_new_output = + ir_builder_.CreateLoad(cas_new_output_address, "cas_new_output"); + + // Emit code to perform the atomicCAS operation + // (cas_old_output, success) = atomicCAS(memory_address, cas_old_output, + // cas_new_output); llvm::Value* ret_value = ir_builder_.CreateAtomicCmpXchg( - ir_builder_.CreateBitCast(output_address, - element_int_ir_type->getPointerTo()), - old_output, new_output, llvm::AtomicOrdering::SequentiallyConsistent, + atomic_memory_address, cas_old_output, cas_new_output, + llvm::AtomicOrdering::SequentiallyConsistent, llvm::AtomicOrdering::SequentiallyConsistent); - // cmpxchg returns a pair. The first element is the original value at - // output_address and the second element is whether the swap is successful. + + // Extract the memory value returned from atomicCAS and store it as + // cas_old_output. ir_builder_.CreateStore( - ir_builder_.CreateExtractValue(ret_value, 0, "old_output"), - ir_builder_.CreateBitCast(old_output_location, - element_int_ir_type->getPointerTo())); + ir_builder_.CreateExtractValue(ret_value, 0, "cas_old_output"), + cas_old_output_address); + // Extract the success bit returned from atomicCAS and generate a + // conditional branch on the success bit. ir_builder_.CreateCondBr( ir_builder_.CreateExtractValue(ret_value, 1, "success"), loop_exit_bb, loop_body_bb); - // Restore the insertion point to the exit basic block so that the caller of + // Set the insertion point to the exit basic block so that the caller of // this method can continue emitting code to the right place. SetToFirstInsertPoint(loop_exit_bb, &ir_builder_); return Status::OK(); } +Status IrEmitter::EmitAtomicOperationForNestedComputation( + const HloComputation& computation, llvm::Value* output_address, + llvm::Value* source_address) { + if (computation.num_parameters() != 2) { + // TODO(b/30258929): We only accept binary computations so far. + return Unimplemented( + "We only support atomic functions with exactly two parameters, but " + "computation %s has %lld.", + computation.name().c_str(), computation.num_parameters()); + } + + if (MaybeEmitDirectAtomicOperation(computation, output_address, + source_address)) { + return Status::OK(); + } + + return EmitAtomicOperationUsingCAS(computation, output_address, + source_address); +} + Status IrEmitter::HandleSelect(HloInstruction* select) { auto pred = select->operand(0); auto on_true = select->operand(1); @@ -640,6 +727,37 @@ Status IrEmitter::HandleRng(HloInstruction* random) { .EmitLoop(IrName(random)); } +Status IrEmitter::HandleConditional(HloInstruction* conditional) { + auto pred = conditional->operand(0); + auto true_arg = conditional->operand(1); + auto false_arg = conditional->operand(2); + + llvm::Value* conditional_result = GetBasePointer(*conditional); + + llvm::LoadInst* pred_value = ir_builder_.CreateLoad( + GetBasePointer(*pred), + llvm_ir::AsStringRef(IrName(conditional, "load_predicate_value"))); + llvm::Value* pred_cond = ir_builder_.CreateICmpNE( + pred_value, + llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), + llvm_ir::AsStringRef(IrName(conditional, "boolean_predicate"))); + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( + pred_cond, IrName(conditional, "if_then_else"), &ir_builder_); + + SetToFirstInsertPoint(if_data.true_block, &ir_builder_); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *conditional->true_computation(), {GetBasePointer(*true_arg)}, + conditional_result)); + + SetToFirstInsertPoint(if_data.false_block, &ir_builder_); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *conditional->false_computation(), {GetBasePointer(*false_arg)}, + conditional_result)); + + SetToFirstInsertPoint(if_data.after_block, &ir_builder_); + return Status::OK(); +} + llvm_ir::IrArray::Index IrEmitter::EmitOperandArrayLoopNest( const llvm_ir::IrArray& operand_array, int64 reduction_dimension, tensorflow::StringPiece name_suffix, llvm_ir::ForLoopNest* loop_nest) { @@ -648,8 +766,8 @@ llvm_ir::IrArray::Index IrEmitter::EmitOperandArrayLoopNest( // reduction dimension. std::vector dimensions; const Shape& shape = operand_array.GetShape(); - for (int i = shape.layout().minor_to_major_size() - 1; i >= 0; --i) { - int64 dimension = shape.layout().minor_to_major(i); + for (int i = 0; i < LayoutUtil::MinorToMajor(shape).size(); ++i) { + int64 dimension = LayoutUtil::Major(shape.layout(), i); if (dimension != reduction_dimension) { dimensions.push_back(dimension); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 9c01f5b7c72f429822300af28bfd5261150d33d1..08bbbe36c72872ba68104c8f328c2f602eb30fa8 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -95,6 +95,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleCall(HloInstruction* call) override; Status HandleCustomCall(HloInstruction* custom_call) override; Status HandleRng(HloInstruction* random) override; + Status HandleConditional(HloInstruction* conditional) override; Status FinishVisit(HloInstruction* root) override { return Status::OK(); } @@ -185,9 +186,16 @@ class IrEmitter : public DfsHloVisitorWithDefault { // be simply implemented using an LLVM atomic instruction. If "computation" is // one of this kind, emits code to do that and returns true; otherwise, // returns false. - bool MaybeEmitSpecialAtomicOperation(const HloComputation& computation, - llvm::Value* output_address, - llvm::Value* source_address); + bool MaybeEmitDirectAtomicOperation(const HloComputation& computation, + llvm::Value* output_address, + llvm::Value* source_address); + + // A helper method for EmitAtomicOperationForNestedComputation. It implements + // binary atomic operations using atomicCAS with special handling to support + // small data types. + Status EmitAtomicOperationUsingCAS(const HloComputation& computation, + llvm::Value* output_address, + llvm::Value* source_address); StatusOr ComputeNestedElement( const HloComputation& computation, @@ -227,6 +235,7 @@ class IrEmitterUnnested : public IrEmitter { // IrEmitterUnnested handles the following instructions differently from // IrEmitter. Status HandleCopy(HloInstruction* copy) override; + Status HandleConditional(HloInstruction* conditional) override; Status HandleConvolution(HloInstruction* convolution) override; Status HandleDot(HloInstruction* dot) override; Status HandleFusion(HloInstruction* fusion) override; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 1b863c9e3c51d6e757751154abd653cd1fdcb8a7..022c63de8db00dba8a626e76751113a3f9356537 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -123,10 +123,12 @@ void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk, llvm::ConstantInt* threads_per_block_ir_value = llvm::ConstantInt::get( llvm::IntegerType::get(llvm_context, /*NumBits=*/32), launch_dims.threads_per_block()); + // Our launch bounds are exact, so we can specify them as reqntidx rather than + // maxntidx. nvvm_annotations_node->addOperand(llvm::MDNode::get( llvm_context, {llvm::ConstantAsMetadata::get(ir_kernel), - llvm::MDString::get(llvm_context, "maxntidx"), + llvm::MDString::get(llvm_context, "reqntidx"), llvm::ConstantAsMetadata::get(threads_per_block_ir_value)})); } } // namespace @@ -246,6 +248,11 @@ Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) { } Status IrEmitterUnnested::HandleDot(HloInstruction* dot) { + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); + if (dnums.lhs_batch_dimensions_size() > 0 || + dnums.rhs_batch_dimensions_size() > 0) { + return Unimplemented("Dot with batch dimensions not implemented."); + } if (ImplementedAsGemm(*dot)) { thunk_sequence_->emplace_back(BuildGemmThunk(dot)); return Status::OK(); @@ -254,6 +261,11 @@ Status IrEmitterUnnested::HandleDot(HloInstruction* dot) { return IrEmitter::HandleDot(dot); } +Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) { + thunk_sequence_->push_back(BuildKernelThunk(conditional)); + return IrEmitter::HandleConditional(conditional); +} + Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) { if (ImplementedAsDnnConvolution(*convolution)) { thunk_sequence_->emplace_back(BuildConvolutionThunk(convolution)); @@ -421,10 +433,10 @@ std::tuple IsTranspose021(const Shape& a, const Shape& b) { CHECK(ShapeUtil::Compatible(a, b)); std::vector perm(a.dimensions().size()); { - std::vector layout_a(a.layout().minor_to_major().rbegin(), - a.layout().minor_to_major().rend()); - std::vector layout_b(b.layout().minor_to_major().rbegin(), - b.layout().minor_to_major().rend()); + auto layout_a_orig = LayoutUtil::MinorToMajor(a); + std::vector layout_a(layout_a_orig.rbegin(), layout_a_orig.rend()); + auto layout_b_orig = LayoutUtil::MinorToMajor(b); + std::vector layout_b(layout_b_orig.rbegin(), layout_b_orig.rend()); for (size_t i = 0; i < perm.size(); ++i) { perm[i] = PositionInContainer(layout_b, layout_a[i]); } @@ -800,9 +812,9 @@ Status IrEmitterUnnested::EmitColumnReduction( // normalized_input_shape to input_matrix_shape. const Shape normalized_input_shape = ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(input_shape); + auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape); const std::vector transpose_dimension_mapping( - input_shape.layout().minor_to_major().rbegin(), - input_shape.layout().minor_to_major().rend()); + input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); const Shape input_matrix_shape = ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( @@ -1043,9 +1055,9 @@ Status IrEmitterUnnested::EmitRowReduction( // normalized_input_shape to input_3d_tensor_shape. const Shape normalized_input_shape = ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(input_shape); + auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape); const std::vector transpose_dimension_mapping( - input_shape.layout().minor_to_major().rbegin(), - input_shape.layout().minor_to_major().rend()); + input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); const Shape input_3d_tensor_shape = ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( input_shape.element_type(), {depth, height, width}); @@ -1177,9 +1189,9 @@ Status IrEmitterUnnested::EmitReductionToVector( // whether another dimension is major or minor of them. std::sort(input_dims_to_keep.begin(), input_dims_to_keep.end(), [&input_shape](int64 dim_a, int64 dim_b) { - return PositionInContainer(input_shape.layout().minor_to_major(), + return PositionInContainer(LayoutUtil::MinorToMajor(input_shape), dim_a) < - PositionInContainer(input_shape.layout().minor_to_major(), + PositionInContainer(LayoutUtil::MinorToMajor(input_shape), dim_b); }); // Now, if output rank is at least 1, `input_dims_to_keep.front()` is @@ -1224,14 +1236,14 @@ Status IrEmitterUnnested::EmitReductionToVector( int64 width = 1; for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape); ++input_dim) { - if (PositionInContainer(input_shape.layout().minor_to_major(), + if (PositionInContainer(LayoutUtil::MinorToMajor(input_shape), input_dim) > - PositionInContainer(input_shape.layout().minor_to_major(), + PositionInContainer(LayoutUtil::MinorToMajor(input_shape), input_dims_to_keep.back())) { depth *= input_shape.dimensions(input_dim); - } else if (PositionInContainer(input_shape.layout().minor_to_major(), + } else if (PositionInContainer(LayoutUtil::MinorToMajor(input_shape), input_dim) < - PositionInContainer(input_shape.layout().minor_to_major(), + PositionInContainer(LayoutUtil::MinorToMajor(input_shape), input_dims_to_keep.front())) { width *= input_shape.dimensions(input_dim); } diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 1cb963be611de23cfb9fbb6eca639019208b3d7a..059943d48cd34b0ac487b91c3f3079ee3f761229 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -34,7 +34,7 @@ limitations under the License. #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Bitcode/BitcodeReader.h" #include "llvm/Bitcode/BitcodeWriter.h" -#include "llvm/CodeGen/CommandFlags.h" +#include "llvm/CodeGen/CommandFlags.def" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Module.h" @@ -77,7 +77,7 @@ static string GetLibdeviceFilename(const string& libdevice_dir_path, // Since CUDA 9.0, all GPU versions are included in a single file const char* unified_libdevice_filename = "libdevice.10.bc"; std::vector unified_libdevice_files; - const tensorflow::Status status = + const tensorflow::Status status = tensorflow::Env::Default()->GetMatchingPaths( tensorflow::io::JoinPath(libdevice_dir_path, unified_libdevice_filename), &unified_libdevice_files); @@ -492,9 +492,8 @@ StatusOr CompileToPtx(llvm::Module* module, tensorflow::port::Tracing::TraceMe annotation( "Compiling IR", llvm_ir::AsString(module->getName()), /*is_expensive=*/true); - ScopedLoggingTimer compilation_timer( - "Compile module " + llvm_ir::AsString(module->getName()), - /*vlog_level=*/2); + XLA_SCOPED_LOGGING_TIMER("Compile module " + + llvm_ir::AsString(module->getName())); TF_ASSIGN_OR_RETURN( ptx, CompileModuleToPtx(module, compute_capability, hlo_module_config, libdevice_dir_path)); diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index 9274e16a455fc1a958cee5101b6a9ef7ce619347..c29fee0879c02021fdc23ac0e02ab398cf40f99e 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -49,8 +49,8 @@ HloInstruction* MaybePaddedAndSlicedInput( // applies positive padding and dilation. PaddingConfig padding_config = MakeNoPaddingConfig(input->shape().dimensions_size()); - for (size_t i = 0; i < conv_dnums.spatial_dimensions().size(); ++i) { - int64 dim = conv_dnums.spatial_dimensions(i); + for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) { + int64 dim = conv_dnums.input_spatial_dimensions(i); padding_config.mutable_dimensions(dim)->set_edge_padding_low( std::max(0LL, conv_window.dimensions(i).padding_low())); padding_config.mutable_dimensions(dim)->set_edge_padding_high( @@ -81,8 +81,8 @@ HloInstruction* MaybePaddedAndSlicedInput( std::vector limit_indices(input->shape().dimensions().begin(), input->shape().dimensions().end()); std::vector strides(input->shape().dimensions_size(), 1); - for (size_t i = 0; i < conv_dnums.spatial_dimensions().size(); ++i) { - int64 dim = conv_dnums.spatial_dimensions(i); + for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) { + int64 dim = conv_dnums.input_spatial_dimensions(i); // If dimension "dim" has negative padding, increase the start index or // decrement the limit index by the amount of negative padding. start_indices[dim] += @@ -117,8 +117,8 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window, for (size_t i = 0; i < kernel->shape().dimensions_size(); ++i) { padding_config.add_dimensions(); } - for (size_t i = 0; i < conv_dnums.spatial_dimensions().size(); ++i) { - int64 dim = conv_dnums.spatial_dimensions(i); + for (size_t i = 0; i < conv_dnums.kernel_spatial_dimensions().size(); ++i) { + int64 dim = conv_dnums.kernel_spatial_dimensions(i); padding_config.mutable_dimensions(dim)->set_interior_padding( conv_window.dimensions(i).window_dilation() - 1); } @@ -202,8 +202,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // ABCD0 = Pad(ABCD, padding_high=1) // BackwardFilterConv(ABCD0, xyz, padding_low=pading_high=1) // We choose the lesser of padding_low and padding_high as the new padding. - HloInstruction* transpose = backward_conv->fused_expression_root(); - HloInstruction* forward_conv = transpose->mutable_operand(0); + HloInstruction* forward_conv = backward_conv->fused_expression_root(); HloInstruction* input = backward_conv->mutable_operand(0); Window new_forward_conv_window = forward_conv->window(); Window new_backward_conv_window = backward_conv->window(); @@ -229,7 +228,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // later. Therefore, the amount of new padding (low or high) is the minimum // of the amount of old padding low and old padding high. int64 new_conv_padding = std::min(padding_low, padding_high); - int64 dim = backward_conv_dnums.spatial_dimensions(i); + int64 dim = backward_conv_dnums.input_spatial_dimensions(i); input_padding_config.mutable_dimensions(dim)->set_edge_padding_low( padding_low - new_conv_padding); input_padding_config.mutable_dimensions(dim)->set_edge_padding_high( @@ -269,19 +268,10 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( .ConsumeValueOrDie(), padded_input, output, new_forward_conv_window, forward_conv_dnums)); - HloInstruction* new_transpose = - computation->AddInstruction(HloInstruction::CreateTranspose( - ShapeInference::InferTransposeShape(new_forward_conv->shape(), - transpose->dimensions()) - .ConsumeValueOrDie(), - new_forward_conv, transpose->dimensions())); - - // Fuse the new forward convolution and the new transpose to the new backward - // convolution. + // Fuse the new forward convolution to the new backward convolution. HloInstruction* new_backward_conv = computation->CreateFusionInstructionForBackwardConvolution( - {new_transpose, new_forward_conv}, - HloInstruction::FusionKind::kConvBackwardFilter, + {new_forward_conv}, HloInstruction::FusionKind::kConvBackwardFilter, new_backward_conv_window, backward_conv_dnums); VLOG(1) << "Canonicalizing backward filter conv"; @@ -369,12 +359,11 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( std::vector limit_indices( new_backward_conv->shape().dimensions().begin(), new_backward_conv->shape().dimensions().end()); - std::vector strides(new_backward_conv->shape().dimensions_size(), - 1LL); + std::vector strides(new_backward_conv->shape().dimensions_size(), 1LL); for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { int64 padding_low = backward_conv->window().dimensions(i).padding_low(); int64 padding_high = backward_conv->window().dimensions(i).padding_high(); - int64 dim = backward_conv_dnums.spatial_dimensions(i); + int64 dim = backward_conv_dnums.output_spatial_dimensions(i); if (padding_low > padding_high) { // If the amount of low padding (of the old backward convolution) is // larger, we internally pad the low end of the activations and slice diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc index d0d2deee24848184278e3e51dcaa3bb673b5fadc..6cf280df05496716a0780d61ded92efd9982734c 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc @@ -44,37 +44,41 @@ std::ostream& operator<<(std::ostream& out, // Calculates the launch dimensions used to invoke `hlo`. LaunchDimensions CalculateLaunchDimensions( - const Shape& shape, const se::DeviceDescription& device_desc, - PartitionStrategy partition_strategy) { - int64 warp_size = device_desc.threads_per_warp(); - + const Shape& shape, const se::DeviceDescription& device_desc) { int64 num_elements = ShapeUtil::ElementsIn(shape); if (num_elements <= 1) { return LaunchDimensions(); } - // Calculate the number of threads per block. - // Initialize threads_per_block as the threads-per-block limit. - int64 threads_per_block = device_desc.threads_per_block_limit(); - VLOG(2) << "Initial # of threads per block = " << threads_per_block; - - if (partition_strategy == PartitionStrategy::kLatency) { - // Limit the thread count to allow maximum number of registers per thread. - // TODO(b/28560520): We don't have to assume the emitted kernel will use up - // all the registers. We could use ptxas to examine the actual number of - // register used, and set the thread count accordingly. - int64 threads_per_block_limit_due_to_registers = - device_desc.registers_per_core_limit() / - device_desc.registers_per_thread_limit(); - CHECK_NE(0, threads_per_block_limit_due_to_registers); - if (threads_per_block_limit_due_to_registers < threads_per_block) { - threads_per_block = - // Make `threads_per_block` a multiple of warp size to use GPU - // efficiently. - warp_size * - std::max(1LL, threads_per_block_limit_due_to_registers / warp_size); - VLOG(2) << "Update # of threads per block due to register pressure = " - << threads_per_block; + // Since we don't do any inter-warp communication, we're free to choose any + // block size we want, subject to hardware constraints. We choose the + // smallest block size that allows the GPU to reach full occupancy (assuming + // the kernel uses sufficiently few registers). This gives us max performance + // when the kernel uses few registers, and lets us scale down gracefully as + // the kernel uses more registers. + // + // Specifically, we choose the number of threads per block such that + // + // * = + + auto threads_per_core = device_desc.threads_per_core_limit(); + auto blocks_per_core = device_desc.blocks_per_core_limit(); + int64 threads_per_block; + if (threads_per_core != 0 && blocks_per_core != 0) { + threads_per_block = device_desc.threads_per_core_limit() / + device_desc.blocks_per_core_limit(); + } else { + static std::atomic log_count{0}; + if (log_count.fetch_add(1) < 8) { + LOG(WARNING) << "Attempting to calculate launch dimensions for GPU " + "without full information about its capabilities. " + "StreamExecutor's PopulateDeviceDescription should be " + "updated for this device."; + } + threads_per_block = device_desc.threads_per_warp(); + if (threads_per_block == 0) { + // Fall back to *something* if we can't even get num threads per warp. + threads_per_block = 32; } } @@ -84,8 +88,6 @@ LaunchDimensions CalculateLaunchDimensions( << threads_per_block << ") because the latter is smaller."; } - // Calculate the block count. We copy the strategy used by Eigen: - // eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h int64 block_count = CeilOfRatio(num_elements, threads_per_block); VLOG(2) << tensorflow::strings::Printf( "Initialized the block count to ceil(# of elements / threads per " diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.h b/tensorflow/compiler/xla/service/gpu/partition_assignment.h index 8f7fce884acc93fd39510ad0826b819a6d9731a7..0bf463a6ef95d5a32784838c08ad239752fd1acf 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.h @@ -30,14 +30,6 @@ limitations under the License. namespace xla { namespace gpu { -enum class PartitionStrategy { - // Optimized for latency by allowing maximum number of registers per thread. - kLatency, - // Optimized for throughput. This may limit registers per thread and cause - // longer latency. - kThroughput -}; - // Encapsulates the launch dimensions of a kernel, e.g., the block count and the // number of threads per block. class LaunchDimensions { @@ -66,8 +58,7 @@ std::ostream& operator<<(std::ostream& out, LaunchDimensions CalculateLaunchDimensions( const Shape& shape, - const perftools::gputools::DeviceDescription& device_desc, - PartitionStrategy partition_strategy = PartitionStrategy::kLatency); + const perftools::gputools::DeviceDescription& device_desc); } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index 0ff27888ad72f8190400c22a9086d1965448662c..486ea7d7e1dad3f7f37d50565e176fbf567f5cc4 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -70,6 +70,19 @@ class Thunk { return tensorflow::Status::OK(); } + // Users of Thunk should call ShouldHaltAllActivityBeforeRunning(stream) + // before calling ExecuteOnStream(stream). If it returns true, it's the + // user's responsibility to wait for all activity on the GPU to finish before + // calling ExecuteOnStream. + // + // This value is not required to be constant for a given Thunk. For example, + // a Thunk that performs autotuning may return true for its first run and + // false thereafter. + virtual bool ShouldHaltAllActivityBeforeRunning( + perftools::gputools::Stream* /*stream*/) { + return false; + } + // Execute the kernel for the thunk on the given stream. This method must be // called after Initialize and can be called multiple times over Thunk's // lifetime. Stream argument must be non-null. diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc index 0d2412096abf7838b7b0e7617811c789f507a4a1..c21559af6d2e5dfb5aaf62afcdcaed514e0914c9 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc @@ -34,16 +34,14 @@ WhileThunk::WhileThunk( body_thunk_sequence_( MakeUnique(std::move(*body_thunk_sequence), hlo)) {} -tensorflow::Status WhileThunk::Initialize(const GpuExecutable& executable) { +Status WhileThunk::Initialize(const GpuExecutable& executable) { TF_RETURN_IF_ERROR(condition_thunk_sequence_->Initialize(executable)); TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable)); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status WhileThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) { - +Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) { perftools::gputools::DeviceMemoryBase condition_result_data = buffer_allocations.GetDeviceAddress(condition_result_buffer_index_); @@ -55,9 +53,11 @@ tensorflow::Status WhileThunk::ExecuteOnStream( // Copy the result of condition computation and break the loop if 'false'. bool condition_result; stream->ThenMemcpy(&condition_result, condition_result_data, sizeof(bool)); - if (!stream->BlockHostUntilDone()) { + Status block_status = stream->BlockHostUntilDone(); + if (!block_status.ok()) { return InternalError( - "Failed to complete all kernels launched on stream %p", stream); + "Failed to complete all kernels launched on stream %p: %s", stream, + block_status.error_message().c_str()); } if (!condition_result) { @@ -68,7 +68,7 @@ tensorflow::Status WhileThunk::ExecuteOnStream( TF_RETURN_IF_ERROR( body_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream)); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.h b/tensorflow/compiler/xla/service/gpu/while_thunk.h index 95ed5497cea4fa3ba5dcdc6762cbd53cec88339a..4c9f45de9e42494df58706d0a4a3eb0c4220b8b8 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.h @@ -45,10 +45,9 @@ class WhileThunk : public Thunk { WhileThunk(const WhileThunk&) = delete; WhileThunk& operator=(const WhileThunk&) = delete; - tensorflow::Status Initialize(const GpuExecutable& executable) override; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + Status Initialize(const GpuExecutable& executable) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; private: const BufferAllocation::Slice condition_result_buffer_index_; diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc index ccdd1717593e4fa7c1d1deb3f0f9ebfab1bf7209..ab94d7d5436e8edd12f68f7e0c395c53f303e6eb 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer.cc @@ -44,7 +44,7 @@ namespace { // // Parameter // | -// Const GetTupleElemet +// Const GetTupleElement // \ / // Add (root) // diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index 44188473d39088923c67216facab472a4e4ee09f..f16daa0b5481474e754c880ead1945297ca50168 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -17,9 +17,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { @@ -33,8 +36,6 @@ class WhileTransformerTest : public HloTestBase { : module_(CreateNewModule()), induction_variable_shape_(ShapeUtil::MakeShape(S32, {})), data_shape_(ShapeUtil::MakeShape(F32, {8})), - loop_state_shape_(ShapeUtil::MakeTupleShape( - {induction_variable_shape_, data_shape_})), condition_result_shape_(ShapeUtil::MakeShape(PRED, {})) {} std::unique_ptr BuildConditionComputation( @@ -42,8 +43,8 @@ class WhileTransformerTest : public HloTestBase { auto builder = HloComputation::Builder(TestName() + ".Condition"); auto limit_const = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(limit))); - auto loop_state = builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); + auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter( + 0, GetLoopStateShape(tuple_index), "loop_state")); auto induction_variable = builder.AddInstruction(HloInstruction::CreateGetTupleElement( limit_const->shape(), loop_state, tuple_index)); @@ -58,8 +59,8 @@ class WhileTransformerTest : public HloTestBase { const int64 increment) { auto builder = HloComputation::Builder(TestName() + ".Body"); // Create param instruction to access loop state. - auto loop_state = builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); + auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter( + 0, GetLoopStateShape(ind_var_tuple_index), "loop_state")); // Update the induction variable GTE(ind_var_tuple_index). auto induction_variable = builder.AddInstruction(HloInstruction::CreateGetTupleElement( @@ -73,7 +74,7 @@ class WhileTransformerTest : public HloTestBase { data_shape_, loop_state, data_tuple_index)); // Use 'induction_variable' in computation with no path to output tuple. auto update = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8})); + HloInstruction::CreateBroadcast(data_shape_, induction_variable, {})); auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( data_shape_, HloOpcode::kAdd, data, update)); // Create output Tuple. @@ -98,8 +99,9 @@ class WhileTransformerTest : public HloTestBase { HloInstruction::CreateTuple({induction_var_init, data_init})) : builder.AddInstruction( HloInstruction::CreateTuple({data_init, induction_var_init})); - auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile( - loop_state_shape_, condition, body, loop_state_init)); + auto while_hlo = builder.AddInstruction( + HloInstruction::CreateWhile(GetLoopStateShape(ind_var_tuple_index), + condition, body, loop_state_init)); module_->AddEntryComputation(builder.Build()); return while_hlo; } @@ -115,18 +117,34 @@ class WhileTransformerTest : public HloTestBase { } void RunCopyInsertionPass() { + HloVerifier verifier([](const Shape& shape) { + return ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/sizeof(void*)); + }); + TF_ASSERT_OK(verifier.Run(module_.get()).status()); CopyInsertion copy_insertion; - EXPECT_IS_OK(copy_insertion.Run(module_.get()).status()); + TF_ASSERT_OK(copy_insertion.Run(module_.get()).status()); + } + + Shape GetLoopStateShape(const int64 ind_var_tuple_index) { + if (ind_var_tuple_index == 0) { + return ShapeUtil::MakeTupleShape( + {induction_variable_shape_, data_shape_}); + } else { + return ShapeUtil::MakeTupleShape( + {data_shape_, induction_variable_shape_}); + } } std::unique_ptr module_; Shape induction_variable_shape_; Shape data_shape_; - Shape loop_state_shape_; Shape condition_result_shape_; }; -TEST_F(WhileTransformerTest, InductionVariableAtTupleElement0) { +// TODO(b/68830972): The while transformer is far too fragile. It patterns +// matches the exact expressions of opcodes. Re-enable when transformation is +// more general +TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement0) { // Build computation with induction variable at tuple element 0. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(0, 10)); @@ -137,13 +155,16 @@ TEST_F(WhileTransformerTest, InductionVariableAtTupleElement0) { RunCopyInsertionPass(); // Run WhileTransformer. auto result = gpu::CanTransformWhileToFor(while_hlo); - ASSERT_TRUE(result.ok()); + TF_ASSERT_OK(result.status()); // Check results. EXPECT_THAT(result.ConsumeValueOrDie(), Eq(std::tuple(0, 10, 1))); } -TEST_F(WhileTransformerTest, InductionVariableAtTupleElement1) { +// TODO(b/68830972): The while transformer is far too fragile. It patterns +// matches the exact expressions of opcodes. Re-enable when transformation is +// more general +TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement1) { // Build computation with induction variable at tuple element 1. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(1, 10)); @@ -154,13 +175,16 @@ TEST_F(WhileTransformerTest, InductionVariableAtTupleElement1) { RunCopyInsertionPass(); // Run WhileTransformer. auto result = gpu::CanTransformWhileToFor(while_hlo); - ASSERT_TRUE(result.ok()); + TF_ASSERT_OK(result.status()); // Check results. EXPECT_THAT(result.ConsumeValueOrDie(), Eq(std::tuple(0, 10, 1))); } -TEST_F(WhileTransformerTest, InvalidLoopLimit) { +// TODO(b/68830972): The while transformer is far too fragile. It patterns +// matches the exact expressions of opcodes. Re-enable when transformation is +// more general +TEST_F(WhileTransformerTest, DISABLED_InvalidLoopLimit) { // Build computation with invalid loop limit. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(0, 5)); @@ -176,7 +200,10 @@ TEST_F(WhileTransformerTest, InvalidLoopLimit) { HasSubstr("Loop start must be less than loop limit.")); } -TEST_F(WhileTransformerTest, InvalidLoopIncrement) { +// TODO(b/68830972): The while transformer is far too fragile. It patterns +// matches the exact expressions of opcodes. Re-enable when transformation is +// more general +TEST_F(WhileTransformerTest, DISABLED_InvalidLoopIncrement) { // Build computation with invalid loop increment. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(0, 10)); diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc index 049e8d80d80c835bca4a4d38592564ba82a3ecf9..05017008e2ddbe0b9e78d06275fdec5d08d94bfa 100644 --- a/tensorflow/compiler/xla/service/graphviz_example.cc +++ b/tensorflow/compiler/xla/service/graphviz_example.cc @@ -108,8 +108,11 @@ std::unique_ptr MakeBigGraph() { HloInstruction::CreateUnary(vshape, HloOpcode::kCopy, param_v0)); auto clamp = builder.AddInstruction(HloInstruction::CreateTernary( vshape, HloOpcode::kClamp, copy, param_v1, param_v2)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(vshape, HloOpcode::kDot, clamp, param_v0)); + HloInstruction::CreateDot(vshape, clamp, param_v0, dot_dnums)); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({dot, param_s, clamp})); auto scalar = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 17b926c8748e45b55f380e7595711b9e7a748f64..387b649a731ebcbfd8307807469f39f22d192b06 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -259,8 +259,11 @@ TEST_F(HeapSimulatorTest, MultiplyDot) { HloInstruction::CreateParameter(2, f32scalar_, "paramY")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, mul, paramY)); + HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); // The buffer for dot is the output, and it cannot be shared with the buffer // for mul, since dot isn't elementwise. @@ -292,8 +295,11 @@ TEST_F(HeapSimulatorTest, MultiplyDotAdd) { HloInstruction::CreateParameter(2, f32scalar_, "paramY")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, mul, paramY)); + HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, dot, paramA)); @@ -327,10 +333,13 @@ TEST_F(HeapSimulatorTest, MultiplyDotDot) { HloInstruction::CreateParameter(2, f32scalar_, "paramY")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot0 = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, mul, paramY)); + HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); auto dot1 = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, dot0, paramY)); + HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums)); // The buffer for dot1 is the output. No buffers can be shared. The buffer // for mul is freed before the end, since it's no longer used after dot0 @@ -365,10 +374,13 @@ TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) { HloInstruction::CreateParameter(2, f32scalar_, "paramY")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot0 = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, mul, paramY)); + HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); auto dot1 = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, dot0, paramY)); + HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums)); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({dot0, dot1})); diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 79493c4112804f8454d200f3f83aa85d718f0d0a..e4aed7593c51a2d1bb156493666b3c1b03dcc626 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -36,6 +36,9 @@ option cc_enable_arenas = true; // Serialization of HloInstruction. message HloInstructionProto { + reserved 10; + reserved "parameter_name"; + string name = 1; string opcode = 2; xla.Shape shape = 3; @@ -50,9 +53,8 @@ message HloInstructionProto { // Literal, only present for kConstant. xla.LiteralProto literal = 8; - // Parameter info, only present for kParameter. + // Parameter number is only present for kParameter. int64 parameter_number = 9; - string parameter_name = 10; // Fusion state, only present for kFusion. string fusion_kind = 11; @@ -118,6 +120,9 @@ message HloInstructionProto { // Shape of outfeed request. xla.Shape outfeed_shape = 29; + + // Describes the dimension numbers used for a dot operation + xla.DotDimensionNumbers dot_dimension_numbers = 30; } // Serialization of HloComputation. @@ -250,7 +255,3 @@ message HloProto { HloOrderingProto hlo_ordering = 2; BufferAssignmentProto buffer_assignment = 3; } - -message HloProtos { - repeated HloProto hlo_protos = 1; -} diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index 6f8099475146e6bbcfb61d2e5a91a7a6f9e63e58..6d2a3aa5b531650a658502531e050702ffbd3760 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -144,8 +144,10 @@ class BufferValueMap { // Move the given value into the given buffer. void MoveValueToBuffer(const HloValue& value, BufferNumber buffer_number) { BufferNumber old_buffer_number = value_to_buffer_number_.at(&value); - buffers_.at(old_buffer_number).erase(&value); - if (buffers_.at(old_buffer_number).empty()) { + tensorflow::gtl::FlatSet& old_value_set = + buffers_.at(old_buffer_number); + old_value_set.erase(&value); + if (old_value_set.empty()) { buffers_.erase(old_buffer_number); } @@ -175,7 +177,7 @@ class BufferValueMap { // Value is init of a while (use is while). std::vector aliased_buffers; for (const HloUse& use : value.uses()) { - VLOG(1) << "use of value " << value.ToShortString() << ": " << use; + VLOG(2) << "use of value " << value.ToShortString() << ": " << use; if (use.instruction->opcode() == HloOpcode::kWhile) { // Determine the while value that this shares a buffer with. const HloValue& while_value = @@ -411,7 +413,7 @@ string HloAliasAnalysis::ToString() const { /* static */ StatusOr> HloAliasAnalysis::Run( HloModule* module) { - VLOG(1) << "HloAliasAnalysis::Run on module " << module->name(); + VLOG(2) << "HloAliasAnalysis::Run on module " << module->name(); XLA_VLOG_LINES(2, module->ToString()); auto alias_analysis = WrapUnique(new HloAliasAnalysis(module)); @@ -444,7 +446,7 @@ StatusOr> HloAliasAnalysis::Run( TF_DCHECK_OK(alias_analysis->Verify()); - XLA_VLOG_LINES(1, alias_analysis->ToString()); + XLA_VLOG_LINES(2, alias_analysis->ToString()); return std::move(alias_analysis); } diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 8056bcf0f791bee949c02d6ecae4af633da84179..a63affa06caf75f1ccab084bd114e39ba7c91a38 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -131,9 +131,9 @@ Status HloComputation::RemoveParameter(int64 param_no) { while (param_no < param_instructions_.size()) { param_instruction = param_instructions_[param_no]; - string param_name = param_instruction->parameter_name(); + string param_name = param_instruction->name(); // Fusion parameters are named foo.param_1, bar.param_2, etc. We are - // renumbering the parameters so replace the final number in the name with + // renumbering the parameters, so replace the final number in the name with // the updated value. const string param_underscore = ".param_"; size_t index = param_name.rfind(param_underscore); @@ -176,10 +176,6 @@ bool HloComputation::IsRemovable(const HloInstruction* instruction) { return false; } - if (instruction->HasSideEffect()) { - return false; - } - return true; } @@ -207,7 +203,8 @@ Status HloComputation::RemoveInstructionAndUnusedOperands( worklist.pop(); if (removed.count(item) != 0 || item->user_count() != 0 || - item == root_instruction() || !IsRemovable(item)) { + item == root_instruction() || !IsRemovable(item) || + item->HasSideEffect()) { continue; } for (int i = 0; i < item->operand_count(); ++i) { @@ -367,26 +364,27 @@ std::list HloComputation::MakeEmbeddedComputationsList() return post_order; } -string HloComputation::ToString(int nested_level, - bool include_large_constants) const { +string HloComputation::ToString(const HloPrintOptions& options) const { std::ostringstream s; - for (int i = 0; i < nested_level; i++) { + for (int i = 0; i < options.indent_amount(); i++) { s << " "; } - s << "%" << name() << " " << ShapeUtil::HumanString(ComputeProgramShape()) - << " {\n"; + if (options.print_percent()) { + s << "%"; + } + s << name(); + if (options.print_program_shape()) { + s << " " << ShapeUtil::HumanString(ComputeProgramShape()); + } + s << " {\n"; for (const HloInstruction* instruction : MakeInstructionPostOrder()) { - for (int i = 0; i < nested_level; i++) { + for (int i = 0; i < options.indent_amount(); i++) { s << " "; } s << " " << (instruction == root_instruction_ ? "ROOT " : "") - << instruction->ToString( - /*compact_operands=*/false, - /*include_metadata=*/true, - /*include_large_constants=*/include_large_constants) - << "\n"; + << instruction->ToString(options) << "\n"; } - for (int i = 0; i < nested_level; i++) { + for (int i = 0; i < options.indent_amount(); i++) { s << " "; } s << "}"; @@ -407,16 +405,18 @@ HloComputationProto HloComputation::ToProto() const { /* static */ StatusOr> HloComputation::CreateFromProto( HloModule* module, const HloComputationProto& proto, - tensorflow::gtl::FlatMap* computation_map, + const tensorflow::gtl::FlatMap& computation_map, + const std::function)>& + add_fused_computation, HloInstruction* fusion_instruction) { std::vector> instructions; tensorflow::gtl::FlatMap instruction_map; int64 parameter_count = 0; for (const HloInstructionProto& instruction_proto : proto.instructions()) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr instruction, - HloInstruction::CreateFromProto(module, instruction_proto, - instruction_map, computation_map)); + TF_ASSIGN_OR_RETURN(std::unique_ptr instruction, + HloInstruction::CreateFromProto( + module, instruction_proto, instruction_map, + computation_map, add_fused_computation)); if (instruction->opcode() == HloOpcode::kParameter) { parameter_count++; } @@ -541,7 +541,7 @@ ProgramShape HloComputation::ComputeProgramShape() const { for (auto* param_instruction : param_instructions_) { *program_shape.add_parameters() = param_instruction->shape(); - *program_shape.add_parameter_names() = param_instruction->parameter_name(); + *program_shape.add_parameter_names() = param_instruction->name(); } *program_shape.mutable_result() = root_instruction_->shape(); diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 2835dbbb846b24599840a9ee3ea72809d3f97dd2..6436815f910405477ec21a33dec75ef71df08602 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -138,8 +138,11 @@ class HloComputation { void UniquifyName(NameUniquer* name_uniquer); // Return a string representation of the computation. - string ToString(int nested_level = 0, - bool include_large_constants = false) const; + // + // (We express the default options using an overload rather than a default + // param because gdb ignores default params, but does resolve overloads.) + string ToString() const { return ToString(HloPrintOptions()); } + string ToString(const HloPrintOptions& options) const; // Returns a serialized representation of this computation. HloComputationProto ToProto() const; @@ -152,12 +155,16 @@ class HloComputation { // computation_map: a map from computation name to HloComputation*. This map // must contain all computations which the newly constructed computation // calls. - // fusion_instruction: if non-null then the newly created computation will be - // constructed as a fused computation with this instruction as its fusion - // parent. + // add_fused_computation: A function to call to add a fused + // computation. Used only when the instruction is a fusion instruction. + // fusion_instruction: if non-null then the newly created computation will + // be constructed as a fused computation with this instruction as its + // fusion parent. static StatusOr> CreateFromProto( HloModule* module, const HloComputationProto& proto, - tensorflow::gtl::FlatMap* computation_map, + const tensorflow::gtl::FlatMap& computation_map, + const std::function)>& + add_fused_computation, HloInstruction* fusion_instruction = nullptr); // Gets the instructions in this computation. @@ -309,11 +316,17 @@ class HloComputation { replacements, HloModule* module = nullptr, const string& suffix = "clone"); - // Returns true if the given instruction can be removed from the - // computation. Instructions such as parameters and send/receive instructions - // cannot be removed without violating invariants of the HLO computation or - // module with the exception of fusion computation. A parameter instruction - // is removable for a fusion computation. + // Returns true if the given instruction can be removed from the computation. + // Parameter instructions cannot be removed without violating invariants of + // the HLO computation with the exception of fusion computation. A parameter + // instruction is removable for a fusion computation. + // + // Note that IsRemovable() is a necessariy condition to remove an instruction + // rather than a sufficient condition. For example, instructions with + // side-effect (e.g., Send, Infeed) may be removed from a computation, but the + // transformation must guarantee the invariants relevant to the instructions + // still hold (e.g., Send and Recv must be removed together to make each + // channel complete). bool IsRemovable(const HloInstruction* instruction); // Returns true if this computation has a side effect. A computation has a diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 1877065f672bdf705f044568e2d77ac342a808cc..b933695b823871c6c0174da6d6f99e618219442a 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -22,13 +22,14 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" namespace xla { constexpr char HloCostAnalysis::kFlopsKey[]; constexpr char HloCostAnalysis::kTranscendentalsKey[]; constexpr char HloCostAnalysis::kBytesAccessedKey[]; -constexpr char HloCostAnalysis::kSecondsKey[]; +constexpr char HloCostAnalysis::kOptimalSecondsKey[]; HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size) : HloCostAnalysis(shape_size, {}) {} @@ -60,16 +61,16 @@ Status HloCostAnalysis::Postprocess(const HloInstruction* hlo) { if (current_should_compute_bottleneck_time_) { // Compute the time as the time of the bottleneck, i.e. the slowest property // given the per-second rate of each property. - float max_seconds = 0.0f; + float optimal_seconds = 0.0f; for (const auto& property : current_properties_) { - if (property.first != kSecondsKey) { - max_seconds = std::max( - max_seconds, + if (property.first != kOptimalSecondsKey) { + optimal_seconds = std::max( + optimal_seconds, property.second / GetProperty(property.first, per_second_rates_, INFINITY)); } } - current_properties_[kSecondsKey] = max_seconds; + current_properties_[kOptimalSecondsKey] = optimal_seconds; } TF_RET_CHECK(hlo_properties_.emplace(hlo, current_properties_).second); @@ -200,10 +201,11 @@ Status HloCostAnalysis::HandleCopy(const HloInstruction*) { Status HloCostAnalysis::HandleDot(const HloInstruction* dot) { const Shape& lhs_shape = dot->operand(0)->shape(); const Shape& rhs_shape = dot->operand(1)->shape(); + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); // Count of elements along the reduction dimension (last dimension for the // rhs). - int64 reduction_width = lhs_shape.dimensions(ShapeUtil::Rank(lhs_shape) - 1); - + int64 reduction_width = + lhs_shape.dimensions(dnums.lhs_contracting_dimensions(0)); // First divide by reduction width before multiplying by rhs elements to avoid // overflow. int64 fma_count; @@ -396,7 +398,14 @@ Status HloCostAnalysis::HandleCrossReplicaSum(const HloInstruction* crs) { // // TODO(b/33004697): Compute correct cost here, taking the actual number of // replicas into account. - current_properties_[kFlopsKey] = ShapeUtil::ElementsIn(crs->shape()); + double flops = 0.0; + ShapeUtil::ForEachSubshape( + crs->shape(), [&, this](const Shape& subshape, const ShapeIndex&) { + if (ShapeUtil::IsArray(subshape)) { + flops += ShapeUtil::ElementsIn(subshape); + } + }); + current_properties_[kFlopsKey] = flops; return Status::OK(); } @@ -480,6 +489,25 @@ Status HloCostAnalysis::HandleWhile(const HloInstruction* xla_while) { return Status::OK(); } +Status HloCostAnalysis::HandleConditional(const HloInstruction* conditional) { + // Compute the cost of the true and false computations and take the maximum + // from those for each property. + TF_ASSIGN_OR_RETURN(const Properties true_computation_properties, + ProcessSubcomputation(conditional->true_computation())); + TF_ASSIGN_OR_RETURN(const Properties false_computation_properties, + ProcessSubcomputation(conditional->false_computation())); + current_properties_ = true_computation_properties; + for (const auto& property : false_computation_properties) { + if (!tensorflow::gtl::InsertIfNotPresent(¤t_properties_, property)) { + current_properties_[property.first] = + std::max(current_properties_[property.first], property.second); + } + } + current_should_compute_bottleneck_time_ = false; + + return Status::OK(); +} + Status HloCostAnalysis::FinishVisit(const HloInstruction*) { return Status::OK(); } @@ -496,8 +524,8 @@ float HloCostAnalysis::bytes_accessed() const { return GetProperty(kBytesAccessedKey, properties_sum_); } -float HloCostAnalysis::seconds() const { - return GetProperty(kSecondsKey, properties_sum_); +float HloCostAnalysis::optimal_seconds() const { + return GetProperty(kOptimalSecondsKey, properties_sum_); } int64 HloCostAnalysis::flop_count(const HloInstruction& hlo) const { @@ -512,8 +540,8 @@ int64 HloCostAnalysis::bytes_accessed(const HloInstruction& hlo) const { return GetPropertyForHlo(hlo, kBytesAccessedKey, hlo_properties_); } -float HloCostAnalysis::seconds(const HloInstruction& hlo) const { - return GetPropertyForHlo(hlo, kSecondsKey, hlo_properties_); +float HloCostAnalysis::optimal_seconds(const HloInstruction& hlo) const { + return GetPropertyForHlo(hlo, kOptimalSecondsKey, hlo_properties_); } StatusOr HloCostAnalysis::ProcessSubcomputation( diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 0f447753788d870e91204fcb03eb2de204c958bf..fade19522cf0c30eab037aa355de1f9203f80014 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -42,7 +42,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { static constexpr char kFlopsKey[] = "flops"; static constexpr char kTranscendentalsKey[] = "transcendentals"; static constexpr char kBytesAccessedKey[] = "bytes accessed"; - static constexpr char kSecondsKey[] = "seconds"; + static constexpr char kOptimalSecondsKey[] = "optimal_seconds"; // shape_size is a function which returns the size in bytes of the top-level // buffer of a shape. @@ -97,6 +97,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleReshape(const HloInstruction* reshape) override; Status HandleTranspose(const HloInstruction* transpose) override; Status HandleWhile(const HloInstruction* xla_while) override; + Status HandleConditional(const HloInstruction* conditional) override; Status FinishVisit(const HloInstruction* root) override; Status Preprocess(const HloInstruction* hlo) override; @@ -118,14 +119,14 @@ class HloCostAnalysis : public ConstDfsHloVisitor { float flop_count() const; float transcendental_count() const; float bytes_accessed() const; - float seconds() const; + float optimal_seconds() const; // Returns the respective cost computed for a particular HLO instruction, or 0 // if the HLO was not found to have a cost in the analysis. int64 flop_count(const HloInstruction& hlo) const; int64 transcendental_count(const HloInstruction& hlo) const; int64 bytes_accessed(const HloInstruction& hlo) const; - float seconds(const HloInstruction& hlo) const; + float optimal_seconds(const HloInstruction& hlo) const; const Properties& properties() const { return properties_sum_; } const float property(const string& key) const { diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 0eaa21ef254e3461baaaca57503ab24ce35ac929..3b289c240a45e8f3df8156ed89e879da2132d01a 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -389,7 +389,7 @@ TEST_F(FusionCostAnalysis, LoopFusion) { static_assert(bytes_accessed == 64, ""); EXPECT_EQ(fusion_analysis.bytes_accessed(), bytes_accessed); - EXPECT_EQ(fusion_analysis.seconds(), 1 << i); + EXPECT_EQ(fusion_analysis.optimal_seconds(), 1 << i); } } diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 3f34b9ceb34abc89fca5b896bb8fbe3a06cd6ed4..2a335843f507e2071807245d4dd256e1ec6f08c8 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -333,6 +333,21 @@ bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) { return false; } +bool HloDataflowAnalysis::UpdateConditionalValueSet( + HloInstruction* conditional) { + CHECK_EQ(conditional->opcode(), HloOpcode::kConditional); + std::vector inputs = { + &GetInstructionValueSet( + conditional->true_computation()->root_instruction()), + &GetInstructionValueSet( + conditional->false_computation()->root_instruction())}; + // A phi-node is not defined for a kConditional instruction even though it + // represents a join point. This is because the current approach is to define + // a phi-node only for kWhile to account for the dataflow through back-edges + // and deal with the ambiguity in other cases. + return GetInstructionValueSet(conditional).AssignUnionOf(inputs); +} + bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) { CHECK_EQ(copy->opcode(), HloOpcode::kCopy); bool changed = false; @@ -394,7 +409,7 @@ bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) { CHECK_EQ(call_graph_node.context(), CallContext::kSequential); std::vector inputs; - bool called_from_while = false; + bool need_phi = false; for (const CallSite& callsite : call_graph_node.caller_callsites()) { if (callsite.instruction()->opcode() == HloOpcode::kCall) { // The operand values of a call instruction are forwarded to the @@ -416,14 +431,32 @@ bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) { inputs.push_back(&GetInstructionValueSet( callsite.instruction()->while_body()->root_instruction())); } - called_from_while = true; + need_phi = true; + } else if (callsite.instruction()->opcode() == HloOpcode::kConditional) { + CHECK_EQ(parameter->parameter_number(), 0); + auto conditional = callsite.instruction(); + // Conditional has 3 operands. Operand 0 is the predicate, operand 1 is + // the argument to the true computation and operand 2 is the argument to + // the false computation. + // + // If the parameter belongs to conditional's true computation, then + // operand 1 is forwarded to this parameter instruction. If the parameter + // belongs to conditional's false computation, then operand 2 is forwarded + // to this parameter instruction. + if (parameter->parent() == conditional->true_computation()) { + inputs.push_back(&GetInstructionValueSet(conditional->operand(1))); + } else { + CHECK_EQ(parameter->parent(), conditional->false_computation()); + inputs.push_back(&GetInstructionValueSet(conditional->operand(2))); + } + need_phi = true; } else { LOG(FATAL) << "CallContext::kSequential computations should only be " - "called from call or while instructions"; + "called from call, while, or conditional instructions"; } } - if (ssa_form_ && called_from_while) { + if (ssa_form_ && need_phi) { return Phi(parameter, inputs); } else { return GetInstructionValueSet(parameter).AssignUnionOf(inputs); @@ -512,6 +545,8 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( return UpdateSendValueSet(instruction); case HloOpcode::kRecvDone: return UpdateRecvDoneValueSet(instruction); + case HloOpcode::kConditional: + return UpdateConditionalValueSet(instruction); default: // Instruction does not forward HloValues (it defines all values in its // output). No update is necessary. @@ -550,13 +585,31 @@ void HloDataflowAnalysis::Propagate() { // If user sequentially calls a computation, then the respective // parameter(s) of the computation need to be updated. - for (HloComputation* called_computation : user->called_computations()) { - const CallGraphNode& call_graph_node = - call_graph_->GetNode(called_computation); - if (call_graph_node.context() == CallContext::kSequential) { - for (int64 operand_number : user->OperandIndices(instruction)) { - worklist.push( - called_computation->parameter_instruction(operand_number)); + if (user->opcode() == HloOpcode::kConditional) { + // If operand 0 is the use of instruction, then no parameters need to be + // updated, since that is the predicate of the conditional. + // If operand 1 is the use of instruction, then the true_computation's + // parameter need to be updated. + // If operand 2 is the use of instruction, then the false_computation's + // parameter need to be updated. + // + // Note that the same instruction can be used in both operand 1 and + // operand 2. + if (user->operand(1) == instruction) { + worklist.push(user->true_computation()->parameter_instruction(0)); + } + if (user->operand(2) == instruction) { + worklist.push(user->false_computation()->parameter_instruction(0)); + } + } else { + for (HloComputation* called_computation : user->called_computations()) { + const CallGraphNode& call_graph_node = + call_graph_->GetNode(called_computation); + if (call_graph_node.context() == CallContext::kSequential) { + for (int64 operand_number : user->OperandIndices(instruction)) { + worklist.push( + called_computation->parameter_instruction(operand_number)); + } } } } @@ -568,7 +621,8 @@ void HloDataflowAnalysis::Propagate() { const CallGraphNode& call_graph_node = call_graph_->GetNode(instruction->parent()); for (const CallSite& callsite : call_graph_node.caller_callsites()) { - if (callsite.instruction()->opcode() == HloOpcode::kCall) { + if ((callsite.instruction()->opcode() == HloOpcode::kCall) || + (callsite.instruction()->opcode() == HloOpcode::kConditional)) { worklist.push(callsite.instruction()); } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) { // Add the while itself, and the body and condition parameters. @@ -636,6 +690,7 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { break; case HloOpcode::kWhile: case HloOpcode::kCall: + case HloOpcode::kConditional: case HloOpcode::kGetTupleElement: // These instructions define no values. The values in their output // flow from their operands or from cross computation dataflow. diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index dfd81ae951042f7a4d6d3c24af4d5b7e046c272d..469620d01295f90e0c36a48cac9be47c12473a68 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -146,6 +146,7 @@ class HloDataflowAnalysis { // the instruction value set changed. bool UpdateBitcastValueSet(HloInstruction* bitcast); bool UpdateCallValueSet(HloInstruction* call); + bool UpdateConditionalValueSet(HloInstruction* conditional); bool UpdateCopyValueSet(HloInstruction* copy); bool UpdateGetTupleElementValueSet(HloInstruction* gte); bool UpdateParameterValueSet(HloInstruction* parameter); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index f08f0b1d6833b028baa5f997929a17eb5abae205..e714b2567fd1b3eab607a19f0bb7e3288150dc64 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -34,6 +34,7 @@ limitations under the License. namespace xla { namespace { +using ::testing::ElementsAre; using ::testing::UnorderedElementsAre; // Test is parameterized on a bool which is whether the dataflow analysis is @@ -77,11 +78,23 @@ class HloDataflowAnalysisTest : public HloTestBase, analysis_->GetValueDefinedAt(b), *analysis_); } + std::unique_ptr CreateR0F32UnaryOpComputation( + HloOpcode opcode) { + HloComputation::Builder builder(TestName() + "." + HloOpcodeString(opcode)); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param0")); + builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, opcode, param0)); + return builder.Build(); + } + std::unique_ptr module_; std::unique_ptr analysis_; const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); const Shape vector_shape_ = ShapeUtil::MakeShape(F32, {42}); + const Shape tuple_shape_ = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}); }; TEST_P(HloDataflowAnalysisTest, BinaryOperation) { @@ -1528,6 +1541,315 @@ TEST_P(HloDataflowAnalysisTest, EmbeddedComputationInterference) { EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, embedded_log)); } +TEST_P(HloDataflowAnalysisTest, ConditionalWithIdentity) { + // Test conditional with identity computations in both true and false cases. + // + // true_computation(F32[] %true_param): + // return %true_param + // + // false_computation(F32[] %false_param): + // return %false_param + // + // entry: + // %pred = Constant(true) + // %constant1 = Constant(56.0) + // %constant2 = Constant(12.0) + // return Conditional(%pred, %constant1, true_computation, + // %constant2, false_computation) + + auto true_builder = HloComputation::Builder(TestName() + "_true"); + auto true_param = true_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "true_param")); + HloComputation* true_computation = + module_->AddEmbeddedComputation(true_builder.Build()); + + auto false_builder = HloComputation::Builder(TestName() + "_false"); + auto false_param = false_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "false_param")); + HloComputation* false_computation = + module_->AddEmbeddedComputation(false_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(56.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(12.0f))); + auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( + scalar_shape_, pred, constant1, true_computation, constant2, + false_computation)); + module_->AddEntryComputation(builder.Build()); + + const HloDataflowAnalysis& analysis = RunAnalysis(GetParam()); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(pred)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2)); + + EXPECT_FALSE(analysis.ValueIsDefinedAt(true_param)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(false_param)); + + EXPECT_EQ(analysis.GetUniqueValueAt(true_param), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(false_param), + analysis.GetValueDefinedAt(constant2)); + + EXPECT_THAT(analysis.GetValueDefinedAt(pred).uses(), + ElementsAre(HloUse{conditional, 0, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), + ElementsAre(HloUse{conditional, 1, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), + ElementsAre(HloUse{conditional, 2, {}})); + + EXPECT_EQ(analysis.values().size(), 3); + EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); + EXPECT_THAT(HloValuesAt(conditional), + UnorderedElementsAre(analysis.GetValueDefinedAt(constant1), + analysis.GetValueDefinedAt(constant2))); +} + +TEST_P(HloDataflowAnalysisTest, ConditionalTakingTupleOperand) { + // Test conditional with true and false computations taking a tuple operand. + // + // true_computation((F32[], F32[]) %true_param): + // %true_x = GetTupleElement(%true_param, 0) + // %true_y = GetTupleElement(%true_param, 1) + // return Add(%true_x, %true_y) + // + // false_computation((F32[], F32[]) %false_param): + // %false_x = GetTupleElement(%false_param, 0) + // %false_y = GetTupleElement(%false_param, 1) + // return Subtract(%false_x, %false_y) + // + // entry: + // %pred = Constant(true) + // %constant1 = Constant(56.0) + // %constant2 = Constant(12.0) + // %tuple_operand = Tuple(%constant1, %constant2) + // return Conditional(%pred, %tuple_operand, true_computation, + // %tuple_operand, false_computation) + + auto true_builder = HloComputation::Builder(TestName() + "_true"); + auto true_param = true_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_, "true_param")); + auto true_x = true_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, true_param, 0)); + auto true_y = true_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, true_param, 1)); + auto add = true_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, true_x, true_y)); + HloComputation* true_computation = + module_->AddEmbeddedComputation(true_builder.Build()); + + auto false_builder = HloComputation::Builder(TestName() + "_false"); + auto false_param = false_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_, "false_param")); + auto false_x = false_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, false_param, 0)); + auto false_y = false_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, false_param, 1)); + auto sub = false_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kSubtract, false_x, false_y)); + HloComputation* false_computation = + module_->AddEmbeddedComputation(false_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(56.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(12.0f))); + auto tuple_operand = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant2})); + auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( + scalar_shape_, pred, tuple_operand, true_computation, tuple_operand, + false_computation)); + module_->AddEntryComputation(builder.Build()); + + const HloDataflowAnalysis& analysis = RunAnalysis(GetParam()); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(pred)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple_operand)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(add)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(sub)); + + EXPECT_FALSE(analysis.ValueIsDefinedAt(true_param)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(false_param)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(true_x)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(true_y)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(false_x)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(false_y)); + + EXPECT_EQ(analysis.GetUniqueValueAt(true_param), + analysis.GetValueDefinedAt(tuple_operand)); + EXPECT_EQ(analysis.GetUniqueValueAt(false_param), + analysis.GetValueDefinedAt(tuple_operand)); + EXPECT_EQ(analysis.GetUniqueValueAt(true_x), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(true_y), + analysis.GetValueDefinedAt(constant2)); + EXPECT_EQ(analysis.GetUniqueValueAt(false_x), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(false_y), + analysis.GetValueDefinedAt(constant2)); + + EXPECT_THAT(analysis.GetValueDefinedAt(pred).uses(), + ElementsAre(HloUse{conditional, 0, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), + UnorderedElementsAre(HloUse{conditional, 1, {0}}, + HloUse{conditional, 2, {0}}, + HloUse{add, 0, {}}, HloUse{sub, 0, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), + UnorderedElementsAre(HloUse{conditional, 1, {1}}, + HloUse{conditional, 2, {1}}, + HloUse{add, 1, {}}, HloUse{sub, 1, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(tuple_operand).uses(), + UnorderedElementsAre( + HloUse{conditional, 1, {}}, HloUse{conditional, 2, {}}, + HloUse{true_x, 0, {}}, HloUse{true_y, 0, {}}, + HloUse{false_x, 0, {}}, HloUse{false_y, 0, {}})); + + EXPECT_EQ(analysis.values().size(), 6); + EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); + EXPECT_THAT(HloValuesAt(conditional), + UnorderedElementsAre(analysis.GetValueDefinedAt(add), + analysis.GetValueDefinedAt(sub))); +} + +TEST_P(HloDataflowAnalysisTest, NestedConditionals) { + // computation1(F32[] %param1): + // %ceil = Ceil(%param1) + // return %ceil + // + // computation2(F32[] %param2): + // %floor = Floor(%param2) + // return %floor + // + // computation3(F32[] %param3): + // %negate = Negate(%param3) + // return %negate + // + // inner_conditional((PRED, F32[], F32[]) %param_cond): + // %pred_cond = GetTupleElement(%param_cond, 0) + // %true_operand_cond = GetTupleElement(%param_cond, 1) + // %false_opearnd_cond = GetTupleElement(%param_cond, 2) + // return Conditional(%pred_cond, %true_operand_cond, computation1, + // %false_operand_cond, computation2) + // + // entry: + // %pred1 = Constant(true) + // %pred2 = Constant(false) + // %constant1 = Constant(1.1); + // %constant2 = Constant(2.2); + // %constant3 = Constant(3.3); + // return Conditional(%pred1, (%pred2, %constant1, %constant2), + // inner_conditional, %constant3, computation3) + + auto computation1 = module_->AddEmbeddedComputation( + CreateR0F32UnaryOpComputation(HloOpcode::kCeil)); + auto computation2 = module_->AddEmbeddedComputation( + CreateR0F32UnaryOpComputation(HloOpcode::kFloor)); + auto computation3 = module_->AddEmbeddedComputation( + CreateR0F32UnaryOpComputation(HloOpcode::kNegate)); + + // Build inner_conditional computation. + const Shape scalar_bool_shape = ShapeUtil::MakeShape(PRED, {}); + const Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {scalar_bool_shape, scalar_shape_, scalar_shape_}); + auto inner_builder = + HloComputation::Builder(TestName() + "_inner_conditional"); + auto param_cond = inner_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_param_shape, "param_cond")); + auto pred_cond = inner_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_bool_shape, param_cond, 0)); + auto true_operand_cond = inner_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param_cond, 1)); + auto false_operand_cond = inner_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param_cond, 2)); + auto inner_conditional = + inner_builder.AddInstruction(HloInstruction::CreateConditional( + scalar_shape_, pred_cond, true_operand_cond, computation1, + false_operand_cond, computation2)); + auto inner_conditional_computation = + module_->AddEmbeddedComputation(inner_builder.Build()); + + // Build entry computation. + auto builder = HloComputation::Builder(TestName()); + auto pred1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + auto pred2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(2.2f))); + auto constant3 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(3.3f))); + auto tuple_operand = builder.AddInstruction( + HloInstruction::CreateTuple({pred2, constant1, constant2})); + auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( + scalar_shape_, pred1, tuple_operand, inner_conditional_computation, + constant3, computation3)); + module_->AddEntryComputation(builder.Build()); + + const HloDataflowAnalysis& analysis = RunAnalysis(GetParam()); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(pred1)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(pred2)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant3)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple_operand)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(computation1->root_instruction())); + EXPECT_TRUE(analysis.ValueIsDefinedAt(computation2->root_instruction())); + EXPECT_TRUE(analysis.ValueIsDefinedAt(computation3->root_instruction())); + + auto computation1_param = computation1->parameter_instruction(0); + auto computation2_param = computation2->parameter_instruction(0); + auto computation3_param = computation3->parameter_instruction(0); + EXPECT_FALSE(analysis.ValueIsDefinedAt(computation1_param)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(computation2_param)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(computation3_param)); + EXPECT_EQ(analysis.GetUniqueValueAt(computation1_param), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(computation2_param), + analysis.GetValueDefinedAt(constant2)); + EXPECT_EQ(analysis.GetUniqueValueAt(computation3_param), + analysis.GetValueDefinedAt(constant3)); + + EXPECT_FALSE(analysis.ValueIsDefinedAt(param_cond)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(pred_cond)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(true_operand_cond)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(false_operand_cond)); + EXPECT_EQ(analysis.GetUniqueValueAt(param_cond), + analysis.GetValueDefinedAt(tuple_operand)); + EXPECT_EQ(analysis.GetUniqueValueAt(pred_cond), + analysis.GetValueDefinedAt(pred2)); + EXPECT_EQ(analysis.GetUniqueValueAt(true_operand_cond), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(false_operand_cond), + analysis.GetValueDefinedAt(constant2)); + + EXPECT_EQ(analysis.values().size(), 9); + EXPECT_FALSE(analysis.ValueIsDefinedAt(inner_conditional)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); + EXPECT_THAT( + HloValuesAt(inner_conditional), + UnorderedElementsAre( + analysis.GetValueDefinedAt(computation1->root_instruction()), + analysis.GetValueDefinedAt(computation2->root_instruction()))); + EXPECT_THAT( + HloValuesAt(conditional), + UnorderedElementsAre( + analysis.GetValueDefinedAt(computation1->root_instruction()), + analysis.GetValueDefinedAt(computation2->root_instruction()), + analysis.GetValueDefinedAt(computation3->root_instruction()))); +} + INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation, HloDataflowAnalysisTest, ::testing::Values(false, true)); diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index a4921232f5848dbe1789c4c641e2b0ba3c1848bb..1e5f0f797a13fd7e7ce1cc934387a274a74153bc 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -37,6 +37,9 @@ namespace xla { StatusOr HloDCE::Run(HloModule* module) { bool changed = false; + VLOG(2) << "Before dce:"; + XLA_VLOG_LINES(2, module->ToString()); + for (auto* computation : module->MakeNonfusionComputations()) { std::unordered_set live_instructions; TF_RETURN_IF_ERROR(computation->root_instruction()->Accept( @@ -52,12 +55,15 @@ StatusOr HloDCE::Run(HloModule* module) { for (auto* instruction : computation->instructions()) { if (instruction->user_count() == 0 && live_instructions.count(instruction) == 0 && - computation->IsRemovable(instruction)) { + computation->IsRemovable(instruction) && + !instruction->HasSideEffect()) { dead_roots.push_back(instruction); } } for (HloInstruction* dead_root : dead_roots) { + VLOG(1) << "Removing dead root " << dead_root->ToString() + << " and it's unused operands"; TF_RETURN_IF_ERROR( computation->RemoveInstructionAndUnusedOperands(dead_root)); changed = true; @@ -87,6 +93,9 @@ StatusOr HloDCE::Run(HloModule* module) { } } + VLOG(2) << "After dce:"; + XLA_VLOG_LINES(2, module->ToString()); + return changed; } diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index d54b9a27087a42fd23eab0bd06e8deaca567312b..5a56607a665c4cbeb7b2572f182b88e890602968 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -70,6 +70,26 @@ TEST_F(HloDceTest, NoDeadCode) { EXPECT_EQ(3, computation->instruction_count()); } +TEST_F(HloDceTest, InstructionsWithSideEffect) { + // Verify that side-effect instructions (Send in this test) are not removed. + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + builder.AddInstruction( + HloInstruction::CreateSend(constant, /*channel_id=*/0)); + builder.AddInstruction(HloInstruction::CreateTuple({})); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(3, computation->instruction_count()); + + HloDCE dce; + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + + EXPECT_EQ(3, computation->instruction_count()); +} + TEST_F(HloDceTest, DeadParameters) { // Verify that dead parameters are not removed, but use of the dead parameters // are. diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc new file mode 100644 index 0000000000000000000000000000000000000000..1773bb401d380031f6c860d295e76d2f62c9e5ff --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -0,0 +1,137 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_query.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { +namespace { + +HloInstruction* ToElementType(HloInstruction* hlo, PrimitiveType type) { + if (hlo->shape().element_type() != type) { + Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type); + hlo = hlo->parent()->AddInstruction( + HloInstruction::CreateConvert(shape, hlo)); + } + CHECK_EQ(hlo->shape().element_type(), type); + return hlo; +} + +bool HasOperandType(HloInstruction* hlo, PrimitiveType type) { + for (HloInstruction* operand : hlo->operands()) { + if (operand->shape().element_type() == type) { + return true; + } + } + return false; +} + +} // namespace + +HloElementTypeConverter::HloElementTypeConverter( + PrimitiveType eliminate_type, PrimitiveType replace_with_type) + : eliminate_type_(eliminate_type), replace_with_type_(replace_with_type) {} + +StatusOr HloElementTypeConverter::Run(HloModule* module) { + XLA_VLOG_LINES( + 3, "HloElementTypeConverter::Run(), before:\n" + module->ToString()); + bool changed = false; + for (auto* computation : module->computations()) { + for (auto* hlo : computation->MakeInstructionPostOrder()) { + // These are ops where it does not make sense to convert them. + if (hlo->opcode() == HloOpcode::kParameter || + hlo->opcode() == HloOpcode::kConstant || + hlo->opcode() == HloOpcode::kTuple || + hlo->opcode() == HloOpcode::kConvert || + hlo->opcode() == HloOpcode::kGetTupleElement || + hlo->opcode() == HloOpcode::kInfeed || + hlo->opcode() == HloOpcode::kOutfeed) { + continue; + } + + // We cannot change a CustomCall since we have no way of adjusting the + // called binary to expect the updated type. + if (hlo->opcode() == HloOpcode::kCustomCall) { + continue; + } + + // These are ops with embedded computations where it suffices to convert + // the embedded computations instead of converting the ops themselves. + if (hlo->opcode() == HloOpcode::kWhile || + hlo->opcode() == HloOpcode::kCall || + hlo->opcode() == HloOpcode::kFusion || + hlo->opcode() == HloOpcode::kMap || + hlo->opcode() == HloOpcode::kReduce || + hlo->opcode() == HloOpcode::kReduceWindow || + hlo->opcode() == HloOpcode::kSelectAndScatter || + hlo->opcode() == HloOpcode::kConditional) { + continue; + } + TF_RET_CHECK(hlo->called_computations().empty()) << hlo->ToString(); + + if (!HasOperandType(hlo, eliminate_type_)) { + // If this CHECK fires, then this was an instruction that does not take + // the elimination type as an operand but it does return it. This pass + // does not have a feature to change the output type in that case, so + // instead of silently failing to eliminate the type, it fails loudly. + TF_RET_CHECK(hlo->shape().element_type() != eliminate_type_); + continue; + } + + std::vector new_operands; + for (HloInstruction* operand : hlo->operands()) { + if (operand->shape().element_type() == eliminate_type_) { + operand = ToElementType(operand, replace_with_type_); + } + new_operands.push_back(operand); + } + + HloInstruction* new_hlo; + if (hlo->shape().element_type() == eliminate_type_) { + Shape shape = + ShapeUtil::ChangeElementType(hlo->shape(), replace_with_type_); + new_hlo = computation->AddInstruction( + hlo->CloneWithNewOperands(shape, new_operands, hlo->GetModule())); + new_hlo = ToElementType(new_hlo, eliminate_type_); + } else { + new_hlo = computation->AddInstruction(hlo->CloneWithNewOperands( + hlo->shape(), new_operands, hlo->GetModule())); + } + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, new_hlo)); + changed = true; + } + } + XLA_VLOG_LINES( + 2, "HloElementTypeConverter::Run(), after:\n" + module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.h b/tensorflow/compiler/xla/service/hlo_element_type_converter.h new file mode 100644 index 0000000000000000000000000000000000000000..2b109225d0b192e5c9e4f6d841377ffad8078dc2 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.h @@ -0,0 +1,49 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ELEMENT_TYPE_CONVERTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ELEMENT_TYPE_CONVERTER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// A pass that eliminates certain element types as the input or output of ops by +// inserting Convert ops. This allows a backend to support an element type while +// only actually implementing the Convert op for that element type. This is +// generally not the fastest approach, but it works. +class HloElementTypeConverter : public HloPassInterface { + public: + // eliminate_type is the type to eliminate as the input or output of ops, + // using Convert ops to replace it with replace_with_type. + HloElementTypeConverter(PrimitiveType eliminate_type, + PrimitiveType replace_with_type); + + tensorflow::StringPiece name() const override { + return "element_type_converter"; + } + + // Returns the pass on the module and returns whether the module was modified. + StatusOr Run(HloModule* module) override; + + private: + PrimitiveType eliminate_type_; + PrimitiveType replace_with_type_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ELEMENT_TYPE_CONVERTER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index a722d1b3d99462f7252c259f74dcef1dfa4967b7..173f0e2c42bed2ea461eef27d811e0a626c4fee3 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -167,11 +168,37 @@ StatusOr> ElementWiseUnaryOpImpl( } // namespace -template +template class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { public: explicit TypedVisitor(HloEvaluator* p) : parent_(p) {} + // The following higher-order functions convert a function with ElementwiseT + // to a function with ReturnT. + std::function ConvertUnaryFunction( + const std::function& unary_op) { + return [&unary_op](ReturnT arg) { + return static_cast(unary_op(static_cast(arg))); + }; + } + std::function ConvertBinaryFunction( + const std::function& + binary_op) { + return [&binary_op](ReturnT arg1, ReturnT arg2) { + return static_cast(binary_op(static_cast(arg1), + static_cast(arg2))); + }; + } + std::function ConvertTernaryFunction( + const std::function& ternary_op) { + return [&ternary_op](ReturnT arg1, ReturnT arg2, ReturnT arg3) { + return static_cast(ternary_op(static_cast(arg1), + static_cast(arg2), + static_cast(arg3))); + }; + } + Status DefaultAction(HloInstruction* hlo_instruction) override { return Unimplemented("unhandled HLO ops for HloEvaluator: %s.", HloOpcodeString(hlo_instruction->opcode()).c_str()); @@ -197,24 +224,25 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { is_complex_t::value>::type* = nullptr> Status HandleAbs(HloInstruction* abs) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], - ElementWiseUnaryOp(abs, [](NativeT elem_operand) { + ElementWiseUnaryOp(abs, [](ElementwiseT elem_operand) { return std::abs(elem_operand); })); return Status::OK(); } Status HandleAbs(HloInstruction* abs) override { - return HandleAbs(abs); + return HandleAbs(abs); } template < typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleRound(HloInstruction* round) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[round], - ElementWiseUnaryOp(round, [](ReturnT elem_operand) { - return std::round(elem_operand); - })); + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[round], + ElementWiseUnaryOp(round, [](ElementwiseT elem_operand) { + return std::round(elem_operand); + })); return Status::OK(); } @@ -264,7 +292,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value>::type* = nullptr> Status HandleCeil(HloInstruction* ceil) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[ceil], - ElementWiseUnaryOp(ceil, [](ReturnT elem_operand) { + ElementWiseUnaryOp(ceil, [](ElementwiseT elem_operand) { return std::ceil(elem_operand); })); return Status::OK(); @@ -299,7 +327,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { Status HandleExp(HloInstruction* exp) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp], - ElementWiseUnaryOp(exp, [](ReturnT elem_operand) { + ElementWiseUnaryOp(exp, [](ElementwiseT elem_operand) { return std::exp(elem_operand); })); return Status::OK(); @@ -309,10 +337,11 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleFloor(HloInstruction* floor) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[floor], - ElementWiseUnaryOp(floor, [](ReturnT elem_operand) { - return std::floor(elem_operand); - })); + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[floor], + ElementWiseUnaryOp(floor, [](ElementwiseT elem_operand) { + return std::floor(elem_operand); + })); return Status::OK(); } @@ -329,18 +358,40 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { Status HandleLog(HloInstruction* log) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[log], - ElementWiseUnaryOp(log, [](ReturnT elem_operand) { + ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) { return std::log(elem_operand); })); return Status::OK(); } - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> + template ::value && + !std::is_same::value>::type* = nullptr> + Status HandleNot(HloInstruction* not_) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], + ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { + return ~elem_operand; + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> Status HandleNot(HloInstruction* not_) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], - ElementWiseUnaryOp(not_, [](ReturnT elem_operand) { + ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { + return !elem_operand; + })); + return Status::OK(); + } + + template ::value>::type* = + nullptr> + Status HandleNot(HloInstruction* not_) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], + ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { return !elem_operand; })); return Status::OK(); @@ -354,25 +405,47 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } Status HandleNot(HloInstruction* not_) override { - return HandleNot(not_); + return HandleNot(not_); } - Status HandleNegate(HloInstruction* negate) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[negate], - ElementWiseUnaryOp(negate, [](ReturnT elem_operand) { - return -elem_operand; - })); + template ::value && + !std::is_floating_point::value>::type* = nullptr> + Status HandleNegate(HloInstruction* negate) { + using type = typename std::make_unsigned::type; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[negate], + ElementWiseUnaryOp(negate, [](ElementwiseT elem_operand) { + return NativeT(-type(elem_operand)); + })); return Status::OK(); } + template ::value || + std::is_floating_point::value>::type* = nullptr> + Status HandleNegate(HloInstruction* negate) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[negate], + ElementWiseUnaryOp( + negate, [](ElementwiseT elem_operand) { return -elem_operand; })); + return Status::OK(); + } + + Status HandleNegate(HloInstruction* negate) override { + return HandleNegate(negate); + } + template < typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleSign(HloInstruction* sign) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], - ElementWiseUnaryOp(sign, [](ReturnT elem_operand) { - return (ReturnT(0) < elem_operand) - - (elem_operand < ReturnT(0)); + ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { + return (ElementwiseT(0) < elem_operand) - + (elem_operand < ElementwiseT(0)); })); return Status::OK(); } @@ -382,9 +455,9 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value>::type* = nullptr> Status HandleSign(HloInstruction* sign) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], - ElementWiseUnaryOp(sign, [](ReturnT elem_operand) { + ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { auto abs_val = std::abs(elem_operand); - return 0 == abs_val ? ReturnT(0) + return 0 == abs_val ? ElementwiseT(0) : elem_operand / abs_val; })); return Status::OK(); @@ -396,45 +469,71 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { Status HandleTanh(HloInstruction* tanh) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[tanh], - ElementWiseUnaryOp(tanh, [](ReturnT elem_operand) { + ElementWiseUnaryOp(tanh, [](ElementwiseT elem_operand) { return std::tanh(elem_operand); })); return Status::OK(); } - Status HandleMultiply(HloInstruction* multiply) override { + template ::value && + !std::is_floating_point::value>::type* = nullptr> + Status HandleMultiply(HloInstruction* multiply) { + using type = typename std::make_unsigned::type; TF_ASSIGN_OR_RETURN( parent_->evaluated_[multiply], - ElementWiseBinaryOp(multiply, [](ReturnT lhs_elem, ReturnT rhs_elem) { - return lhs_elem * rhs_elem; - })); + ElementWiseBinaryOp(multiply, + [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { + return NativeT(type(lhs_elem) * type(rhs_elem)); + })); return Status::OK(); } + template < + typename NativeT, + typename std::enable_if::value || + std::is_floating_point::value || + is_complex_t::value>::type* = nullptr> + Status HandleMultiply(HloInstruction* multiply) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[multiply], + ElementWiseBinaryOp(multiply, + [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { + return lhs_elem * rhs_elem; + })); + return Status::OK(); + } + + Status HandleMultiply(HloInstruction* multiply) override { + return HandleMultiply(multiply); + } + Status HandleSubtract(HloInstruction* subtract) override { TF_ASSIGN_OR_RETURN( parent_->evaluated_[subtract], - ElementWiseBinaryOp(subtract, [](ReturnT lhs_elem, ReturnT rhs_elem) { - return lhs_elem - rhs_elem; - })); + ElementWiseBinaryOp(subtract, + [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { + return lhs_elem - rhs_elem; + })); return Status::OK(); } Status HandleAdd(HloInstruction* add) override { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[add], - ElementWiseBinaryOp(add, [](ReturnT lhs_elem, ReturnT rhs_elem) { - return lhs_elem + rhs_elem; - })); + TF_ASSIGN_OR_RETURN(parent_->evaluated_[add], + ElementWiseBinaryOp(add, [](ElementwiseT lhs_elem, + ElementwiseT rhs_elem) { + return lhs_elem + rhs_elem; + })); return Status::OK(); } Status HandleDivide(HloInstruction* divide) override { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[divide], - ElementWiseBinaryOp(divide, [](ReturnT lhs_elem, ReturnT rhs_elem) { - return lhs_elem / rhs_elem; - })); + TF_ASSIGN_OR_RETURN(parent_->evaluated_[divide], + ElementWiseBinaryOp(divide, [](ElementwiseT lhs_elem, + ElementwiseT rhs_elem) { + return lhs_elem / rhs_elem; + })); return Status::OK(); } @@ -444,7 +543,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { Status HandleMaximum(HloInstruction* maximum) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[maximum], - ElementWiseBinaryOp(maximum, [](ReturnT lhs, ReturnT rhs) { + ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) { return std::fmax(lhs, rhs); })); return Status::OK(); @@ -458,18 +557,18 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } Status HandleMaximum(HloInstruction* maximum) override { - return HandleMaximum(maximum); + return HandleMaximum(maximum); } template < typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleMinimum(HloInstruction* minimum) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[minimum], - ElementWiseBinaryOp(minimum, [](ReturnT lhs_el, ReturnT rhs_el) { - return std::fmin(lhs_el, rhs_el); - })); + TF_ASSIGN_OR_RETURN(parent_->evaluated_[minimum], + ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return std::fmin(lhs_el, rhs_el); + })); return Status::OK(); } @@ -481,15 +580,15 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } Status HandleMinimum(HloInstruction* minimum) override { - return HandleMinimum(minimum); + return HandleMinimum(minimum); } Status HandlePower(HloInstruction* power) override { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[power], - ElementWiseBinaryOp(power, [](ReturnT lhs_el, ReturnT rhs_el) { - return std::pow(lhs_el, rhs_el); - })); + TF_ASSIGN_OR_RETURN(parent_->evaluated_[power], + ElementWiseBinaryOp(power, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return std::pow(lhs_el, rhs_el); + })); return Status::OK(); } @@ -497,11 +596,11 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleRemainder(HloInstruction* remainder) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[remainder], - ElementWiseBinaryOp(remainder, [](ReturnT lhs_el, ReturnT rhs_el) { - return std::fmod(lhs_el, rhs_el); - })); + TF_ASSIGN_OR_RETURN(parent_->evaluated_[remainder], + ElementWiseBinaryOp(remainder, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return std::fmod(lhs_el, rhs_el); + })); return Status::OK(); } @@ -513,16 +612,27 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } Status HandleRemainder(HloInstruction* remainder) override { - return HandleRemainder(remainder); + return HandleRemainder(remainder); } - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> + template ::value>::type* = + nullptr> Status HandleAnd(HloInstruction* and_) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[and_], - ElementWiseBinaryOp(and_, [](ReturnT lhs_el, ReturnT rhs_el) { + ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { + return lhs_el & rhs_el; + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleAnd(HloInstruction* and_) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[and_], + ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { return lhs_el && rhs_el; })); return Status::OK(); @@ -536,16 +646,27 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } Status HandleAnd(HloInstruction* and_) override { - return HandleAnd(and_); + return HandleAnd(and_); } - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> + template ::value>::type* = + nullptr> Status HandleOr(HloInstruction* or_) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[or_], - ElementWiseBinaryOp(or_, [](ReturnT lhs_el, ReturnT rhs_el) { + ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { + return lhs_el | rhs_el; + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleOr(HloInstruction* or_) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[or_], + ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { return lhs_el || rhs_el; })); return Status::OK(); @@ -559,7 +680,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } Status HandleOr(HloInstruction* or_) override { - return HandleOr(or_); + return HandleOr(or_); } template (shl); + return HandleShiftLeft(shl); } template (shra); + return HandleShiftRightArithmetic(shra); } template (shrl); + return HandleShiftRightLogical(shrl); } template < typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleClamp(HloInstruction* clamp) { - std::function clamp_op = - [](ReturnT low, ReturnT high, ReturnT value) { + std::function + clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) { return std::fmax(low, std::fmin(value, high)); }; - TF_ASSIGN_OR_RETURN(parent_->evaluated_[clamp], - ElementWiseTernaryOp(clamp, std::move(clamp_op))); + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[clamp], + ElementwiseTernaryOp(clamp, + std::move(ConvertTernaryFunction(clamp_op)))); return Status::OK(); } @@ -661,7 +784,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } Status HandleClamp(HloInstruction* clamp) override { - return HandleClamp(clamp); + return HandleClamp(clamp); } Status HandleSelect(HloInstruction* select) override { @@ -674,7 +797,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return on_false; }; TF_ASSIGN_OR_RETURN(parent_->evaluated_[select], - ElementWiseTernaryOp(select, std::move(select_op))); + ElementwiseTernaryOp(select, std::move(select_op))); return Status::OK(); } @@ -724,7 +847,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { CHECK(ShapeUtil::SameElementType(lhs_shape, result_shape)); const auto& dnums = conv->convolution_dimension_numbers(); - const int64 num_spatial_dims = dnums.spatial_dimensions_size(); + const int64 num_spatial_dims = dnums.output_spatial_dimensions_size(); + CHECK_EQ(num_spatial_dims, dnums.input_spatial_dimensions_size()); CHECK_EQ(num_spatial_dims, dnums.kernel_spatial_dimensions_size()); CHECK_GE(num_spatial_dims, 0); CHECK_EQ(window.dimensions_size(), num_spatial_dims); @@ -771,7 +895,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size()); auto func = [&](tensorflow::gtl::ArraySlice out_index) { - ReturnT result_val = static_cast(0); + ElementwiseT result_val = static_cast(0); std::fill(lhs_index.begin(), lhs_index.end(), 0); std::fill(rhs_index.begin(), rhs_index.end(), 0); @@ -789,13 +913,15 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { // Find corresponding spatial dimension index for input (lhs). for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) { // Spatial dimension number for input (lhs) and output. - const int64 spatial_dim = dnums.spatial_dimensions(ki); + const int64 input_spatial_dim = dnums.input_spatial_dimensions(ki); + const int64 output_spatial_dim = + dnums.output_spatial_dimensions(ki); // Calculate lhs (input) index without taking base dilation into // account. const auto& window_dim = window.dimensions(ki); const int64 undilated_index = - out_index[spatial_dim] * window_dim.stride() - + out_index[output_spatial_dim] * window_dim.stride() - window_dim.padding_low() + rhs_spatial_index[ki] * window_dim.window_dilation(); // Skip if the lhs (input) index is to be dilated. @@ -804,26 +930,30 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } // Calculate the actual lhs (input) index after dilation. - lhs_index[spatial_dim] = + lhs_index[input_spatial_dim] = undilated_index / window_dim.base_dilation(); // Skip if input index is not in bound. - if (!(lhs_index[spatial_dim] >= 0 && - lhs_index[spatial_dim] < lhs_shape.dimensions(spatial_dim))) { + if (!(lhs_index[input_spatial_dim] >= 0 && + lhs_index[input_spatial_dim] < + lhs_shape.dimensions(input_spatial_dim))) { goto cnt; } rhs_index[dnums.kernel_spatial_dimensions(ki)] = - rhs_spatial_index[ki]; + window_dim.window_reversal() + ? ((window_dim.size() - 1) - rhs_spatial_index[ki]) + : rhs_spatial_index[ki]; } - result_val += lhs_literal.Get(lhs_index) * - rhs_literal.Get(rhs_index); + result_val += + static_cast(lhs_literal.Get(lhs_index)) * + static_cast(rhs_literal.Get(rhs_index)); } - cnt:; + cnt : {} } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index)); - return result_val; + return static_cast(result_val); }; auto result = Literal::CreateFromShape(result_shape); @@ -873,7 +1003,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { auto result = Literal::CreateFromShape(dot->shape()); TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice multi_index) { - ReturnT result_val = static_cast(0); + ElementwiseT result_val = static_cast(0); std::vector lhs_index(lhs_rank, 0); std::vector rhs_index(rhs_rank, 0); @@ -890,11 +1020,12 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { lhs_index[lhs_contracted_dimension] = i; rhs_index[rhs_contracted_dimension] = i; - result_val += lhs_literal.Get(lhs_index) * - rhs_literal.Get(rhs_index); + result_val += + static_cast(lhs_literal.Get(lhs_index)) * + static_cast(rhs_literal.Get(rhs_index)); } - return result_val; + return static_cast(result_val); })); parent_->evaluated_[dot] = std::move(result); @@ -1080,6 +1211,97 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + template + StatusOr> MapImpl(HloInstruction* map) { + auto operands = map->operands(); + HloComputation* computation = map->to_apply(); + + auto result = Literal::CreateFromShape(map->shape()); + + HloEvaluator embedded_evaluator; + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + std::vector> arg_literals; + arg_literals.reserve(operands.size()); + + // Construct scalar literal parameters to be passed to the map + // computation. + for (auto operand : operands) { + const Literal& arg_literal = + parent_->GetEvaluatedLiteralFor(operand); + + auto curr_val = arg_literal.Get(multi_index); + auto curr_val_literal = Literal::CreateR0(curr_val); + + arg_literals.push_back(std::move(curr_val_literal)); + } + + std::unique_ptr computed_result = + embedded_evaluator + .Evaluate>(*computation, + arg_literals) + .ConsumeValueOrDie(); + // Clear visit states so that the we can use the evaluate again on + // the same computation. + embedded_evaluator.ResetVisitStates(); + + return computed_result->Get({}); + })); + return std::move(result); + } + + Status HandleMap(HloInstruction* map) override { + switch (map->operand(0)->shape().element_type()) { + case PRED: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case U8: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case U32: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case U64: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case S8: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case S32: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case S64: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case F32: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case F64: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case C64: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + default: + LOG(FATAL) << "HandleMap: unhandled primitive type for " + "input operand: " + << PrimitiveType_Name( + map->operand(0)->shape().element_type()); + } + + return Status::OK(); + } + Status HandleReduce(HloInstruction* reduce) override { auto arg = reduce->operand(0); auto init_value = reduce->operand(1); @@ -1126,6 +1348,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } } + HloEvaluator embedded_evaluator; // For each resulting dimension, calculate and assign computed value. TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice multi_index) { @@ -1145,13 +1368,12 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { std::vector args = {curr_val_literal.get(), result_val_literal.get()}; - // We need a new visitor for each evaluation, so that the same - // computation can be visited more than once (with different - // inputs). - HloEvaluator embedded_evaluator; std::unique_ptr computed_result = - embedded_evaluator.Evaluate(*function, args) + embedded_evaluator.Evaluate(*function, args) .ConsumeValueOrDie(); + // Clear visit states so that the we can use the evaluate again on + // the same computation. + embedded_evaluator.ResetVisitStates(); // Assign computed result to result_val. result_val = computed_result->Get({}); @@ -1208,6 +1430,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { DimensionVector window_index(window.dimensions_size()); DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); + HloEvaluator embedded_evaluator; // For each resulting dimension, calculate and assign computed value. TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice output_index) { @@ -1239,14 +1462,14 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { Literal::CreateR0(result_val); const std::vector args = {curr_val_literal.get(), result_val_literal.get()}; - // We need a new visitor for each evaluation, so that the same - // computation can be visited more than once (with different - // inputs). - HloEvaluator embedded_evaluator; std::unique_ptr computed_result = - embedded_evaluator.Evaluate(*function, args) + embedded_evaluator.Evaluate(*function, args) .ConsumeValueOrDie(); + // Clear visit states so that the we can use the evaluate again on + // the same computation. + embedded_evaluator.ResetVisitStates(); + result_val = computed_result->Get({}); } while (IndexUtil::BumpIndices(window_shape, &window_index)); @@ -1287,6 +1510,50 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + template ::value>::type* = nullptr> + Status HandleSin(HloInstruction* sin) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[sin], + ElementWiseUnaryOp(sin, [](ElementwiseT elem_operand) { + return std::sin(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value || + is_complex_t::value>::type* = nullptr> + Status HandleSin(HloInstruction* sin) { + return InvalidArgument("Unsupported type for Sin"); + } + + Status HandleSin(HloInstruction* sin) override { + return HandleSin(sin); + } + + template ::value>::type* = nullptr> + Status HandleCos(HloInstruction* cos) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[cos], + ElementWiseUnaryOp(cos, [](ElementwiseT elem_operand) { + return std::cos(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value || + is_complex_t::value>::type* = nullptr> + Status HandleCos(HloInstruction* cos) { + return InvalidArgument("Unsupported type for Cos"); + } + + Status HandleCos(HloInstruction* cos) override { + return HandleCos(cos); + } + private: template StatusOr> DynamicSlice( @@ -1349,22 +1616,27 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { StatusOr> ElementWiseUnaryOp( HloInstruction* instruction, - const std::function& unary_op) { + const std::function& unary_op) { const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(instruction->operand(0)); - return ElementWiseUnaryOpImpl(instruction, unary_op, - operand_literal); + TF_ASSIGN_OR_RETURN( + auto result_literal, + (ElementWiseUnaryOpImpl( + instruction, ConvertUnaryFunction(unary_op), operand_literal))); + + return std::move(result_literal); } StatusOr> ElementWiseBinaryOp( HloInstruction* instruction, - const std::function& binary_op) { + const std::function& + binary_op) { const auto shape = instruction->shape(); const auto* lhs = instruction->operand(0); const auto* rhs = instruction->operand(1); - // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is - // removed. + // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast + // is removed. if (!(ShapeUtil::SameDimensions(shape, rhs->shape()) && ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) { return Unimplemented( @@ -1382,14 +1654,15 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice multi_index) { - return binary_op(lhs_literal.Get(multi_index), - rhs_literal.Get(multi_index)); + return ConvertBinaryFunction(binary_op)( + lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index)); })); return std::move(result); } template - StatusOr> ElementWiseTernaryOp( + StatusOr> ElementwiseTernaryOp( HloInstruction* instruction, const std::function& ternary_op) { const auto shape = instruction->shape(); @@ -1397,8 +1670,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { const auto* rhs = instruction->operand(1); const auto* ehs = instruction->operand(2); - // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is - // removed. + // TODO(b/35950897, b/27796129): add DCHECK back once implicit + // broadcast is removed. if (!(ShapeUtil::SameDimensions(shape, lhs->shape()) && ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()) && ShapeUtil::SameDimensions(rhs->shape(), ehs->shape()))) { @@ -1451,9 +1724,11 @@ HloEvaluator::HloEvaluator() { typed_visitors_[F64] = MakeUnique>(this); typed_visitors_[C64] = MakeUnique>(this); - typed_visitors_[BF16] = MakeUnique([](HloInstruction*) { - return Unimplemented("HloEvaluator: unhandled primitive type: BF16."); - }); + // Most of the evaluator computations we use don't support BF16 (e.g., + // std::ceil, std::tanh). To make evaluator work with BF16, we set all + // elementwise computations to be done in F32 and do BF16<->F32 conversion + // around the input and the output of the computations. + typed_visitors_[BF16] = MakeUnique>(this); typed_visitors_[TUPLE] = MakeUnique([](HloInstruction*) { return Unimplemented("HloEvaluator: unhandled primitive type: TUPLE."); }); @@ -1462,13 +1737,17 @@ HloEvaluator::HloEvaluator() { }); } +template StatusOr> HloEvaluator::Evaluate( const HloModule& module, - tensorflow::gtl::ArraySlice arg_literals) { + tensorflow::gtl::ArraySlice arg_literals) { XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString()); - arg_literals_ = arg_literals; evaluated_.clear(); + arg_literals_.clear(); + for (const auto& literal_ptr : arg_literals) { + arg_literals_.push_back(&*literal_ptr); + } TF_RETURN_IF_ERROR(module.entry_computation()->Accept(this)); @@ -1476,27 +1755,36 @@ StatusOr> HloEvaluator::Evaluate( GetEvaluatedLiteralFor(module.entry_computation()->root_instruction())); } +template StatusOr> HloEvaluator::Evaluate( const HloComputation& computation, - tensorflow::gtl::ArraySlice arg_literals) { + tensorflow::gtl::ArraySlice arg_literals) { XLA_VLOG_LINES( 2, "HloEvaluator::Evaluate computation:\n" + computation.ToString()); - arg_literals_ = arg_literals; + evaluated_.clear(); + arg_literals_.clear(); + for (const auto& literal_ptr : arg_literals) { + arg_literals_.push_back(&*literal_ptr); + } TF_RETURN_IF_ERROR(computation.Accept(this)); return MakeUnique( GetEvaluatedLiteralFor(computation.root_instruction())); } +template StatusOr> HloEvaluator::Evaluate( HloInstruction* instruction, - tensorflow::gtl::ArraySlice operands) { + tensorflow::gtl::ArraySlice arg_literals) { TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction)); TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape())); - arg_literals_ = operands; evaluated_.clear(); + arg_literals_.clear(); + for (const auto& literal_ptr : arg_literals) { + arg_literals_.push_back(&*literal_ptr); + } // Evaluate operands of Parameter type against the input literals which // caches the evaluated literal results. @@ -1565,6 +1853,7 @@ StatusOr> HloEvaluator::EvaluateWithSubstitutions( } std::vector operands; + operands.reserve(owned_operands.size()); for (auto& operand : owned_operands) { operands.push_back(operand.get()); } @@ -1583,9 +1872,13 @@ StatusOr> HloEvaluator::EvaluateWithSubstitutions( } Status HloEvaluator::HandleParameter(HloInstruction* parameter) { + CHECK_LT(parameter->parameter_number(), arg_literals_.size()); const Literal* input_literal = arg_literals_[parameter->parameter_number()]; VLOG(2) << "Parameter evaluated to: " << input_literal->ToString(); - DCHECK(ShapeUtil::Equal(parameter->shape(), input_literal->shape())); + DCHECK(ShapeUtil::Equal(parameter->shape(), input_literal->shape())) + << "parameter shape is: " << ShapeUtil::HumanString(parameter->shape()) + << ", but input literal shape is: " + << ShapeUtil::HumanString(input_literal->shape()); evaluated_[parameter] = MakeUnique(*input_literal); return Status::OK(); @@ -1610,8 +1903,8 @@ Status HloEvaluator::HandleTranspose(HloInstruction* transpose) { Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { tensorflow::gtl::ArraySlice operands( concatenate->operands()); - // The result concatenate dimension is going to be the sum of all concatenate - // dimensions of the operands taking part of the operation. + // The result concatenate dimension is going to be the sum of all + // concatenate dimensions of the operands taking part of the operation. const Shape& reference_shape = operands[0]->shape(); CHECK(!ShapeUtil::IsTuple(reference_shape)); const int64 rank = ShapeUtil::Rank(reference_shape); @@ -1821,4 +2114,30 @@ Status HloEvaluator::Postprocess(HloInstruction* hlo) { return Status::OK(); } +// Explicit instantiation of templatized Evaluate* methods. +// +template StatusOr> HloEvaluator::Evaluate< + const Literal*>(const HloModule& module, + tensorflow::gtl::ArraySlice arg_literals); +template StatusOr> +HloEvaluator::Evaluate>( + const HloModule& module, + tensorflow::gtl::ArraySlice> arg_literals); + +template StatusOr> HloEvaluator::Evaluate< + const Literal*>(const HloComputation& computation, + tensorflow::gtl::ArraySlice arg_literals); +template StatusOr> +HloEvaluator::Evaluate>( + const HloComputation& computation, + tensorflow::gtl::ArraySlice> arg_literals); + +template StatusOr> HloEvaluator::Evaluate< + const Literal*>(HloInstruction* instruction, + tensorflow::gtl::ArraySlice arg_literals); +template StatusOr> +HloEvaluator::Evaluate>( + HloInstruction* instruction, + tensorflow::gtl::ArraySlice> arg_literals); + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 7557aaa2484d184555411a79d8dce2c9241427b0..02bb8b0a47065c359603a113f49626bf3ad344d8 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -42,9 +42,12 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Precondition: The indices of arg_literals correspond to the parameter // numbers of the HLO parameters in the computation. See comment below for an // example. + // `LiteralPtr` accepts either std::unique_ptr or const Literal* + // type. + template StatusOr> Evaluate( const HloModule& module, - tensorflow::gtl::ArraySlice arg_literals); + tensorflow::gtl::ArraySlice arg_literals); // Evaluates an HLO computation and an array of pointers to literals. // Returns the evaluated result as a literal if successful. @@ -62,9 +65,12 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // where Parameter0 has parameter_number 0 and Parameter1 has parameter_number // 1 in this computation. The input literals array will then have its first // literal map to Parameter0 and the second map to Parameter1. + // `LiteralPtr` accepts either std::unique_ptr or const Literal* + // type. + template StatusOr> Evaluate( const HloComputation& computation, - tensorflow::gtl::ArraySlice arg_literals); + tensorflow::gtl::ArraySlice arg_literals); // Evaluates a single HLO instruction and an array of pointers to literals. // Return the evaluated result as literal if successful. @@ -72,10 +78,12 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // 1. argument literals correspond to the input instruction's parameters in // their post-ordering. // 2. the instruction's operands must be of either Parameter or Constant type. - // TODO(b/35950897): implement more ops other than element-wise ops. + // `LiteralPtr` accepts either std::unique_ptr or const Literal* + // type. + template StatusOr> Evaluate( HloInstruction* instruction, - tensorflow::gtl::ArraySlice arg_literals); + tensorflow::gtl::ArraySlice arg_literals); // Evaluates a single HLO instruction with constant operands. // Returns the evaluated result as literal if successful. @@ -100,12 +108,16 @@ class HloEvaluator : public DfsHloVisitorWithDefault { protected: // Templated DfsHloVisitor. Typically ReturnT here indicates the resulting // literal type of each evaluated Handle* method of a TypedVisitor. - // There are however a few notable exceptions to this is rule, notably: + // There are however a few notable exceptions to this rule, notably: // - HandleCompare and HandleIsFinite: where the resulting literal type is // always boolean. // These operations are handled outside of the parent HloEvaluator handlers // instead of from within TypedVisitor. - template + // + // Type params: + // - ReturnT: The type of input and output of each operation. + // - ElementwiseT: The type in which internal computation are done. + template class TypedVisitor; // Wraps around instruction handling to infer types before dispatching to @@ -134,6 +146,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleIsFinite(HloInstruction* is_finite) override; Status HandleCompare(HloInstruction* compare) override; + Status HandleTuple(HloInstruction* tuple) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; @@ -167,13 +180,15 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // TODO(b/35950897): have better memory management here to free instructions // that are no longer a parent for any other subsequent instruction in // post-orderring. + // Must be cleared for each evaluation. tensorflow::gtl::FlatMap> evaluated_; - // Stores input literals, assuming they are in post-order. Literals are not - // owned by this class, and they must outlive the lifetime of the instance of - // this class. - tensorflow::gtl::ArraySlice arg_literals_; + // Caches pointers to input literals, assuming they are in post-order. + // Literals are not owned by this class, and they must outlive the lifetime of + // each invocation to the Evaluate* method. + // Must be cleared for each evaluation. + std::vector arg_literals_; TF_DISALLOW_COPY_AND_ASSIGN(HloEvaluator); }; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 85477af6fe26f53504c07204348566c16a24392c..97697d06b73e606351ab8dff638483aa7d844bfc 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -25,8 +25,10 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -35,46 +37,124 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace { -class HloEvaluatorTest : public HloVerifiedTestBase { +static std::array use_bf16_params{true, false}; + +class HloEvaluatorTest : public ::testing::WithParamInterface, + public HloVerifiedTestBase { protected: - HloEvaluatorTest() { evaluator_ = MakeUnique(); } + HloEvaluatorTest() : use_bfloat16_(GetParam()) { + evaluator_ = MakeUnique(); + } + + std::unique_ptr Evaluate( + tensorflow::gtl::ArraySlice arg_literals = {}) { + if (use_bfloat16_) { + // In BF16 mode, we convert all F32 type to BF16 and evaluate the module. + auto type_converter = HloElementTypeConverter(F32, BF16); + type_converter.Run(&module()).ValueOrDie(); + } + return evaluator_->Evaluate(*module().entry_computation(), arg_literals) + .ConsumeValueOrDie(); + } std::unique_ptr evaluator_; + + void TestUnaryOp(HloOpcode opcode, std::unique_ptr expected, + std::unique_ptr input, float aabs = 0) { + HloComputation::Builder b(TestName()); + auto c1 = + b.AddInstruction(HloInstruction::CreateConstant(std::move(input))); + b.AddInstruction( + HloInstruction::CreateUnary(expected->shape(), opcode, c1)); + module().AddEntryComputation(b.Build()); + + std::unique_ptr result = Evaluate(); + + auto element_type = expected->shape().element_type(); + if (element_type == F32 || element_type == F64) { + ErrorSpec error(aabs); + LiteralTestUtil::ExpectNear(*expected, *result, error); + } else { + LiteralTestUtil::ExpectEqual(*expected, *result); + } + } + + void TestBinaryOp(HloOpcode opcode, std::unique_ptr expected, + std::unique_ptr lhs, + std::unique_ptr rhs) { + HloComputation::Builder b(TestName()); + auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs))); + auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs))); + b.AddInstruction( + HloInstruction::CreateBinary(expected->shape(), opcode, c1, c2)); + module().AddEntryComputation(b.Build()); + + std::unique_ptr result = Evaluate(); + + LiteralTestUtil::ExpectEqual(*expected, *result); + } + + bool use_bfloat16_; }; +#define XLA_TYPED_TEST_P(test_case_name, test_name, test_type1) \ + TEST_P(test_case_name, test_name) + // Verifies that HloEvaluator evaluates a HLO instruction that performs clamp // with 3 operands. -TEST_F(HloEvaluatorTest, DoesClamp) { +TEST_P(HloEvaluatorTest, DoesClamp) { auto low = Literal::CreateR2({{0.f, 2.f}, {2.f, 4.f}}); - auto high = Literal::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); auto value = Literal::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); + auto high = Literal::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); Shape shape = low->shape(); HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low))); - auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high))); - auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value))); - auto instruction = b.AddInstruction( + auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value))); + auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high))); + b.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({{0, 4}, {2, 4}}); LiteralTestUtil::ExpectEqual(*expected, *result); } +TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { + auto low = Literal::CreateR0(0.f); + auto value = Literal::CreateR2({{-1.f, 0.f}, {1.f, 2.f}}); + auto high = Literal::CreateR0(1.f); + + Shape shape = value->shape(); + HloComputation::Builder b(TestName()); + auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low))); + auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value))); + auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high))); + b.AddInstruction( + HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); + module().AddEntryComputation(b.Build()); + + std::unique_ptr result = Evaluate(); + + auto expected = Literal::CreateR2({{0, 0}, {1, 1}}); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + // Verifies that HloEvaluator evaluates a HLO instruction that performs select // with 3 operands. -TEST_F(HloEvaluatorTest, DoesSelect) { +TEST_P(HloEvaluatorTest, DoesSelect) { auto pred = Literal::CreateR2({{true, false}, {false, true}}); auto on_true = Literal::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); auto on_false = Literal::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); @@ -86,12 +166,11 @@ TEST_F(HloEvaluatorTest, DoesSelect) { b.AddInstruction(HloInstruction::CreateConstant(std::move(on_true))); auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(on_false))); - auto instruction = b.AddInstruction( + b.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kSelect, c1, c2, c3)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate({}); auto expected = Literal::CreateR2({{2, 5}, {0, 4}}); @@ -100,126 +179,108 @@ TEST_F(HloEvaluatorTest, DoesSelect) { // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise addition with 2 operands. -TEST_F(HloEvaluatorTest, DoesAdd) { +TEST_P(HloEvaluatorTest, DoesAdd) { auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); - - Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); - HloComputation::Builder b(TestName()); - auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs))); - auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs))); - auto instruction = b.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, c1, c2)); - module().AddEntryComputation(b.Build()); - - std::unique_ptr result = - evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); - auto expected = Literal::CreateR2({{3, 4}, {-96, 8}}); - - LiteralTestUtil::ExpectEqual(*expected, *result); + TestBinaryOp(HloOpcode::kAdd, std::move(expected), std::move(lhs), + std::move(rhs)); +} +// Verifies that HloEvaluator evaluates a HLO instruction that performs +// element-wise and with 2 operands. +TEST_P(HloEvaluatorTest, DoesAnd) { + auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); + auto expected = Literal::CreateR2({{0, 0}, {4, 4}}); + TestBinaryOp(HloOpcode::kAnd, std::move(expected), std::move(lhs), + std::move(rhs)); +} +// Verifies that HloEvaluator evaluates a HLO instruction that performs +// element-wise or with 2 operands. +TEST_P(HloEvaluatorTest, DoesOr) { + auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); + auto expected = Literal::CreateR2({{3, 4}, {-100, 4}}); + TestBinaryOp(HloOpcode::kOr, std::move(expected), std::move(lhs), + std::move(rhs)); +} +// Verifies that HloEvaluator evaluates a HLO instruction that performs +// element-wise multiply with 2 operands. +TEST_P(HloEvaluatorTest, DoesMultiply) { + auto lhs = Literal::CreateR2({{-1, 0}, {-100, 4}}); + auto rhs = Literal::CreateR2( + {{std::numeric_limits::min(), 4}, {4, 4}}); + auto expected = Literal::CreateR2( + {{std::numeric_limits::min(), 0}, {-400, 16}}); + TestBinaryOp(HloOpcode::kMultiply, std::move(expected), std::move(lhs), + std::move(rhs)); } - // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise divide with 2 operands. -TEST_F(HloEvaluatorTest, DoesDivideInt64) { - auto lhs_s64 = Literal::CreateR2({{1, 0}, {-100, 4}}); - auto rhs_s64 = Literal::CreateR2({{2, 4}, {4, 4}}); - - Shape shape_s64 = ShapeUtil::MakeShape(S64, {2, 2}); - HloComputation::Builder b(TestName()); - auto c1_s64 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_s64))); - auto c2_s64 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_s64))); - auto instruction = b.AddInstruction(HloInstruction::CreateBinary( - shape_s64, HloOpcode::kDivide, c1_s64, c2_s64)); - module().AddEntryComputation(b.Build()); - - std::unique_ptr result = - evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); - +TEST_P(HloEvaluatorTest, DoesDivideInt64) { + auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); auto expected = Literal::CreateR2({{0, 0}, {-25, 1}}); - - LiteralTestUtil::ExpectEqual(*expected, *result); + TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs), + std::move(rhs)); } -TEST_F(HloEvaluatorTest, DoesDivideDouble) { - auto lhs_f64 = Literal::CreateR2({{1.0, 0.0}, {-100.0, 4.0}}); - auto rhs_f64 = Literal::CreateR2({{2.2, 4.0}, {4.0, 4.0}}); - - Shape shape_f64 = ShapeUtil::MakeShape(F64, {2, 2}); - HloComputation::Builder b(TestName()); - auto c1_f64 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_f64))); - auto c2_f64 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_f64))); - auto instruction = b.AddInstruction(HloInstruction::CreateBinary( - shape_f64, HloOpcode::kDivide, c1_f64, c2_f64)); - module().AddEntryComputation(b.Build()); - - auto result = evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); - +TEST_P(HloEvaluatorTest, DoesDivideDouble) { + auto lhs = Literal::CreateR2({{1.0, 0.0}, {-100.0, 4.0}}); + auto rhs = Literal::CreateR2({{2.2, 4.0}, {4.0, 4.0}}); auto expected = Literal::CreateR2({{0.45454545454545453, 0}, {-25, 1}}); - - LiteralTestUtil::ExpectEqual(*expected, *result); + TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs), + std::move(rhs)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise abs op with 1 operand. -TEST_F(HloEvaluatorTest, DoesAbsR2) { +TEST_P(HloEvaluatorTest, DoesAbsR2) { auto operand = Literal::CreateR2({{1, -20}, {-100, 4}}); - const Shape& shape = ShapeUtil::MakeShape(S64, {2, 2}); - HloComputation::Builder b(TestName()); - auto c1 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(operand))); - auto instruction = - b.AddInstruction(HloInstruction::CreateUnary(shape, HloOpcode::kAbs, c1)); - module().AddEntryComputation(b.Build()); - - std::unique_ptr result = - evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); - auto expected = Literal::CreateR2({{1, 20}, {100, 4}}); - - LiteralTestUtil::ExpectEqual(*expected, *result); + TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } -TEST_F(HloEvaluatorTest, DoesAbsR0) { - // For R0 literal. - const Shape& r0 = ShapeUtil::MakeShape(F32, {}); +TEST_P(HloEvaluatorTest, DoesAbsR0) { auto operand = Literal::CreateR0(-1.0f); - HloComputation::Builder b(TestName()); - auto c1 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(operand))); - auto instruction = - b.AddInstruction(HloInstruction::CreateUnary(r0, HloOpcode::kAbs, c1)); - module().AddEntryComputation(b.Build()); - - auto result = evaluator_->Evaluate(instruction).ConsumeValueOrDie(); auto expected = Literal::CreateR0(1.0f); - - LiteralTestUtil::ExpectEqual(*expected, *result); + TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } -TEST_F(HloEvaluatorTest, DoesAbsR1WithZeroSize) { - // For R1 literal with dimension of size 0. - Shape empty_r1 = ShapeUtil::MakeShape(F32, {0}); +TEST_P(HloEvaluatorTest, DoesAbsR1WithZeroSize) { auto operand = Literal::CreateR1({}); - HloComputation::Builder b(TestName()); - auto c1 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(operand))); - auto instruction = b.AddInstruction( - HloInstruction::CreateUnary(empty_r1, HloOpcode::kAbs, c1)); - module().AddEntryComputation(b.Build()); - - auto result = evaluator_->Evaluate(instruction).ConsumeValueOrDie(); auto expected = Literal::CreateR1({}); - - LiteralTestUtil::ExpectEqual(*expected, *result); + TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); +} +TEST_P(HloEvaluatorTest, DoesNegateR2) { + auto operand = Literal::CreateR2( + {{0, std::numeric_limits::min()}, {-1, 4}}); + auto expected = + Literal::CreateR2({{0, std::numeric_limits::min()}, {1, -4}}); + TestUnaryOp(HloOpcode::kNegate, std::move(expected), std::move(operand)); +} +TEST_P(HloEvaluatorTest, DoesCosR2) { + auto operand = Literal::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); + auto expected = Literal::CreateR2({{1, -1}, {-1, 1}}); + TestUnaryOp(HloOpcode::kCos, std::move(expected), std::move(operand), + use_bfloat16_ ? 0x1.0P-5 : 0x1.0P-20); +} +TEST_P(HloEvaluatorTest, DoesSinR2) { + auto operand = Literal::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); + auto expected = Literal::CreateR2({{0, 0}, {0, 0}}); + TestUnaryOp(HloOpcode::kSin, std::move(expected), std::move(operand), + use_bfloat16_ ? 0x1.0P-5 : 0x1.0P-20); +} +TEST_P(HloEvaluatorTest, DoesNotR2) { + auto operand = + Literal::CreateR2({{0, std::numeric_limits::min()}, + {-1, std::numeric_limits::max()}}); + auto expected = + Literal::CreateR2({{-1, std::numeric_limits::max()}, + {0, std::numeric_limits::min()}}); + TestUnaryOp(HloOpcode::kNot, std::move(expected), std::move(operand)); } - // Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor // constant operands. -TEST_F(HloEvaluatorTest, DoesTraverseInstructions) { +TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); auto rhs2 = Literal::CreateR2({{1, -20}, {-100, 4}}); @@ -239,10 +300,9 @@ TEST_F(HloEvaluatorTest, DoesTraverseInstructions) { b.AddInstruction(HloInstruction::CreateParameter(2, shape, "rhs2")); b.AddInstruction(HloInstruction::CreateBinary(shape, HloOpcode::kAdd, lhs_instruction, param_rhs2)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, args).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(args); auto expected = Literal::CreateR2({{4, -16}, {-196, 12}}); @@ -250,7 +310,7 @@ TEST_F(HloEvaluatorTest, DoesTraverseInstructions) { } // Verifies Reshape operation is correctly evaluated. -TEST_F(HloEvaluatorTest, DoesReshape) { +TEST_P(HloEvaluatorTest, DoesReshape) { HloComputation::Builder b(TestName()); const int64 dimensions[] = {11, 8, 7, 5, 9}; TF_ASSERT_OK_AND_ASSIGN(auto literal, @@ -264,21 +324,20 @@ TEST_F(HloEvaluatorTest, DoesReshape) { const int64 permutation[] = {1, 2, 0, 4, 3}; b.AddInstruction( HloInstruction::CreateTranspose(shape, literal_instruction, permutation)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate({}); using NativeT = typename primitive_util::PrimitiveTypeToNative::type; result->EachCell( [&](tensorflow::gtl::ArraySlice indices, NativeT value) { std::vector rindexes = Permute(permutation, indices); - EXPECT_TRUE(value == literal_clone->Get(rindexes)); + EXPECT_NEAR(value, literal_clone->Get(rindexes), 0x1.0P-5); }); } // Verifies Broadcast operation is correctly evaluated. -TEST_F(HloEvaluatorTest, DoesBroadcast) { +TEST_P(HloEvaluatorTest, DoesBroadcast) { HloComputation::Builder b(TestName()); auto input_literal = Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}); auto output_literal = Literal::CreateR3( @@ -287,15 +346,14 @@ TEST_F(HloEvaluatorTest, DoesBroadcast) { HloInstruction::CreateConstant(std::move(input_literal))); b.AddInstruction(HloInstruction::CreateBroadcast( output_literal->shape(), literal_instruction, {1, 2})); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate({}); LiteralTestUtil::ExpectEqual(*result, *output_literal); } -TEST_F(HloEvaluatorTest, DoesBroadcastScalar) { +TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { HloComputation::Builder b(TestName()); auto input_literal = Literal::CreateR0(111); auto output_literal = Literal::CreateR2( @@ -307,15 +365,14 @@ TEST_F(HloEvaluatorTest, DoesBroadcastScalar) { b.AddInstruction(HloInstruction::CreateBroadcast( output_literal->shape(), literal_instruction, /*broadcast_dimensions=*/{})); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate({}); LiteralTestUtil::ExpectEqual(*result, *output_literal); } -TEST_F(HloEvaluatorTest, DoesConcatenateSimple) { +TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { HloComputation::Builder b(TestName()); HloInstruction* operand1 = b.AddInstruction(HloInstruction::CreateConstant( @@ -328,17 +385,16 @@ TEST_F(HloEvaluatorTest, DoesConcatenateSimple) { Shape shape = ShapeUtil::MakeShape(S64, {4, 2}); b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}}); LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { +TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { HloComputation::Builder b(TestName()); HloInstruction* operand1 = b.AddInstruction( @@ -351,16 +407,15 @@ TEST_F(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { Shape shape = ShapeUtil::MakeShape(S64, {2}); b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR1({100, 200}); LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, ConvertWithSameLayout) { +TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { HloComputation::Builder b(TestName()); auto input_literal = Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}); @@ -372,15 +427,14 @@ TEST_F(HloEvaluatorTest, ConvertWithSameLayout) { HloInstruction* constant = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); LiteralTestUtil::ExpectEqual(*result, *expected); } -TEST_F(HloEvaluatorTest, ConvertWithDifferentLayout) { +TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { HloComputation::Builder b(TestName()); auto input_literal = Literal::CreateR2WithLayout( @@ -393,10 +447,9 @@ TEST_F(HloEvaluatorTest, ConvertWithDifferentLayout) { HloInstruction* constant = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); LiteralTestUtil::ExpectEqual(*result, *expected); } @@ -414,7 +467,7 @@ PaddingConfig CreatePaddingConfig( return padding_config; } -TEST_F(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { +TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { auto operand = Literal::CreateR2({{}, {}}); HloComputation::Builder b(TestName()); auto operand_instruction = @@ -427,11 +480,11 @@ TEST_F(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { auto padding_config = CreatePaddingConfig({{{1, 0, 2}}, {{0, 2, 1}}}); Shape shape = ShapeUtil::MakeShape(S32, {5, 2}); - auto pad_instruction = b.AddInstruction(HloInstruction::CreatePad( + b.AddInstruction(HloInstruction::CreatePad( shape, operand_instruction, padding_value_instruction, padding_config)); module().AddEntryComputation(b.Build()); - auto result = evaluator_->Evaluate(pad_instruction).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2( {{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}}); @@ -439,7 +492,7 @@ TEST_F(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { +TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { HloComputation::Builder b(TestName()); Array4D input_array(3, 2, 1, 1, {1, 2, 3, 4, 5, 6}); @@ -456,10 +509,9 @@ TEST_F(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { CreatePaddingConfig({{{1, 0, 2}}, {{0, 2, 1}}, {{0, 0, 0}}, {{0, 0, 0}}}); b.AddInstruction(HloInstruction::CreatePad( shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected_array = MakeUnique>(8, 5, 1, 1); expected_array->Fill(kPadValue); @@ -475,7 +527,7 @@ TEST_F(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, NegativePadding2D) { +TEST_P(HloEvaluatorTest, NegativePadding2D) { HloComputation::Builder b(TestName()); // input_array: @@ -501,10 +553,9 @@ TEST_F(HloEvaluatorTest, NegativePadding2D) { pad_value_instruction, r2_padding_on_dim0_dim1)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); // f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 } auto expected_array = MakeUnique>(1, 5); @@ -515,10 +566,10 @@ TEST_F(HloEvaluatorTest, NegativePadding2D) { (*expected_array)(0, 4) = 2.718f; auto expected = Literal::CreateR2FromArray2D(*expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(0x1.0P-5)); } -TEST_F(HloEvaluatorTest, NegativeAndInteriorPadding2D) { +TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { HloComputation::Builder b(TestName()); // f32[4,3] { @@ -547,10 +598,9 @@ TEST_F(HloEvaluatorTest, NegativeAndInteriorPadding2D) { pad_value_instruction, r2_padding_on_dim0_dim1)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected_array = MakeUnique>(0, 9); auto expected = Literal::CreateR2FromArray2D(*expected_array); @@ -558,7 +608,7 @@ TEST_F(HloEvaluatorTest, NegativeAndInteriorPadding2D) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, DotRank2AndRank1) { +TEST_P(HloEvaluatorTest, DotRank2AndRank1) { HloComputation::Builder b(TestName()); // lhs: @@ -581,12 +631,14 @@ TEST_F(HloEvaluatorTest, DotRank2AndRank1) { b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); Shape shape = ShapeUtil::MakeShape(F32, {4, 2}); - b.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, lhs_instruction, rhs_instruction)); - auto computation = module().AddEntryComputation(b.Build()); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, + rhs_instruction, dot_dnums)); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); // clang-format off auto expected_array = Array2D({ @@ -601,7 +653,7 @@ TEST_F(HloEvaluatorTest, DotRank2AndRank1) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, DotRank1AndRank2) { +TEST_P(HloEvaluatorTest, DotRank1AndRank2) { HloComputation::Builder b(TestName()); // lhs: @@ -624,19 +676,21 @@ TEST_F(HloEvaluatorTest, DotRank1AndRank2) { b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); Shape shape = ShapeUtil::MakeShape(F32, {2}); - b.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, lhs_instruction, rhs_instruction)); - auto computation = module().AddEntryComputation(b.Build()); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, + rhs_instruction, dot_dnums)); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR1({22.f, 28.f}); LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, DotRank2AndRank2) { +TEST_P(HloEvaluatorTest, DotRank2AndRank2) { HloComputation::Builder b(TestName()); // lhs: @@ -665,12 +719,14 @@ TEST_F(HloEvaluatorTest, DotRank2AndRank2) { b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); Shape shape = ShapeUtil::MakeShape(F32, {4, 2}); - b.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, lhs_instruction, rhs_instruction)); - auto computation = module().AddEntryComputation(b.Build()); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, + rhs_instruction, dot_dnums)); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected_array = Array2D({ {22.f, 28.f}, @@ -683,7 +739,7 @@ TEST_F(HloEvaluatorTest, DotRank2AndRank2) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, SimpleConv1D) { +TEST_P(HloEvaluatorTest, SimpleConv1D) { HloComputation::Builder b(TestName()); Array3D lhs_array = {{{1, 2, 3}}}; @@ -711,7 +767,8 @@ TEST_F(HloEvaluatorTest, SimpleConv1D) { dnums.set_output_batch_dimension(0); dnums.set_input_feature_dimension(1); dnums.set_output_feature_dimension(1); - dnums.add_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); dnums.set_kernel_output_feature_dimension(0); dnums.set_kernel_input_feature_dimension(1); @@ -720,10 +777,9 @@ TEST_F(HloEvaluatorTest, SimpleConv1D) { const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 3}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); Array3D expected_array = {{{11.f, 18.f, 9.f}}}; auto expected = Literal::CreateR3FromArray3D(expected_array); @@ -731,7 +787,7 @@ TEST_F(HloEvaluatorTest, SimpleConv1D) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { +TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { HloComputation::Builder b(TestName()); Array4D lhs_array(1, 1, 4, 4); @@ -775,10 +831,9 @@ TEST_F(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); Array4D expected_array(1, 1, 4, 4); // clang-format off @@ -794,7 +849,7 @@ TEST_F(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, Conv2DGeneralDimensions) { +TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { HloComputation::Builder b(TestName()); // clang-format off @@ -826,6 +881,8 @@ TEST_F(HloEvaluatorTest, Conv2DGeneralDimensions) { auto rhs_literal = Literal::CreateR4FromArray4D(weight); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); + rhs_instruction = b.AddInstruction(HloInstruction::CreateReverse( + rhs_instruction->shape(), rhs_instruction, {3, 1})); Window window; WindowDimension dim; @@ -835,6 +892,7 @@ TEST_F(HloEvaluatorTest, Conv2DGeneralDimensions) { dim.set_padding_high(0); dim.set_window_dilation(1); dim.set_base_dilation(1); + dim.set_window_reversal(true); *window.add_dimensions() = dim; *window.add_dimensions() = dim; @@ -843,8 +901,10 @@ TEST_F(HloEvaluatorTest, Conv2DGeneralDimensions) { dnums.set_output_batch_dimension(2); dnums.set_input_feature_dimension(0); dnums.set_output_feature_dimension(0); - dnums.add_spatial_dimensions(1); - dnums.add_spatial_dimensions(3); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(3); + dnums.add_output_spatial_dimensions(3); dnums.set_kernel_output_feature_dimension(0); dnums.set_kernel_input_feature_dimension(2); @@ -854,21 +914,99 @@ TEST_F(HloEvaluatorTest, Conv2DGeneralDimensions) { const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); // clang-format off // Result dimensions: [feature=1, height=1, batch=1, width=2] Array4D expected_array({{{{2514, 2685}}}}); + Array4D expected_array_bf16({{{{2512, 2672}}}}); // clang-format on - auto expected = Literal::CreateR4FromArray4D(expected_array); + auto expected = Literal::CreateR4FromArray4D( + use_bfloat16_ ? expected_array_bf16 : expected_array); LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { +TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { + HloComputation::Builder b(TestName()); + + // clang-format off + // Input dimensions: [feature=2, height=3, batch=1, width=4] + Array4D input({ + {{{1, 2, 3, 4}}, + {{5, 6, 7, 8}}, + {{9, 10, 11, 12}}}, + {{{13, 14, 15, 16}}, + {{17, 18, 19, 20}}, + {{21, 22, 23, 24}}} + }); + // Weight dimensions: + // [kernel_output_feature=1, width=3, kernel_input_feature=2, height=3] + Array4D weight({{ + {{1, 7, 13}, + {4, 10, 16}}, + {{2, 8, 14}, + {5, 11, 17}}, + {{3, 9, 15}, + {6, 12, 18}} + }}); + // clang-format on + + auto lhs_literal = Literal::CreateR4FromArray4D(input); + HloInstruction* lhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); + + auto rhs_literal = Literal::CreateR4FromArray4D(weight); + HloInstruction* rhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); + + Window window; + WindowDimension dim; + dim.set_size(3); + dim.set_stride(1); + dim.set_padding_low(0); + dim.set_padding_high(0); + dim.set_window_dilation(1); + dim.set_base_dilation(1); + *window.add_dimensions() = dim; + *window.add_dimensions() = dim; + + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(2); + dnums.set_output_batch_dimension(2); + dnums.set_input_feature_dimension(0); + dnums.set_output_feature_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(3); + dnums.add_output_spatial_dimensions(3); + + dnums.set_kernel_output_feature_dimension(0); + dnums.set_kernel_input_feature_dimension(2); + dnums.add_kernel_spatial_dimensions(3); + dnums.add_kernel_spatial_dimensions(1); + + const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); + b.AddInstruction(HloInstruction::CreateConvolve( + shape, lhs_instruction, rhs_instruction, window, dnums)); + module().AddEntryComputation(b.Build()); + + std::unique_ptr result = Evaluate(); + + // clang-format off + // Result dimensions: [feature=1, height=1, batch=1, width=2] + Array4D expected_array({{{{2514, 2685}}}}); + Array4D expected_array_bf16({{{{2512, 2672}}}}); + // clang-format on + auto expected = Literal::CreateR4FromArray4D( + use_bfloat16_ ? expected_array_bf16 : expected_array); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + +TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { HloComputation::Builder b(TestName()); Array4D lhs_array(1, 1, 4, 4); @@ -912,10 +1050,9 @@ TEST_F(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); Array4D expected_array(1, 1, 7, 7); expected_array.FillWithYX(Array2D({ @@ -932,7 +1069,7 @@ TEST_F(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { +TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { HloComputation::Builder b(TestName()); Array4D lhs_array(1, 1, 4, 4); @@ -976,10 +1113,9 @@ TEST_F(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); Array4D expected_array(1, 1, 8, 8); expected_array.FillWithYX(Array2D({ @@ -997,7 +1133,7 @@ TEST_F(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, +TEST_P(HloEvaluatorTest, DilatedWindowAndBaseConv2DWithDifferentLowAndHighPaddingAndStrides) { HloComputation::Builder b(TestName()); @@ -1048,10 +1184,9 @@ TEST_F(HloEvaluatorTest, const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); Array4D expected_array(1, 1, 9, 3); expected_array.FillWithYX(Array2D({ @@ -1070,7 +1205,7 @@ TEST_F(HloEvaluatorTest, LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, ReduceAdd) { +TEST_P(HloEvaluatorTest, ReduceAdd) { HloComputation::Builder b(TestName()); // arg: @@ -1103,17 +1238,16 @@ TEST_F(HloEvaluatorTest, ReduceAdd) { HloInstruction::CreateReduce(shape, arg_instruction, init_value, /*dimensions_to_reduce=*/{1}, add_func)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR1({6, 18}); LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, ReduceWindowMax) { +TEST_P(HloEvaluatorTest, ReduceWindowMax) { HloComputation::Builder b(TestName()); // arg: @@ -1156,15 +1290,15 @@ TEST_F(HloEvaluatorTest, ReduceWindowMax) { b.AddInstruction(HloInstruction::CreateReduceWindow( shape, arg_instruction, init_value, window, max_func)); - auto computation = module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + module().AddEntryComputation(b.Build()); + + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({{6, 7}}); LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, ReduceWindowAdd) { +TEST_P(HloEvaluatorTest, ReduceWindowAdd) { HloComputation::Builder b(TestName()); // arg: @@ -1213,15 +1347,15 @@ TEST_F(HloEvaluatorTest, ReduceWindowAdd) { b.AddInstruction(HloInstruction::CreateReduceWindow( shape, arg_instruction, init_value, window, add_func)); - auto computation = module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + module().AddEntryComputation(b.Build()); + + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({{1, 3, 5}, {5, 11, 13}}); LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, ReduceWindowAdd6D) { +TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { HloComputation::Builder b(TestName()); // arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time. @@ -1274,9 +1408,9 @@ TEST_F(HloEvaluatorTest, ReduceWindowAdd6D) { b.AddInstruction(HloInstruction::CreateReduceWindow( shape, arg_instruction, init_value, window, add_func)); - auto computation = module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + module().AddEntryComputation(b.Build()); + + std::unique_ptr result = Evaluate(); std::vector output_dims = {4, 3, 3, 3, 4, 4}; std::unique_ptr result_literal = @@ -1284,7 +1418,7 @@ TEST_F(HloEvaluatorTest, ReduceWindowAdd6D) { LiteralTestUtil::ExpectEqual(*result_literal, *result); } -TEST_F(HloEvaluatorTest, StridedSlice) { +TEST_P(HloEvaluatorTest, StridedSlice) { HloComputation::Builder b(TestName()); // arg: @@ -1305,10 +1439,9 @@ TEST_F(HloEvaluatorTest, StridedSlice) { /*start_indices=*/{0, 2}, /*limit_indices=*/{3, 5}, /*strides=*/{2, 3})); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({ {3}, @@ -1318,7 +1451,7 @@ TEST_F(HloEvaluatorTest, StridedSlice) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, DynamicSlice) { +TEST_P(HloEvaluatorTest, DynamicSlice) { HloComputation::Builder b(TestName()); // arg: @@ -1339,10 +1472,9 @@ TEST_F(HloEvaluatorTest, DynamicSlice) { Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand, start_indices, {2, 3})); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({ {2, 3, 4}, @@ -1354,7 +1486,7 @@ TEST_F(HloEvaluatorTest, DynamicSlice) { // Verifies that the HloEvaluator's implementation goes along with existing // backends' behavior, although this is not required by the spec. -TEST_F(HloEvaluatorTest, DynamicSliceModSlice) { +TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { HloComputation::Builder b(TestName()); // arg: @@ -1375,10 +1507,9 @@ TEST_F(HloEvaluatorTest, DynamicSliceModSlice) { Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand, start_indices, {2, 3})); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({ {2, 3, 4}, @@ -1388,7 +1519,7 @@ TEST_F(HloEvaluatorTest, DynamicSliceModSlice) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, DynamicSliceUpdate) { +TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { HloComputation::Builder b(TestName()); // arg: @@ -1412,10 +1543,9 @@ TEST_F(HloEvaluatorTest, DynamicSliceUpdate) { Shape shape = ShapeUtil::MakeShape(F64, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( shape, operand, update, start_indices)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({ {1, -2, -3}, @@ -1425,7 +1555,7 @@ TEST_F(HloEvaluatorTest, DynamicSliceUpdate) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, SetAndGetTuples) { +TEST_P(HloEvaluatorTest, SetAndGetTuples) { HloComputation::Builder b(TestName()); // arg: @@ -1448,9 +1578,9 @@ TEST_F(HloEvaluatorTest, SetAndGetTuples) { Shape shape = ShapeUtil::MakeShape(F64, {2, 3}); b.AddInstruction(HloInstruction::CreateGetTupleElement(shape, tuple, 1)); - auto computation = module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + module().AddEntryComputation(b.Build()); + + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({ {1, 2, 3}, @@ -1460,7 +1590,7 @@ TEST_F(HloEvaluatorTest, SetAndGetTuples) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, SetAndGetNestedTuples) { +TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { HloComputation::Builder b(TestName()); // arg: @@ -1487,9 +1617,9 @@ TEST_F(HloEvaluatorTest, SetAndGetNestedTuples) { b.AddInstruction( HloInstruction::CreateGetTupleElement(tuple2->shape(), outer_tuple, 1)); - auto computation = module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + module().AddEntryComputation(b.Build()); + + std::unique_ptr result = Evaluate(); auto result_inner_literal = Literal::CreateR2FromArray2D(*operand_array); @@ -1501,7 +1631,7 @@ TEST_F(HloEvaluatorTest, SetAndGetNestedTuples) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, Reverse) { +TEST_P(HloEvaluatorTest, Reverse) { HloComputation::Builder b(TestName()); // Input shape is float[4x3x2x1]. @@ -1527,10 +1657,9 @@ TEST_F(HloEvaluatorTest, Reverse) { const Shape shape = ShapeUtil::MakeShape(F32, {4, 3, 2, 1}); b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1})); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); // clang-format off auto expected = Literal::CreateR4FromArray4D({ @@ -1555,7 +1684,7 @@ TEST_F(HloEvaluatorTest, Reverse) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, EvaluateWithSubstitutions) { +TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { HloComputation::Builder b(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4}); @@ -1578,7 +1707,7 @@ TEST_F(HloEvaluatorTest, EvaluateWithSubstitutions) { // Check that EvaluateWithSubstitutions works if one of the operands to the op // we're evaluating is a constant. -TEST_F(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { +TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { HloComputation::Builder b(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4}); @@ -1600,5 +1729,8 @@ TEST_F(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { *result.ValueOrDie()); } +INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest, + ::testing::ValuesIn(use_bf16_params)); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc index 755374b91d05f4b6186e75e98847cbd3ffed0e93..0111cfd5a3d7889f80370f9e3e744457bc4091e4 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -40,7 +40,7 @@ HloProfileIndexMap::HloProfileIndexMap(const HloModule& module) { } } -static HloProfilePrinter CreateOwnedHloProfilePrinter( +std::unique_ptr CreateHloProfilePrinter( const HloProfileIndexMap& hlo_profile_index_map, const HloCostAnalysis& cost_analysis) { using HloComputationInfo = HloProfilePrinter::HloComputationInfo; @@ -76,14 +76,14 @@ static HloProfilePrinter CreateOwnedHloProfilePrinter( HloProfilePrinter::HloInstructionInfo* instruction_info = &computation_info->instructions[instruction_index_in_static_data++]; instruction_info->long_name = strdup(hlo->ToString().c_str()); - instruction_info->short_name = - strdup(hlo->ToString(/*compact_operands=*/true).c_str()); + instruction_info->short_name = strdup( + hlo->ToString(HloPrintOptions().set_compact_operands(true)).c_str()); instruction_info->category = strdup(hlo->ToCategory().c_str()); instruction_info->flop_count = cost_analysis.flop_count(*hlo); instruction_info->transcendental_count = cost_analysis.transcendental_count(*hlo); instruction_info->bytes_accessed = cost_analysis.bytes_accessed(*hlo); - instruction_info->seconds = cost_analysis.seconds(*hlo); + instruction_info->optimal_seconds = cost_analysis.optimal_seconds(*hlo); instruction_info->profile_index = hlo_profile_index_map.GetProfileIndexFor(*hlo); CHECK_LT(instruction_info->profile_index, max_profile_index); @@ -108,15 +108,16 @@ static HloProfilePrinter CreateOwnedHloProfilePrinter( delete[] computation_infos; }; - return HloProfilePrinter(computation_infos, - hlo_profile_index_map.computation_count(), deleter); + return MakeUnique( + computation_infos, hlo_profile_index_map.computation_count(), + /*profile_counters_size=*/max_profile_index, deleter); } -HloExecutionProfile::HloExecutionProfile(const HloModule& module, - const HloCostAnalysis& cost_analysis) - : hlo_profile_index_map_(module), - hlo_profile_printer_( - CreateOwnedHloProfilePrinter(hlo_profile_index_map_, cost_analysis)), +HloExecutionProfile::HloExecutionProfile( + const HloProfilePrinter* hlo_profile_printer, + const HloProfileIndexMap* hlo_profile_index_map) + : hlo_profile_printer_(*hlo_profile_printer), + hlo_profile_index_map_(*hlo_profile_index_map), profile_counters_( /*count*/ hlo_profile_index_map_.total_count(), /*value*/ 0) {} @@ -131,10 +132,4 @@ uint64 HloExecutionProfile::GetCyclesTakenBy(const HloInstruction& hlo) const { return profile_counters_[hlo_profile_index_map_.GetProfileIndexFor(hlo)]; } -string HloExecutionProfile::ToString( - const DeviceDescription& device_description) const { - return hlo_profile_printer_.ToString(profile_counters_.data(), - device_description.clock_rate_ghz()); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.h b/tensorflow/compiler/xla/service/hlo_execution_profile.h index 84702680c0c40335098530c4b1fdb164bb7f9374..470fd4ce3c205d84152238f4b18daad77e403f68 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.h +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.h @@ -77,6 +77,11 @@ class HloProfileIndexMap { std::unordered_map computation_to_profile_idx_; }; +// Create an instance of `HloProfilePrinter` that owns its memory. +std::unique_ptr CreateHloProfilePrinter( + const HloProfileIndexMap& hlo_profile_index_map, + const HloCostAnalysis& cost_analysis); + // Describes how much time each HLO operation took. // // Each HloComputation takes a certain number of cycles. This class helps break @@ -85,8 +90,8 @@ class HloExecutionProfile { public: using DeviceDescription = perftools::gputools::DeviceDescription; - HloExecutionProfile(const HloModule& module, - const HloCostAnalysis& cost_analysis); + HloExecutionProfile(const HloProfilePrinter* hlo_profile_printer, + const HloProfileIndexMap* hlo_profile_index_map); // Record how many cycles this HLO took to execute. void SetCyclesTakenBy(const HloInstruction* hlo, uint64 cycles_taken); @@ -114,15 +119,16 @@ class HloExecutionProfile { // for the operations in a given computation. Returns an empty string if it // wasn't possible to generate a printable version. cost_analysis should be a // clean analysis that can be used to visit the computation. - string ToString(const DeviceDescription& device_description) const; + string ToString(const DeviceDescription& device_description) const { + return hlo_profile_printer_.ToString(profile_counters_.data(), + device_description.clock_rate_ghz()); + } - private: - // hlo_profile_index_map_ maps an Hlo entity (computation or instruction) to - // an index in profile_counters_. - HloProfileIndexMap hlo_profile_index_map_; + std::vector* mutable_profile_counters() { return &profile_counters_; } - // Used to print profile_counters_ in a human readable form. - HloProfilePrinter hlo_profile_printer_; + private: + const HloProfilePrinter& hlo_profile_printer_; + const HloProfileIndexMap& hlo_profile_index_map_; // Stores per-Hlo profile counters. This is the only thing that changes when // we execute an XLA computation. diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc index 0628444b34b017297d5da7980202e4c5586879ab..b1e6729e2bccad4bdbe075a635d8a9b1ede6fecb 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc @@ -72,7 +72,11 @@ TEST_F(HloExecutionProfileTest, Basic) { }; HloCostAnalysis cost_analysis(shape_size_function); - HloExecutionProfile execution_profile(*hlo_module, cost_analysis); + HloProfileIndexMap profile_index_map(*hlo_module); + std::unique_ptr profile_printer = + CreateHloProfilePrinter(profile_index_map, cost_analysis); + HloExecutionProfile execution_profile(profile_printer.get(), + &profile_index_map); const int64 add_cycles = 1000; const int64 dot_cycles = 4000; @@ -90,10 +94,10 @@ TEST_F(HloExecutionProfileTest, Basic) { const std::vector& line_3 = lines_and_words[3]; EXPECT_EQ(line_2[kInstructionCyclesIndex], std::to_string(dot_cycles)); - EXPECT_EQ(line_2[kInstructionNameIndex], dot_instruction->name()); + EXPECT_EQ(line_2[kInstructionNameIndex], '%' + dot_instruction->name()); EXPECT_EQ(line_3[kInstructionCyclesIndex], std::to_string(add_cycles)); - EXPECT_EQ(line_3[kInstructionNameIndex], add_instruction->name()); + EXPECT_EQ(line_3[kInstructionNameIndex], '%' + add_instruction->name()); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index d71a4b42c71154a25d1e6ec029ba3922361fd0b9..44db09208544a4372f37861b0a2a824faa593d60 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -864,9 +864,10 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { // (eg, parameter). switch (instr->opcode()) { case HloOpcode::kAbs: - case HloOpcode::kRoundNearestAfz: case HloOpcode::kAdd: + case HloOpcode::kAnd: case HloOpcode::kAtan2: + case HloOpcode::kBitcastConvert: case HloOpcode::kCeil: case HloOpcode::kClamp: case HloOpcode::kComplex: @@ -882,18 +883,19 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kIsFinite: case HloOpcode::kLe: case HloOpcode::kLog: - case HloOpcode::kAnd: - case HloOpcode::kNot: - case HloOpcode::kOr: case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kMultiply: case HloOpcode::kNe: case HloOpcode::kNegate: + case HloOpcode::kNot: + case HloOpcode::kOr: case HloOpcode::kPower: case HloOpcode::kReal: case HloOpcode::kRemainder: + case HloOpcode::kRng: + case HloOpcode::kRoundNearestAfz: case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: @@ -903,7 +905,6 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kSort: case HloOpcode::kSubtract: case HloOpcode::kTanh: - case HloOpcode::kRng: // De-emphasize scalar-shaped elementwise ops -- they're generally // uninteresting. if (ShapeUtil::IsEffectiveScalar(instr->shape())) { @@ -911,9 +912,9 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { } return kYellow; case HloOpcode::kBitcast: - case HloOpcode::kTuple: - case HloOpcode::kTrace: case HloOpcode::kGetTupleElement: + case HloOpcode::kTrace: + case HloOpcode::kTuple: return kWhite; case HloOpcode::kBroadcast: // De-emphasize nodes which broadcast a scalar within a fusion node -- @@ -952,28 +953,28 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { return kRed; case HloOpcode::kParameter: return kParameterColor; - case HloOpcode::kBatchNormTraining: - case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormGrad: + case HloOpcode::kBatchNormInference: + case HloOpcode::kBatchNormTraining: case HloOpcode::kReduce: - case HloOpcode::kSelectAndScatter: case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: return kPurple; - case HloOpcode::kMap: case HloOpcode::kFusion: + case HloOpcode::kMap: return kGray; - case HloOpcode::kSend: - case HloOpcode::kSendDone: - case HloOpcode::kRecv: - case HloOpcode::kRecvDone: + case HloOpcode::kCrossReplicaSum: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kSend: + case HloOpcode::kSendDone: return kBrown; + case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kCustomCall: case HloOpcode::kWhile: - case HloOpcode::kCall: return kDarkGreen; case HloOpcode::kConstant: LOG(FATAL) << "Constants don't get their own nodes in the graph."; @@ -1055,7 +1056,7 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { case HloOpcode::kBatchNormGrad: return Printf("feature_index=%lld", instr->feature_index()); case HloOpcode::kCustomCall: - return Printf("custom_call_target=%s", instr->custom_call_target()); + return Printf("target=%s", instr->custom_call_target()); case HloOpcode::kSlice: return std::all_of(instr->slice_strides().begin(), instr->slice_strides().end(), @@ -1090,7 +1091,7 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { instr->shape().dimensions_size() > 1 && !ShapeUtil::IsTuple(instr->shape())) { StrAppend(&instr_shape, "{", - Join(instr->shape().layout().minor_to_major(), ","), "}"); + Join(LayoutUtil::MinorToMajor(instr->shape()), ","), "}"); } // Some instructions have giant tuples as their shapes, so truncate the @@ -1353,18 +1354,17 @@ string SaveGraph(const string& graph, break; } string path = JoinPath( - dest_path, StrCat("hlo_graph_", output_num++, ".XXXXXX", file_extension)); + dest_path, StrCat("hlo_graph_", output_num++, ".")); auto status = Status::OK(); - int fd = mkstemps(&path[0], file_extension.length()); - if (fd < 0) { + auto env = tensorflow::Env::Default(); + if (!env->CreateUniqueFileName(&path, file_extension)) { status = Status(tensorflow::error::Code::UNKNOWN, StrCat("Failed to create temporary file to dump HLO graph: ", strerror(errno))); } else { status = - tensorflow::WriteStringToFile(tensorflow::Env::Default(), path, graph); - close(fd); + tensorflow::WriteStringToFile(env, path, graph); } if (!status.ok()) { LOG(WARNING) << "Saving HLO graph failed: " << status; @@ -1437,7 +1437,8 @@ void DumpText(const HloModule& module, const string& label, do_prefix ? StrCat(prefix, "-", label, ".txt") : StrCat(label, ".txt"); string path = JoinPath(directory_path, filename); TF_CHECK_OK(WriteStringToFile( - env, path, module.ToString(/*include_large_constants=*/true))); + env, path, + module.ToString(HloPrintOptions().set_print_large_constants(true)))); LOG(INFO) << "dumping module '" << module.name() << "' to " << path; } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index c35ca1eb992d98d10a0af1ca2327bcb93c2b4972..89a95b2b991b061acdb5701dc7507b6b0a33fe73 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -52,7 +52,9 @@ using ::tensorflow::strings::StrCat; StatusOr> HloInstruction::CreateFromProto( HloModule* module, const HloInstructionProto& proto, const tensorflow::gtl::FlatMap& instruction_map, - tensorflow::gtl::FlatMap* computation_map) { + const tensorflow::gtl::FlatMap& computation_map, + const std::function)>& + add_fused_computation) { TF_RET_CHECK(!proto.opcode().empty()); TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode())); TF_RET_CHECK(proto.has_shape()); @@ -78,19 +80,19 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(!proto.fusion_kind().empty()); TF_ASSIGN_OR_RETURN(instruction->fusion_kind_, StringToFusionKind(proto.fusion_kind())); - TF_ASSIGN_OR_RETURN( - std::unique_ptr fused_computation, - HloComputation::CreateFromProto( - module, proto.fused_instructions_computation(), computation_map, - /*fusion_instruction=*/instruction.get())); - instruction->called_computations_.push_back( - module->AddEmbeddedComputation(std::move(fused_computation))); + TF_ASSIGN_OR_RETURN(std::unique_ptr fused_computation, + HloComputation::CreateFromProto( + module, proto.fused_instructions_computation(), + computation_map, add_fused_computation, + /*fusion_instruction=*/instruction.get())); + instruction->called_computations_.push_back(fused_computation.get()); + add_fused_computation(std::move(fused_computation)); } else { for (const string& computation_name : proto.called_computation_names()) { - TF_RET_CHECK(ContainsKey(*computation_map, computation_name)) + TF_RET_CHECK(ContainsKey(computation_map, computation_name)) << "No computation named " << computation_name; instruction->called_computations_.push_back( - computation_map->at(computation_name)); + computation_map.at(computation_name)); } } @@ -102,7 +104,6 @@ StatusOr> HloInstruction::CreateFromProto( instruction->literal_ = MakeUnique(proto.literal()); } instruction->parameter_number_ = proto.parameter_number(); - instruction->parameter_name_ = proto.parameter_name(); instruction->tuple_index_ = proto.tuple_index(); for (int64 dimension : proto.dimensions()) { @@ -116,6 +117,10 @@ StatusOr> HloInstruction::CreateFromProto( MakeUnique( proto.convolution_dimension_numbers()); } + if (proto.has_dot_dimension_numbers()) { + instruction->dot_dimension_numbers_ = + MakeUnique(proto.dot_dimension_numbers()); + } for (const HloInstructionProto::SliceDimensions& slice_dimensions : proto.slice_dimensions()) { instruction->slice_starts_.push_back(slice_dimensions.start()); @@ -148,8 +153,7 @@ StatusOr> HloInstruction::CreateFromProto( auto instruction = WrapUnique(new HloInstruction(HloOpcode::kParameter, shape)); instruction->parameter_number_ = parameter_number; - instruction->parameter_name_ = name; - instruction->name_ = "%" + name; + instruction->name_ = name; return instruction; } @@ -330,6 +334,31 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, return instruction; } +/* static */ std::unique_ptr HloInstruction::CreateDot( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dimension_numbers) { + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); + instruction->AppendOperand(lhs); + instruction->AppendOperand(rhs); + instruction->dot_dimension_numbers_ = + MakeUnique(dimension_numbers); + return instruction; +} + +/* static */ std::unique_ptr HloInstruction::CreateCanonicalDot( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs) { + CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2); + CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2); + + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); + instruction->AppendOperand(lhs); + instruction->AppendOperand(rhs); + instruction->dot_dimension_numbers_ = MakeUnique(); + instruction->dot_dimension_numbers_->add_lhs_contracting_dimensions(1); + instruction->dot_dimension_numbers_->add_rhs_contracting_dimensions(0); + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateReducePrecision(const Shape& shape, HloInstruction* operand, @@ -344,12 +373,9 @@ HloInstruction::CreateReducePrecision(const Shape& shape, } /* static */ std::unique_ptr -HloInstruction::CreateCrossReplicaSum(const Shape& shape, - HloInstruction* operand) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kCrossReplicaSum, shape)); - instruction->AppendOperand(operand); - return instruction; +HloInstruction::CreateCrossReplicaSum( + const Shape& shape, tensorflow::gtl::ArraySlice operands) { + return CreateNary(shape, HloOpcode::kCrossReplicaSum, operands); } /* static */ std::unique_ptr HloInstruction::CreateInfeed( @@ -436,6 +462,23 @@ HloInstruction::CreateCrossReplicaSum(const Shape& shape, return instruction; } +/* static */ std::unique_ptr HloInstruction::CreateConditional( + const Shape& shape, HloInstruction* pred, + HloInstruction* true_computation_arg, HloComputation* true_computation, + HloInstruction* false_computation_arg, HloComputation* false_computation) { + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kConditional, shape)); + instruction->AppendOperand(pred); + instruction->AppendOperand(true_computation_arg); + instruction->AppendOperand(false_computation_arg); + // In called_computations_, the index of true_computation must be 0 and that + // of false computation must be 1, as defined by kTrueComputationIndex and + // kFalseComputationIndex. + instruction->called_computations_.push_back(true_computation); + instruction->called_computations_.push_back(false_computation); + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateSlice( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice start_indices, @@ -499,6 +542,15 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, return instruction; } +/* static */ std::unique_ptr +HloInstruction::CreateBitcastConvert(const Shape& shape, + HloInstruction* operand) { + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape)); + instruction->AppendOperand(operand); + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateReduce( const Shape& shape, HloInstruction* arg, HloInstruction* init_value, tensorflow::gtl::ArraySlice dimensions_to_reduce, @@ -631,7 +683,10 @@ HloInstruction::CreateSelectAndScatter( CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size()); CHECK(std::equal(operand->shape().dimensions().begin(), operand->shape().dimensions().end(), - Permute(dimensions, shape.dimensions()).begin())); + Permute(dimensions, shape.dimensions()).begin())) + << "shape: " << ShapeUtil::HumanString(shape) + << ", operand->shape(): " << ShapeUtil::HumanString(shape) + << ", dimensions: {" << Join(dimensions, ", ") << "}"; auto instruction = WrapUnique(new HloInstruction(HloOpcode::kTranspose, shape)); instruction->AppendOperand(operand); @@ -791,7 +846,7 @@ HloInstruction* HloInstruction::FuseInstructionInternal( HloInstruction* HloInstruction::CloneAndFuseInternal( HloInstruction* instruction_to_fuse, bool add_output) { CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK(instruction_to_fuse->IsFusable()); + CHECK(instruction_to_fuse->IsFusable()) << instruction_to_fuse->ToString(); VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString(); HloInstruction* clone = nullptr; if (called_computations_.empty()) { @@ -869,10 +924,8 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( // parameter instruction. int64 param_no = fused_parameters.size(); // Name the parameter after the instruction it represents in the outer - // (non-fusion) computation. Strip the leading "%" from the operand name - // to avoid a double %%. - string param_name = - StrCat(operand->name().substr(1), ".param_", param_no); + // (non-fusion) computation. + string param_name = StrCat(operand->name(), ".param_", param_no); fused_param = fused_instructions_computation()->AddParameter( CreateParameter(param_no, operand->shape(), param_name)); AppendOperand(operand); @@ -956,6 +1009,7 @@ bool HloInstruction::HasSideEffect() const { case HloOpcode::kSendDone: case HloOpcode::kRecv: case HloOpcode::kRecvDone: + case HloOpcode::kRng: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kTrace: @@ -1013,7 +1067,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( VLOG(3) << "CloneWithNewOperands:\n " << ToString(); VLOG(3) << " new operands:"; for (const HloInstruction* new_operand : new_operands) { - VLOG(3) << " " << new_operand->name(); + VLOG(3) << " %" << new_operand->name(); } std::unique_ptr clone; @@ -1057,7 +1111,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kLe: case HloOpcode::kLt: case HloOpcode::kNe: - case HloOpcode::kDot: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kPower: @@ -1095,6 +1148,10 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CHECK_EQ(new_operands.size(), 1); clone = CreateConvert(shape, new_operands[0]); break; + case HloOpcode::kBitcastConvert: + CHECK_EQ(new_operands.size(), 1); + clone = CreateBitcastConvert(shape, new_operands[0]); + break; case HloOpcode::kReducePrecision: CHECK_EQ(new_operands.size(), 1); clone = CreateReducePrecision(shape, new_operands[0], exponent_bits_, @@ -1105,9 +1162,13 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CreateConvolve(shape, new_operands[0], new_operands[1], *window_, *convolution_dimension_numbers_); break; + case HloOpcode::kDot: + CHECK_EQ(new_operands.size(), 2); + clone = CreateDot(shape, new_operands[0], new_operands[1], + *dot_dimension_numbers_); + break; case HloOpcode::kCrossReplicaSum: - CHECK_EQ(new_operands.size(), 1); - clone = CreateCrossReplicaSum(shape, new_operands[0]); + clone = CreateCrossReplicaSum(shape, new_operands); break; case HloOpcode::kGetTupleElement: CHECK_EQ(new_operands.size(), 1); @@ -1182,7 +1243,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CloneFusionWithNewOperands(shape, new_operands, module); break; case HloOpcode::kParameter: - clone = CreateParameter(parameter_number_, shape, parameter_name_); + clone = CreateParameter(parameter_number_, shape, name_); break; case HloOpcode::kBatchNormTraining: CHECK_EQ(new_operands.size(), 3); @@ -1211,6 +1272,11 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( new_operands[4], epsilon(), feature_index()); break; case HloOpcode::kConditional: + CHECK_EQ(new_operands.size(), 3); + clone = CreateConditional(shape, new_operands[0], new_operands[1], + true_computation(), new_operands[2], + false_computation()); + break; case HloOpcode::kRecv: case HloOpcode::kRecvDone: case HloOpcode::kSend: @@ -1476,7 +1542,6 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kCos: case HloOpcode::kCrossReplicaSum: case HloOpcode::kDivide: - case HloOpcode::kDot: case HloOpcode::kEq: case HloOpcode::kExp: case HloOpcode::kFloor: @@ -1535,6 +1600,7 @@ bool HloInstruction::IdenticalSlowPath( // A convert result is determined by the primitive type that the operand is // converted into. case HloOpcode::kConvert: + case HloOpcode::kBitcastConvert: return shape().element_type() == other.shape().element_type(); // A reduce-precision operation is determined by the bit sizes. @@ -1548,6 +1614,10 @@ bool HloInstruction::IdenticalSlowPath( protobuf_util::ProtobufEquals( convolution_dimension_numbers(), other.convolution_dimension_numbers()); + // Check dot dimension numbers. + case HloOpcode::kDot: + return protobuf_util::ProtobufEquals(dot_dimension_numbers(), + other.dot_dimension_numbers()); // Reduction results are determined by the reduction dimension and the // reduction computation. @@ -1814,6 +1884,32 @@ void HloInstruction::set_scatter(HloComputation* computation) { called_computations_[kScatterComputationIndex] = computation; } +HloComputation* HloInstruction::true_computation() const { + CHECK_EQ(HloOpcode::kConditional, opcode_); + return called_computations_[kTrueComputationIndex]; +} + +HloComputation* HloInstruction::false_computation() const { + CHECK_EQ(HloOpcode::kConditional, opcode_); + return called_computations_[kFalseComputationIndex]; +} + +void HloInstruction::set_true_computation(HloComputation* true_computation) { + // Don't allow changing the computation for fused instructions so we don't + // have to recompute called_instructions for the entire fusion instruction. + CHECK(!IsFused()); + CHECK_EQ(HloOpcode::kConditional, opcode_); + called_computations_[kTrueComputationIndex] = true_computation; +} + +void HloInstruction::set_false_computation(HloComputation* false_computation) { + // Don't allow changing the computation for fused instructions so we don't + // have to recompute called_instructions for the entire fusion instruction. + CHECK(!IsFused()); + CHECK_EQ(HloOpcode::kConditional, opcode_); + called_computations_[kFalseComputationIndex] = false_computation; +} + string HloInstruction::SignatureString() const { string operands = Join(operands_, ", ", [](string* out, HloInstruction* operand) { @@ -1822,16 +1918,23 @@ string HloInstruction::SignatureString() const { return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape())); } -string HloInstruction::ToString(bool compact_operands, bool include_metadata, - bool include_large_constants) const { +namespace { + +string PrintName(const string& name, const HloPrintOptions& options) { + return StrCat(options.print_percent() ? "%" : "", name); +} + +} // namespace + +string HloInstruction::ToString(const HloPrintOptions& options) const { string result = - StrCat(name(), " = ", ShapeUtil::HumanStringWithLayout(shape()), " ", - HloOpcodeString(opcode()), "(", - OperandsToString(compact_operands, include_large_constants), ")"); - for (const string& extra : ExtraAttributesToString()) { + StrCat(PrintName(name(), options), " = ", + ShapeUtil::HumanStringWithLayout(shape()), " ", + HloOpcodeString(opcode()), "(", OperandsToString(options), ")"); + for (const string& extra : ExtraAttributesToString(options)) { StrAppend(&result, ", ", extra); } - if (include_metadata && + if (options.print_metadata() && (!metadata_.op_type().empty() || !metadata_.op_name().empty() || !metadata_.source_file().empty())) { StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}"); @@ -1839,14 +1942,13 @@ string HloInstruction::ToString(bool compact_operands, bool include_metadata, return result; } -string HloInstruction::OperandsToString(bool compact, - bool include_large_constants) const { +string HloInstruction::OperandsToString(const HloPrintOptions& options) const { string operands; if (opcode() == HloOpcode::kConstant) { // For constants, show the actual value in place of an empty operand list. if ((!ShapeUtil::IsTuple(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) || - include_large_constants) { + options.print_large_constants()) { // Literal::ToString emits multidimensional arrays over multiple // lines. Compact this into one line by stripping out white space. string tmp = literal().ToString(); @@ -1871,14 +1973,19 @@ string HloInstruction::OperandsToString(bool compact, } else { tensorflow::gtl::ArraySlice slice(operands_); const int64 kMaxOperandsToShowIfCompact = 4; - if (compact && slice.size() > kMaxOperandsToShowIfCompact) { + if (options.compact_operands() && + slice.size() > kMaxOperandsToShowIfCompact) { slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact); } operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) { - *out += ShapeUtil::HumanStringWithLayout(operand->shape()); - if (!compact) { - StrAppend(out, " ", operand->name()); + std::vector str; + if (options.print_operand_shape()) { + str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape())); + } + if (!options.compact_operands()) { + str.push_back(PrintName(operand->name(), options)); } + StrAppend(out, Join(str, " ")); }); const int64 remaining = operands_.size() - slice.size(); if (slice.size() != operands_.size()) { @@ -1888,7 +1995,8 @@ string HloInstruction::OperandsToString(bool compact, return operands; } -std::vector HloInstruction::ExtraAttributesToString() const { +std::vector HloInstruction::ExtraAttributesToString( + const HloPrintOptions& options) const { std::vector extra; if (opcode() == HloOpcode::kFusion) { extra.push_back(StrCat("kind=", xla::ToString(fusion_kind()))); @@ -1896,7 +2004,7 @@ std::vector HloInstruction::ExtraAttributesToString() const { if (CanHaveDimensionsField()) { extra.push_back(StrCat("dimensions={", Join(dimensions(), ","), "}")); } - if (window_ != nullptr) { + if (window_ != nullptr && window_->dimensions_size() != 0) { extra.push_back(StrCat("window={", window_util::ToString(*window_), "}")); } if (padding_config_ != nullptr) { @@ -1930,22 +2038,33 @@ std::vector HloInstruction::ExtraAttributesToString() const { if (convolution_dimension_numbers_ != nullptr) { extra.push_back(ConvolutionDimensionNumbersToString()); } + if (dot_dimension_numbers_ != nullptr) { + extra.push_back(DotDimensionNumbersToString()); + } if (opcode() == HloOpcode::kWhile) { - extra.push_back(StrCat("condition=%", while_condition()->name())); - extra.push_back(StrCat("body=%", while_body()->name())); + extra.push_back( + StrCat("condition=", PrintName(while_condition()->name(), options))); + extra.push_back(StrCat("body=", PrintName(while_body()->name(), options))); } else if (opcode() == HloOpcode::kSelectAndScatter) { - extra.push_back(StrCat("select=%", select()->name())); - extra.push_back(StrCat("scatter=%", scatter()->name())); + extra.push_back(StrCat("select=", PrintName(select()->name(), options))); + extra.push_back(StrCat("scatter=", PrintName(scatter()->name(), options))); + } else if (opcode() == HloOpcode::kConditional) { + extra.push_back(StrCat("true_computation=", + PrintName(true_computation()->name(), options))); + extra.push_back(StrCat("false_computation=", + PrintName(false_computation()->name(), options))); } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap || opcode() == HloOpcode::kReduceWindow || opcode() == HloOpcode::kReduce) { - extra.push_back(StrCat("to_apply=%", to_apply()->name())); + extra.push_back( + StrCat("to_apply=", PrintName(to_apply()->name(), options))); } else if (!called_computations().empty()) { extra.push_back(StrCat( "calls=", Join(called_computations(), ", ", - [](string* out, const HloComputation* computation) { - StrAppend(out, "%", computation->name()); + [&](string* out, const HloComputation* computation) { + StrAppend(out, + PrintName(computation->name(), options)); }))); } @@ -1963,8 +2082,9 @@ std::vector HloInstruction::ExtraAttributesToString() const { if (!control_predecessors_.empty()) { extra.push_back(StrCat("control-predecessors={", Join(control_predecessors_, ", ", - [](string* out, HloInstruction* pre) { - StrAppend(out, pre->name()); + [&](string* out, HloInstruction* pre) { + StrAppend(out, + PrintName(pre->name(), options)); }), "}")); } @@ -1975,14 +2095,26 @@ std::vector HloInstruction::ExtraAttributesToString() const { extra.push_back( StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\"")); } + if (opcode() == HloOpcode::kRng) { + extra.push_back( + StrCat("distribution=", RandomDistributionToString(distribution_))); + } + if (opcode() == HloOpcode::kReducePrecision) { + extra.push_back(StrCat("exponent_bits=", exponent_bits_)); + extra.push_back(StrCat("mantissa_bits=", mantissa_bits_)); + } + if (opcode() == HloOpcode::kCustomCall) { + extra.push_back( + StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\"")); + } return extra; } string HloInstruction::ToShortString() const { - return StrCat(name(), " = ", HloOpcodeString(opcode()), "(", + return StrCat("%", name(), " = ", HloOpcodeString(opcode()), "(", Join(operands_, ", ", [](string* out, HloInstruction* operand) { - StrAppend(out, operand->name()); + StrAppend(out, "%", operand->name()); }), ")"); } @@ -2004,7 +2136,6 @@ HloInstructionProto HloInstruction::ToProto() const { *proto.mutable_literal() = literal_->ToProto(); } proto.set_parameter_number(parameter_number_); - proto.set_parameter_name(parameter_name_); if (opcode() == HloOpcode::kFusion) { proto.set_fusion_kind(xla::ToString(fusion_kind())); *proto.mutable_fused_instructions_computation() = @@ -2026,6 +2157,9 @@ HloInstructionProto HloInstruction::ToProto() const { *proto.mutable_convolution_dimension_numbers() = *convolution_dimension_numbers_; } + if (dot_dimension_numbers_ != nullptr) { + *proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_; + } for (int i = 0; i < slice_starts_.size(); ++i) { auto* slice_dimension = proto.add_slice_dimensions(); slice_dimension->set_start(slice_starts_[i]); @@ -2076,8 +2210,10 @@ string HloInstruction::ToCategory() const { bool saw_rank_1 = false; bool saw_higher_rank = false; for (const auto* operand : operands()) { - saw_rank_1 |= ShapeUtil::Rank(operand->shape()) == 1; - saw_higher_rank |= ShapeUtil::Rank(operand->shape()) > 1; + if (!ShapeUtil::IsTuple(operand->shape())) { + saw_rank_1 |= ShapeUtil::Rank(operand->shape()) == 1; + saw_higher_rank |= ShapeUtil::Rank(operand->shape()) > 1; + } } if (saw_rank_1 && saw_higher_rank) { return "rank-1-broadcast binary fusion"; @@ -2130,25 +2266,13 @@ bool HloInstruction::IsFusable() const { if (tracing()) { return false; } - // Some kinds of instructions don't make sense to fuse. switch (opcode_) { - case HloOpcode::kInfeed: - case HloOpcode::kOutfeed: case HloOpcode::kParameter: - case HloOpcode::kTrace: - case HloOpcode::kRecv: - case HloOpcode::kRecvDone: - case HloOpcode::kSend: - case HloOpcode::kSendDone: return false; - // Only fuse Rng if it is used once, otherwise the random numbers generated - // will be different in each fusion. If it is the root (user count = 0) - // then it is the equivalent of having one user. - case HloOpcode::kRng: - return users_.size() <= 1; + // Side effecting instrutions cannot be fused. default: - return true; + return !HasSideEffect(); } } @@ -2199,7 +2323,7 @@ HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape) : unique_id_(-1), opcode_(opcode), shape_(shape), - name_("%" + HloOpcodeString(opcode)) { + name_(HloOpcodeString(opcode)) { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_)); } @@ -2259,6 +2383,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleConcatenate(this); case HloOpcode::kConvert: return visitor->HandleConvert(this); + case HloOpcode::kBitcastConvert: + return visitor->HandleBitcastConvert(this); case HloOpcode::kCopy: return visitor->HandleCopy(this); case HloOpcode::kMultiply: @@ -2345,6 +2471,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleFusion(this); case HloOpcode::kCall: return visitor->HandleCall(this); + case HloOpcode::kConditional: + return visitor->HandleConditional(this); case HloOpcode::kCustomCall: return visitor->HandleCustomCall(this); case HloOpcode::kRecv: @@ -2357,7 +2485,6 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleSendDone(this); // These opcodes are not handled here. - case HloOpcode::kConditional: case HloOpcode::kTrace: break; } @@ -2423,7 +2550,7 @@ static Status PostOrderDFS(HloInstruction* root, Visitor* visitor, visitor->GetVisitState(current_id); if (visit_state == Visitor::kVisited) { dfs_stack.pop_back(); - VLOG(3) << "Not visiting HLO " << current_node->name() + VLOG(3) << "Not visiting HLO %" << current_node->name() << " as it was already visited."; continue; } @@ -2432,7 +2559,7 @@ static Status PostOrderDFS(HloInstruction* root, Visitor* visitor, dfs_stack.pop_back(); TF_RETURN_IF_ERROR(visitor->Preprocess(current_node)); - VLOG(2) << "Visiting HLO " << current_node->name(); + VLOG(2) << "Visiting HLO %" << current_node->name(); TF_RETURN_IF_ERROR(current_node->Visit(visitor)); visitor->SetVisitState(current_id, Visitor::kVisited); TF_RETURN_IF_ERROR(visitor->Postprocess(current_node)); @@ -2477,7 +2604,7 @@ template Status HloInstruction::Accept(DfsHloVisitorBase* visitor, bool call_finish_visit, bool ignore_control_predecessors) { - VLOG(3) << "HloInstruction::Accept(" << name() << ")"; + VLOG(3) << "HloInstruction::Accept(%" << name() << ")"; TF_RETURN_IF_ERROR( PostOrderDFS(this, visitor, nullptr, ignore_control_predecessors)); if (call_finish_visit) { @@ -2493,7 +2620,7 @@ template Status HloInstruction::Accept(ConstDfsHloVisitor*, bool, bool); Status HloInstruction::AcceptWithOperandOrder( DfsHloVisitor* visitor, const CompareFunction& operand_order, bool call_finish_visit) { - VLOG(2) << "HloInstruction::AcceptWithOperandOrder(" << name() << ")"; + VLOG(2) << "HloInstruction::AcceptWithOperandOrder(%" << name() << ")"; InternalCompareFunction func = [&operand_order]( std::pair a, std::pair b) { @@ -2556,7 +2683,7 @@ Status HloInstruction::Accept( Status HloInstruction::AcceptOrdered( DfsHloVisitor* visitor, const std::vector& order) { - VLOG(2) << "HloInstruction::AcceptOrdered(" << name() << ")"; + VLOG(2) << "HloInstruction::AcceptOrdered(%" << name() << ")"; TF_RET_CHECK(OrderIsTopologicalSort(order)); // Compute the predecessors of this instruction. @@ -2575,7 +2702,7 @@ Status HloInstruction::AcceptOrdered( // The visitor can mark instructions as visited to skip particular // instructions. if (visitor->DidVisit(*const_instruction)) { - VLOG(3) << "Not visiting HLO " << const_instruction->name() + VLOG(3) << "Not visiting HLO %" << const_instruction->name() << " as it was already visited."; continue; } @@ -2584,7 +2711,7 @@ Status HloInstruction::AcceptOrdered( const_cast(const_instruction); TF_RETURN_IF_ERROR(visitor->Preprocess(instruction)); - VLOG(2) << "Visiting HLO " << instruction->name(); + VLOG(2) << "Visiting HLO %" << instruction->name(); TF_RETURN_IF_ERROR(instruction->Visit(visitor)); visitor->SetVisited(*instruction); TF_RETURN_IF_ERROR(visitor->Postprocess(instruction)); @@ -2630,6 +2757,7 @@ bool HloInstruction::IsElementwise() const { case HloOpcode::kRoundNearestAfz: case HloOpcode::kCeil: case HloOpcode::kConvert: + case HloOpcode::kBitcastConvert: case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kExp: @@ -2947,6 +3075,28 @@ string OpMetadataToString(const OpMetadata& metadata) { return Join(result, " "); } +string RandomDistributionToString(const RandomDistribution& distribution) { + return tensorflow::str_util::Lowercase(RandomDistribution_Name(distribution)); +} + +StatusOr StringToRandomDistribution(const string& name) { + static std::unordered_map* map = [] { + static auto* map = new std::unordered_map; + for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) { + if (RandomDistribution_IsValid(i)) { + auto value = static_cast(i); + (*map)[RandomDistributionToString(value)] = value; + } + } + return map; + }(); + auto found = map->find(tensorflow::str_util::Lowercase(name)); + if (found == map->end()) { + return InvalidArgument("Unknown distribution"); + } + return found->second; +} + std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { return os << ToString(kind); } @@ -2967,25 +3117,25 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const { // lhs_dims[i] is the symbol of the logical dimension i for the lhs // operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b". - std::vector lhs_dims(2 + dnums.spatial_dimensions().size()); + std::vector lhs_dims(2 + dnums.input_spatial_dimensions().size()); lhs_dims[dnums.input_batch_dimension()] = 'b'; lhs_dims[dnums.input_feature_dimension()] = 'f'; - for (int64 i = 0; i < dnums.spatial_dimensions().size(); ++i) { - lhs_dims[dnums.spatial_dimensions(i)] = StrCat(i); + for (int64 i = 0; i < dnums.input_spatial_dimensions().size(); ++i) { + lhs_dims[dnums.input_spatial_dimensions(i)] = StrCat(i); } std::vector rhs_dims(2 + dnums.kernel_spatial_dimensions().size()); rhs_dims[dnums.kernel_input_feature_dimension()] = "i"; rhs_dims[dnums.kernel_output_feature_dimension()] = "o"; - for (int64 i = 0; i < dnums.spatial_dimensions().size(); ++i) { + for (int64 i = 0; i < dnums.kernel_spatial_dimensions().size(); ++i) { rhs_dims[dnums.kernel_spatial_dimensions(i)] = StrCat(i); } - std::vector output_dims(2 + dnums.spatial_dimensions().size()); + std::vector output_dims(2 + dnums.output_spatial_dimensions().size()); output_dims[dnums.output_batch_dimension()] = 'b'; output_dims[dnums.output_feature_dimension()] = 'f'; - for (int64 i = 0; i < dnums.spatial_dimensions().size(); ++i) { - output_dims[dnums.spatial_dimensions(i)] = StrCat(i); + for (int64 i = 0; i < dnums.output_spatial_dimensions().size(); ++i) { + output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i); } result += "dim_labels="; @@ -2997,6 +3147,30 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const { return result; } +string HloInstruction::DotDimensionNumbersToString() const { + string result; + if (dot_dimension_numbers_ == nullptr) { + return result; + } + const DotDimensionNumbers& dnums = *dot_dimension_numbers_; + if (!dnums.lhs_batch_dimensions().empty()) { + result += "lhs_batch_dims="; + StrAppend(&result, Join(dnums.lhs_batch_dimensions(), ",")); + } + result += "lhs_contracting_dims="; + StrAppend(&result, Join(dnums.lhs_contracting_dimensions(), ",")); + + result += ","; + if (!dnums.rhs_batch_dimensions().empty()) { + result += "rhs_batch_dims="; + StrAppend(&result, Join(dnums.rhs_batch_dimensions(), ",")); + } + result += "rhs_contracting_dims="; + StrAppend(&result, Join(dnums.rhs_contracting_dimensions(), ",")); + + return result; +} + bool HloInstruction::CouldBeBitcast() const { switch (opcode_) { case HloOpcode::kTranspose: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index f5f40ad9475568496ad8da5ad528289f9867c29f..2083c1b81d4a69ea9cdb3c15a8f78d1d3b404309 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -44,6 +44,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/platform/logging.h" @@ -55,6 +56,90 @@ namespace xla { class HloComputation; class HloModule; +// A bunch of switches that control how the hlo text should be printed. +class HloPrintOptions { + public: + // Constructs the default print options: don't print large constants, don't + // compact operands, no indentation. + HloPrintOptions() + : print_large_constants_(false), + print_metadata_(true), + compact_operands_(false), + print_operand_shape_(true), + print_program_shape_(true), + print_percent_(true), + indent_amount_(0) {} + + static HloPrintOptions ShortParsable() { + return HloPrintOptions() + .set_print_large_constants(true) + .set_print_metadata(false) + .set_print_operand_shape(false) + .set_print_program_shape(false) + .set_print_percent(false); + } + + // If true, large constants will be printed out. + HloPrintOptions& set_print_large_constants(bool value) { + print_large_constants_ = value; + return *this; + } + + // If true, metatdata will be printed. + HloPrintOptions& set_print_metadata(bool value) { + print_metadata_ = value; + return *this; + } + + // If true, operands' shapes will be printed. + HloPrintOptions& set_print_operand_shape(bool value) { + print_operand_shape_ = value; + return *this; + } + + // If true, program shape of hlo computations will be printed. + HloPrintOptions& set_print_program_shape(bool value) { + print_program_shape_ = value; + return *this; + } + + // If true, names will be printed with prefix '%'. + HloPrintOptions& set_print_percent(bool value) { + print_percent_ = value; + return *this; + } + + // If true, only a part of operands will be printed out, and their names will + // be omitted (note that in this case the text will not be parsable). + HloPrintOptions& set_compact_operands(bool value) { + compact_operands_ = value; + return *this; + } + + // The indent of the hlo text block. + HloPrintOptions& set_indent_amount(int value) { + indent_amount_ = value; + return *this; + } + + bool print_large_constants() const { return print_large_constants_; } + bool print_metadata() const { return print_metadata_; } + bool compact_operands() const { return compact_operands_; } + bool print_operand_shape() const { return print_operand_shape_; } + bool print_program_shape() const { return print_program_shape_; } + bool print_percent() const { return print_percent_; } + int indent_amount() const { return indent_amount_; } + + private: + bool print_large_constants_; + bool print_metadata_; + bool compact_operands_; + bool print_operand_shape_; + bool print_program_shape_; + bool print_percent_; + int indent_amount_; +}; + // HLO instructions are the IR used by the high-level compiler. class HloInstruction { public: @@ -83,12 +168,16 @@ class HloInstruction { // must contain all operands of the newly constructed instruction. // computation_map: a map from computation name to HloComputation*. This map // must contain all computations which the newly constructed instruction - // calls. If the instruction is a fusion instruction, then the fusion - // computation is added to this map and the module. + // calls. + // add_fused_computation: A function to call to add a fused + // computation. Used (clearly) when the instruction is a fusion + // instruction. static StatusOr> CreateFromProto( HloModule* module, const HloInstructionProto& proto, const tensorflow::gtl::FlatMap& instruction_map, - tensorflow::gtl::FlatMap* computation_map); + const tensorflow::gtl::FlatMap& computation_map, + const std::function)>& + add_fused_computation); // Creates a parameter-retrieving instruction. static std::unique_ptr CreateParameter(int64 parameter_number, @@ -155,6 +244,18 @@ class HloInstruction { const Window& window, const ConvolutionDimensionNumbers& dimension_numbers); + // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch + // dimensions specified in 'dimension_numbers'. + static std::unique_ptr CreateDot( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dimension_numbers); + + // Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1 + // of the LHS with dimension 0 of the RHS with no batch dimensions. Both LHS + // and the RHS must be of rank 2. + static std::unique_ptr CreateCanonicalDot( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs); + // Creates a reduce-precision op, where operand is the data to reduce in // precision, and exponent_bits and mantissa_bits describe the precision to // reduce it to. @@ -164,13 +265,19 @@ class HloInstruction { // Creates a cross replica sum op. static std::unique_ptr CreateCrossReplicaSum( - const Shape& shape, HloInstruction* operand); + const Shape& shape, + tensorflow::gtl::ArraySlice operands); // Creates a conversion instruction, where operand is the data to convert and // shape is the target shape for the conversion. static std::unique_ptr CreateConvert(const Shape& shape, HloInstruction* operand); + // Creates a bitcast conversion instruction, where operand is the data to + // convert and shape is the target shape for the conversion. + static std::unique_ptr CreateBitcastConvert( + const Shape& shape, HloInstruction* operand); + // Creates an infeed instruction, which reads data of the given shape from the // Infeed interface of the device. static std::unique_ptr CreateInfeed(const Shape& shape, @@ -305,6 +412,11 @@ class HloInstruction { HloComputation* body, HloInstruction* init); + static std::unique_ptr CreateConditional( + const Shape& shape, HloInstruction* pred, + HloInstruction* true_computation_arg, HloComputation* true_computation, + HloInstruction* false_computation_arg, HloComputation* false_computation); + // Creates a fusion instruction. A fusion instruction contains one or more // fused instructions forming an expression with a single root // "fused_root". Additional instructions can be added to the fusion @@ -406,7 +518,7 @@ class HloInstruction { Status RemoveControlDependencyTo(HloInstruction* instruction); // Returns the set of control predecessors (successors) of this - // instruction. Control predecessors (sucessors) must execute before (after) + // instruction. Control predecessors (successors) must execute before (after) // the current instruction. const std::vector& control_predecessors() const { return control_predecessors_; @@ -525,16 +637,6 @@ class HloInstruction { return parameter_number_; } - const string& parameter_name() const { - CHECK_EQ(HloOpcode::kParameter, opcode_); - return parameter_name_; - } - - void set_parameter_name(const string& str) { - CHECK_EQ(HloOpcode::kParameter, opcode_); - parameter_name_ = str; - } - // Returns the dimension sizes or numbers associated with this instruction. // // Precondition: opcode() is one of: concatenate, reduce, broadcast, reshape, @@ -608,23 +710,34 @@ class HloInstruction { void set_select(HloComputation* select); void set_scatter(HloComputation* scatter); + // Gets/sets the true and false HloComputation for Conditional. The setters + // should only be called by HloModule or HloComputation methods. + // + // Precondition: The instruction is a Conditional instruction. + HloComputation* true_computation() const; + HloComputation* false_computation() const; + void set_true_computation(HloComputation* true_computation); + void set_false_computation(HloComputation* false_computation); + // Returns a string for the signature of this instruction if considered as a // function, e.g. the signature of an F32 add is (F32, F32) -> F32. string SignatureString() const; // Returns a debugging string that represents this instruction. - string ToString(bool compact_operands = false, bool include_metadata = true, - bool include_large_constants = false) const; + // + // (We express the default options using an overload rather than a default + // param because gdb ignores default params, but does resolve overloads.) + string ToString() const { return ToString(HloPrintOptions()); } + string ToString(const HloPrintOptions& options) const; // Components of the ToString() representation: // Returns a string representation of the operand list. - string OperandsToString(bool compact, bool include_large_constants) const; + string OperandsToString(const HloPrintOptions& options) const; // Returns string representation of op-specific attributes. - std::vector ExtraAttributesToString() const; - - string ToStringNoMetadata() const { return ToString(false, false); } + std::vector ExtraAttributesToString( + const HloPrintOptions& options) const; // As ToString, but returns a shorter string. string ToShortString() const; @@ -652,13 +765,15 @@ class HloInstruction { // Returns feature_index field associated with the instruction. The index // represents the index of the feature dimension. // - // Precondition: opcode() == HloOpcode::kBatchNormTraining + // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference, + // or kBatchNormGrad. int64 feature_index() const { return feature_index_; } // Returns a epsilon value associated with the instruction. The is a small // number added to the variance to avoid divide-by-zero error. // - // Precondition: opcode() == HloOpcode::kBatchNormTraining + // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference, + // or kBatchNormGrad. float epsilon() const { return epsilon_; } // Returns the infeed configuration string. The infeed configuration includes @@ -891,6 +1006,15 @@ class HloInstruction { // Returns the dump string of the convolution dimension numbers. string ConvolutionDimensionNumbersToString() const; + // Returns data on the dimension numbers used for a dot operation. + const DotDimensionNumbers& dot_dimension_numbers() const { + CHECK(dot_dimension_numbers_ != nullptr); + return *dot_dimension_numbers_; + } + + // Returns the dump string of the dot dimension numbers. + string DotDimensionNumbersToString() const; + // Returns the random distribution for this rng node. // // Precondition: opcode() == HloOpcode::kRng @@ -982,10 +1106,9 @@ class HloInstruction { std::tuple, std::vector> ReshapeMerelyInsertsOrDeletes1SizedDimensions() const; - // Returns a string identifier for this instruction. If no string identifier - // has been explicitly set, then the identifier is the serialized pointer to - // this instruction. + // Gets/sets the string identifier for this instruction. const string& name() const { return name_; } + void set_name(tensorflow::StringPiece name) { name_ = name.ToString(); } // Use the given NameUniquer to select a unique name for the instruction based // on the instruction's existing name. @@ -1149,6 +1272,9 @@ class HloInstruction { // Describes the dimension numbers used for a convolution. std::unique_ptr convolution_dimension_numbers_; + // Describes the dimension numbers used for a dot. + std::unique_ptr dot_dimension_numbers_; + // Describes the [begin, end) index range for a slice. std::vector slice_starts_; std::vector slice_limits_; @@ -1174,7 +1300,6 @@ class HloInstruction { // For parameter instructions this field holds the parameter number. int64 parameter_number_ = 0; - string parameter_name_; // Name of a global symbol to call, only present for kCustomCall. string custom_call_target_; @@ -1192,6 +1317,10 @@ class HloInstruction { // kSelectAndScatter computations. kSelectComputationIndex = 0, kScatterComputationIndex = 1, + + // kConditional computations. + kTrueComputationIndex = 0, + kFalseComputationIndex = 1, }; // Outfeed configuration information, only present for kOutfeed. @@ -1239,9 +1368,12 @@ string ToString(HloInstruction::FusionKind kind); StatusOr StringToFusionKind( const string& kind_name); -// Custom stringification functions for protos that live inside HloInstruction. +// Custom (de)stringification functions for protos that live inside +// HloInstruction. string PaddingConfigToString(const PaddingConfig& padding); string OpMetadataToString(const OpMetadata& metadata); +string RandomDistributionToString(const RandomDistribution& distribution); +StatusOr StringToRandomDistribution(const string& name); std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index c383dea40555f4768eba6e59c98ac0c932284847..043c751a5e7193d80c3afd6fe2ccdb3434149feb 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1068,8 +1068,11 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); HloInstruction* reshape = builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(sout, HloOpcode::kDot, x, reshape)); + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); HloModule module(TestName()); auto* computation = module.AddEntryComputation(builder.Build()); @@ -1088,48 +1091,6 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { root2->operand(1)->operand(0)->shape())); } -TEST_F(HloInstructionTest, IsRandomFusable) { - auto shape = ShapeUtil::MakeShape(F32, {2, 2}); - { - auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); - auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(0.0))); - auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(1.0))); - auto rng = builder.AddInstruction(HloInstruction::CreateRng( - shape, RandomDistribution::RNG_NORMAL, {const0, const1})); - - auto* computation = hlo_module->AddEntryComputation(builder.Build()); - computation->CreateFusionInstruction({rng, const0, const1}, - HloInstruction::FusionKind::kLoop); - - auto* root = computation->root_instruction(); - - EXPECT_EQ(HloOpcode::kFusion, root->opcode()); - } - { - auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); - auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(0.0))); - auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(1.0))); - auto rng = builder.AddInstruction(HloInstruction::CreateRng( - shape, RandomDistribution::RNG_NORMAL, {const0, const1})); - builder.AddInstruction(HloInstruction::CreateUnary( - shape, HloOpcode::kNegate, rng)); - auto* computation = hlo_module->AddEntryComputation(builder.Build()); - computation->CreateFusionInstruction({rng, const0, const1}, - HloInstruction::FusionKind::kLoop); - - auto* root = computation->root_instruction(); - - EXPECT_EQ(HloOpcode::kFusion, root->operand(0)->opcode()); - } -} - - TEST_F(HloInstructionTest, CloneSuffixNames) { // Test that the suffix string added to cloned instructions is not // duplicated. Rather a numeric incrementing value should be appended. That @@ -1138,39 +1099,38 @@ TEST_F(HloInstructionTest, CloneSuffixNames) { // Test cloning the same instruction multiple times. auto foo = HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "foo"); - EXPECT_EQ(foo->Clone()->name(), "%foo.clone"); - EXPECT_EQ(foo->Clone()->Clone()->name(), "%foo.clone2"); - EXPECT_EQ(foo->Clone()->Clone()->Clone()->name(), "%foo.clone3"); + EXPECT_EQ(foo->Clone()->name(), "foo.clone"); + EXPECT_EQ(foo->Clone()->Clone()->name(), "foo.clone2"); + EXPECT_EQ(foo->Clone()->Clone()->Clone()->name(), "foo.clone3"); // Test custom suffixes. - EXPECT_EQ(foo->Clone("bar")->name(), "%foo.bar"); - EXPECT_EQ(foo->Clone("bar")->Clone("bar")->name(), "%foo.bar2"); - EXPECT_EQ(foo->Clone("bar")->Clone("bar")->Clone()->name(), - "%foo.bar2.clone"); + EXPECT_EQ(foo->Clone("bar")->name(), "foo.bar"); + EXPECT_EQ(foo->Clone("bar")->Clone("bar")->name(), "foo.bar2"); + EXPECT_EQ(foo->Clone("bar")->Clone("bar")->Clone()->name(), "foo.bar2.clone"); // Test instruction name with a dot. auto foo_baz = HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {}), "foo.baz"); - EXPECT_EQ(foo_baz->Clone()->name(), "%foo.baz.clone"); + EXPECT_EQ(foo_baz->Clone()->name(), "foo.baz.clone"); // Test incrementing a large number after the suffix. auto foo_clone234 = HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {}), "foo.clone234"); - EXPECT_EQ(foo_clone234->Clone()->name(), "%foo.clone235"); + EXPECT_EQ(foo_clone234->Clone()->name(), "foo.clone235"); // Test a non-numeric string after the cloning suffix. auto foo_clonexyz = HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {}), "foo.clonexyz"); - EXPECT_EQ(foo_clonexyz->Clone()->name(), "%foo.clonexyz.clone"); + EXPECT_EQ(foo_clonexyz->Clone()->name(), "foo.clonexyz.clone"); // Test a name with multiple appearances of the suffix. auto foo_clone_clone3 = HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {}), "foo.clone.clone3"); - EXPECT_EQ(foo_clone_clone3->Clone()->name(), "%foo.clone.clone4"); + EXPECT_EQ(foo_clone_clone3->Clone()->name(), "foo.clone.clone4"); } TEST_F(HloInstructionTest, Stringification) { - // Tests stringification of a simple op, fusion, and while. + // Tests stringification of a simple op, fusion, while, and conditional. const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); @@ -1183,12 +1143,17 @@ TEST_F(HloInstructionTest, Stringification) { builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); HloInstruction* reshape = builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(sout, HloOpcode::kDot, x, reshape)); + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + + auto options = HloPrintOptions().set_print_metadata(false); - EXPECT_EQ(dot->ToString(false, false), + EXPECT_EQ(dot->ToString(options), "%dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} " - "%transpose)"); + "%transpose), lhs_contracting_dims=1,rhs_contracting_dims=0"); HloModule module(TestName()); auto* computation = module.AddEntryComputation(builder.Build()); @@ -1196,15 +1161,25 @@ TEST_F(HloInstructionTest, Stringification) { {dot, reshape}, HloInstruction::FusionKind::kTransposeDot); EXPECT_EQ( - fusion->ToString(false, false), + fusion->ToString(options), "%fusion = f32[5,20]{1,0} fusion(f32[5,10]{1,0} %x, " "f32[20,10]{1,0} %y), kind=kTransposeDot, calls=%fused_computation"); HloInstruction* loop = builder.AddInstruction( HloInstruction::CreateWhile(sout, computation, computation, x)); - EXPECT_EQ(loop->ToString(false, false), + EXPECT_EQ(loop->ToString(options), "%while = f32[5,20]{1,0} while(f32[5,10]{1,0} %x), " "condition=%TransposeDot, body=%TransposeDot"); + + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction* conditional = + builder.AddInstruction(HloInstruction::CreateConditional( + sout, pred, x, computation, x, computation)); + EXPECT_EQ(conditional->ToString(options), + "%conditional = f32[5,20]{1,0} conditional(pred[] %constant, " + "f32[5,10]{1,0} %x, f32[5,10]{1,0} %x), " + "true_computation=%TransposeDot, false_computation=%TransposeDot"); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 268fa0f632d838c1122f655ea6a548335727390a..992f55788b4900949f4994ba5b7be015bcd0d3de 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -87,6 +87,7 @@ HLO_MATCHER(Call); HLO_MATCHER(Ceil); HLO_MATCHER(Clamp); HLO_MATCHER(Concatenate); +HLO_MATCHER(Conditional); HLO_MATCHER(Constant); HLO_MATCHER(Convert); HLO_MATCHER(Convolution); diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index d9c223fbbad5a3c20cba6d902ef5bc79e35304d1..6103cab3e7e73079ef9e65b4ada181aa088c4541 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -35,14 +35,15 @@ namespace xla { HloModule::HloModule(const string& name, const VersionedComputationHandle& entry_computation_handle, const HloModuleConfig& config) - : name_(name), + : name_(NameUniquer::GetSanitizedName(name)), config_(config), has_entry_computation_handle_(true), entry_computation_handle_(entry_computation_handle) {} -HloModule::HloModule(const string& name) : name_(name) {} +HloModule::HloModule(const string& name) + : name_(NameUniquer::GetSanitizedName(name)) {} HloModule::HloModule(const string& name, const HloModuleConfig& config) - : name_(name), config_(config) {} + : name_(NameUniquer::GetSanitizedName(name)), config_(config) {} HloComputation* HloModule::AddComputationInternal( std::unique_ptr computation, bool is_entry, @@ -170,17 +171,14 @@ void HloModule::ReplaceComputations( computations_ = std::move(new_computations); } -string HloModule::ToString(bool include_large_constants) const { +string HloModule::ToString(const HloPrintOptions& options) const { std::ostringstream s; - s << "HloModule " << name() << ":\n\n"; + s << "HloModule " << name() << "\n\n"; for (const HloComputation* computation : MakeComputationPostOrder()) { if (computation == entry_computation()) { s << "ENTRY "; } - s << computation->ToString( - /*nested_level=*/0, - /*include_large_constants=*/include_large_constants) - << "\n\n"; + s << computation->ToString(options) << "\n\n"; } return s.str(); } @@ -232,8 +230,8 @@ StatusOr ProgramShapeFromProto(const HloModuleProto& module) { << "Entry computation has more than one parameter instruction " "with parameter number " << instruction.parameter_number(); - parameters[instruction.parameter_number()] = { - instruction.parameter_name(), &instruction.shape()}; + parameters[instruction.parameter_number()] = {instruction.name(), + &instruction.shape()}; } } TF_RET_CHECK(root != nullptr) @@ -290,9 +288,16 @@ StatusOr> HloModule::CreateFromProto( tensorflow::gtl::FlatMap computation_map; for (const HloComputationProto& computation_proto : proto.computations()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr computation, - HloComputation::CreateFromProto( - module.get(), computation_proto, &computation_map)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr computation, + HloComputation::CreateFromProto( + module.get(), computation_proto, computation_map, + /*add_fused_computation=*/ + [&module](std::unique_ptr fused_computation) { + module->AddComputationInternal(std::move(fused_computation), + /*is_entry=*/false, + /*uniquify_names=*/false); + })); CHECK_NE(computation.get(), nullptr); TF_RET_CHECK(!ContainsKey(computation_map, computation->name())); string computation_name = computation->name(); diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 5141e7bc8d4cf0ef4cd83310772e0c5d66b5da12..d3bb46bffca15549ef22e2908f129efd8586fa67 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -98,6 +98,10 @@ class HloModule { return config_.mutable_entry_computation_layout(); } + ComputationLayout entry_computation_layout() const { + return config_.entry_computation_layout(); + } + const VersionedComputationHandle& entry_computation_handle() const { return entry_computation_handle_; } @@ -143,7 +147,12 @@ class HloModule { const HloModuleConfig& config() const { return config_; } - string ToString(bool include_large_constants = false) const; + // Return a string representation of the module. + // + // (We express the default options using an overload rather than a default + // param because gdb ignores default params, but does resolve overloads.) + string ToString() const { return ToString(HloPrintOptions()); } + string ToString(const HloPrintOptions& options) const; // Convert an HloModule to or from a proto. HloModuleProto ToProto() const; diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index bf6440d66cac0d3a929c377202b212aba262f887..0f5d3dccb74e6e3c88e51685392171f940c03596 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -135,14 +135,15 @@ TEST_F(HloModuleTest, LargeConstantToString) { module->AddEntryComputation(builder.Build()); EXPECT_EQ( - "HloModule LargeConstantToString:\n\nENTRY %Constant () -> f32[16] {\n " + "HloModule LargeConstantToString\n\nENTRY %Constant () -> f32[16] {\n " "ROOT %constant = f32[16]{0} constant({...})\n}\n\n", - module->ToString(/*include_large_constants=*/false)); + module->ToString(HloPrintOptions().set_print_large_constants(false))); + EXPECT_EQ( - "HloModule LargeConstantToString:\n\nENTRY %Constant () -> f32[16] {\n " + "HloModule LargeConstantToString\n\nENTRY %Constant () -> f32[16] {\n " "ROOT %constant = f32[16]{0} constant({42, 42, 42, 42, 42, 42, 42, 42, " "42, 42, 42, 42, 42, 42, 42, 42})\n}\n\n", - module->ToString(/*include_large_constants=*/true)); + module->ToString(HloPrintOptions().set_print_large_constants(true))); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 7b07027441670ed3f72ef802770858fb8a7476fe..f3f79357582ac7661a532e94031acdbca0b86784 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -52,6 +52,7 @@ namespace xla { V(kBatchNormInference, "batch-norm-inference") \ V(kBatchNormTraining, "batch-norm-training") \ V(kBitcast, "bitcast") \ + V(kBitcastConvert, "bitcast-convert") \ V(kBroadcast, "broadcast") \ V(kCall, "call", kHloOpcodeIsVariadic) \ V(kCeil, "ceil") \ diff --git a/tensorflow/compiler/xla/service/hlo_profile_printer.cc b/tensorflow/compiler/xla/service/hlo_profile_printer.cc index 071c5a6629addad1a25116739a4d34e7ce55070a..e944ad15139af0d2f98e8e68d3d48303f47ecf1c 100644 --- a/tensorflow/compiler/xla/service/hlo_profile_printer.cc +++ b/tensorflow/compiler/xla/service/hlo_profile_printer.cc @@ -50,7 +50,7 @@ string HloProfilePrinter::ToString(const int64* counters, /*short_name=*/instruction->short_name, instruction->category, counters[instruction->profile_index], instruction->flop_count, instruction->transcendental_count, instruction->bytes_accessed, - instruction->seconds); + instruction->optimal_seconds); } result += builder.ToString(); diff --git a/tensorflow/compiler/xla/service/hlo_profile_printer.h b/tensorflow/compiler/xla/service/hlo_profile_printer.h index 45921c66f68e811ef9d0ca3acc37465f5a160c94..2f056490ae027872570f7a0821ee63114f49fab8 100644 --- a/tensorflow/compiler/xla/service/hlo_profile_printer.h +++ b/tensorflow/compiler/xla/service/hlo_profile_printer.h @@ -41,7 +41,7 @@ class HloProfilePrinter { float flop_count; float transcendental_count; float bytes_accessed; - float seconds; + float optimal_seconds; // The index into the profile counters array for the HloInstruction // corresponding to this HloInstructionInfo. @@ -65,9 +65,11 @@ class HloProfilePrinter { HloProfilePrinter( HloComputationInfo* computation_infos, int64 computation_infos_size, + int64 profile_counters_size, std::function deleter = nullptr) : computation_infos_(computation_infos), computation_infos_size_(computation_infos_size), + profile_counters_size_(profile_counters_size), deleter_(std::move(deleter)) {} HloProfilePrinter(HloProfilePrinter&& other) { @@ -79,10 +81,13 @@ class HloProfilePrinter { HloProfilePrinter(const HloProfilePrinter&) = delete; HloProfilePrinter& operator=(const HloProfilePrinter&) = delete; - // Convert the profile counter sequence `counters` to a human readable string + // Converts the profile counter sequence `counters` to a human readable string // representation. string ToString(const int64* counters, double clock_rate_ghz) const; + // Returns the size of the profile buffer expected by this printer. + int64 profile_counters_size() const { return profile_counters_size_; } + ~HloProfilePrinter(); private: @@ -90,6 +95,7 @@ class HloProfilePrinter { // is manifested as the deleter_ function. HloComputationInfo* computation_infos_ = nullptr; int64 computation_infos_size_ = 0; + int64 profile_counters_size_ = 0; std::function deleter_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h index d7bdac9c86579f19afbba133772c2c50894853d1..553ec11f6f9a2997ab7113f9b8241e04c7fe20d5 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.h +++ b/tensorflow/compiler/xla/service/hlo_reachability.h @@ -30,11 +30,17 @@ namespace xla { class HloInstruction; -// A class for computing and representing reachability between HloInstructions. +// A class for representing reachability between HloInstructions. +// +// !!! THIS CLASS DOES NOT COMPUTE REACHABILITY !!! It has an adjacency matrix +// and it is up to the user of the class to set the adjacency matrix such that +// it represents reachability, i.e. such that it is transitive. That the graph +// be transitive is thus not an invariant of this class, but it is required for +// the name of the class and its methods to make sense. class HloReachabilityMap { public: - // Sets up an empty reachable matrix for the full set of instructions - // specified in 'instructions'. + // Sets up a graph with no edges and where the nodes correspond to the given + // instructions. explicit HloReachabilityMap(const std::list& instructions); // Set the reachability set of 'instruction' to the union of the reachability @@ -42,17 +48,33 @@ class HloReachabilityMap { // 'x' is not 'instruction' will return true iff IsReachable(x, input) is true // for some 'input' in 'inputs'. Also sets 'instruction' to be reachable from // itself. Returns whether the reachability set of 'instruction' changed. + // + // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency + // vector in the internal graph of this HloReachabilityMap for the given + // instruction and does not transitively update any other part of the + // adjacency matrix. bool SetReachabilityToUnion( tensorflow::gtl::ArraySlice inputs, const HloInstruction* instruction); // Sets entry so that IsReachable(a, b) will return true + // + // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency + // matrix in the internal graph of this HloReachabilityMap to have an edge + // from a to b and does not transitively update any other part of the + // adjacency matrix. void SetReachable(const HloInstruction* a, const HloInstruction* b); // Returns true if "b" is reachable from "a" + // + // Note that this function only correctly answers queries about reachability + // if the set of edges that have been provided to this class are transitive. bool IsReachable(const HloInstruction* a, const HloInstruction* b) const; // Returns true if "b" is reachable from "a" or "a" is reachable from "b" + // + // Note that this function only correctly answers queries about reachability + // if the set of edges that have been provided to this class are transitive. bool IsConnected(const HloInstruction* a, const HloInstruction* b) const; private: diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 828be8490c994e1992a99e8a9aa960a279486666..c6b4dc0368d92fd477decdfb38045f74f8696803 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -62,18 +62,11 @@ bool IsRematerializable(const HloInstruction* instruction) { case HloOpcode::kConstant: case HloOpcode::kCrossReplicaSum: case HloOpcode::kCustomCall: - case HloOpcode::kOutfeed: - case HloOpcode::kInfeed: case HloOpcode::kParameter: - case HloOpcode::kRecv: - case HloOpcode::kRecvDone: - case HloOpcode::kSend: - case HloOpcode::kSendDone: - case HloOpcode::kTrace: case HloOpcode::kWhile: return false; default: - return true; + return !instruction->HasSideEffect(); } } @@ -573,7 +566,9 @@ Status MemoryUsageTracker::BeginInstruction(Item* item) { VLOG(3) << " memory usage = " << memory_usage_; VLOG(10) << ToString(); - DCHECK(Check()); + if (VLOG_IS_ON(1)) { + DCHECK(Check()); + } return Status::OK(); } @@ -610,8 +605,9 @@ Status MemoryUsageTracker::EndInstruction() { VLOG(3) << " memory usage = " << memory_usage_; VLOG(10) << ToString(); - DCHECK(Check()); - + if (VLOG_IS_ON(1)) { + DCHECK(Check()); + } return Status::OK(); } @@ -1028,7 +1024,9 @@ StatusOr HloRematerialization::RematerializeComputation( HloInstruction* best = best_item->instruction; VLOG(1) << "Rematerializing instruction " << best->name() << " (saving " - << memory_tracker.MemoryReducedIfRematerialized(best_item) << ")"; + << HumanReadableNumBytes( + memory_tracker.MemoryReducedIfRematerialized(best_item)) + << ")"; changed = true; remat_count++; @@ -1108,8 +1106,8 @@ StatusOr HloRematerialization::RematerializeComputation( net_instructions_added++; } - VLOG(3) << "memory_usage after rematerialization = " - << memory_tracker.memory_usage(); + VLOG(1) << "memory_usage after rematerialization = " + << HumanReadableNumBytes(memory_tracker.memory_usage()); } const CallSite* callsite = call_graph_node.GetCallSite(instruction); @@ -1215,11 +1213,12 @@ StatusOr HloRematerialization::Run( XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); // Create initial sequence of HLO instructions. - TF_ASSIGN_OR_RETURN(*sequence, - CreateMemoryMinimizingSequence( - *module, [this](const LogicalBuffer& buffer) { - return size_function_(buffer.shape()); - })); + TF_ASSIGN_OR_RETURN(*sequence, CreateMemoryMinimizingSequence( + *module, + [this](const LogicalBuffer& buffer) { + return size_function_(buffer.shape()); + }, + scheduler_algorithm_)); // Compute peak memory usage of all computations in the module called in a // sequential context. call_graph_ = CallGraph::Build(module); @@ -1320,9 +1319,10 @@ StatusOr HloRematerialization::Run( /* static */ StatusOr HloRematerialization::RematerializeAndSchedule( const HloRematerialization::ShapeSizeFunction& size_function, int64 memory_limit_bytes, HloModule* hlo_module, + SchedulerAlgorithm scheduler_algorithm, SequentialHloOrdering::HloModuleSequence* sequence, RematerializationSizes* sizes) { - HloRematerialization remat(size_function); + HloRematerialization remat(scheduler_algorithm, size_function); return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes); } diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 11f79a6d4158c6251c2faf63e9cac4e742440863..52553439033a3bcfa4b472f13f9cd4b1ecf5ed96 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -20,6 +20,7 @@ #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" namespace xla { @@ -65,12 +66,15 @@ class HloRematerialization { // code generation. static StatusOr RematerializeAndSchedule( const ShapeSizeFunction& size_function, int64 memory_limit_bytes, - HloModule* hlo_module, SequentialHloOrdering::HloModuleSequence* sequence, + HloModule* hlo_module, SchedulerAlgorithm scheduler_algorithm, + SequentialHloOrdering::HloModuleSequence* sequence, RematerializationSizes* sizes = nullptr); protected: - HloRematerialization(const ShapeSizeFunction& size_function) - : size_function_(size_function) {} + HloRematerialization(SchedulerAlgorithm scheduler_algorithm, + const ShapeSizeFunction& size_function) + : scheduler_algorithm_(scheduler_algorithm), + size_function_(size_function) {} ~HloRematerialization() {} // Runs rematerialization on the given module. Returns whether the module was @@ -103,6 +107,9 @@ class HloRematerialization { StatusOr CalledComputationsMemoryUsage( const HloInstruction* instruction) const; + // Selects an algorithm to use for HLO scheduling. + SchedulerAlgorithm scheduler_algorithm_; + // Function which computes the size of the top-level buffer of a shape. const ShapeSizeFunction size_function_; diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index d88aa4bb567c6c5f6eab54f12239bf7040339c39..216825959a560bd5baa4b49d1a3cace277e16098 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -158,11 +158,11 @@ TEST_F(HloRematerializationTest, SingleComputation) { SequentialHloOrdering::HloModuleSequence sequence; // Computation requires 16KB without rematerialization, but uses only 12KB // with rematerialization so pick a memory limit between these values (14KB). - TF_ASSERT_OK_AND_ASSIGN( - bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/14 * 1024, module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/14 * 1024, module.get(), + SchedulerAlgorithm::kAuto, &sequence)); EXPECT_TRUE(changed); // Root should not have changed. @@ -191,11 +191,11 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { EXPECT_EQ(computation->instruction_count(), 7); SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN( - bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/20 * 1024, module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/20 * 1024, module.get(), + SchedulerAlgorithm::kAuto, &sequence)); // No instructions should have been materialized. EXPECT_FALSE(changed); @@ -232,11 +232,11 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { // while so the peak memory use of the module is 18KB. Set the memory limit a // bit lower (17KB) to force rematerialization of the entry computation. SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN( - bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/17 * 1024, module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/17 * 1024, module.get(), + SchedulerAlgorithm::kAuto, &sequence)); EXPECT_TRUE(changed); // Only the entry computation should have a rematerialized instruction added. @@ -268,11 +268,11 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { EXPECT_EQ(body_computation->instruction_count(), 7); SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN( - bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/15 * 1024, module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/15 * 1024, module.get(), + SchedulerAlgorithm::kAuto, &sequence)); EXPECT_TRUE(changed); // Both computations should have a rematerialized instruction added. @@ -310,11 +310,11 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { // If all computations are maximally rematerialized then peak memory usage is // ~12K so pick something slightly larger. SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN( - bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/13 * 1024, module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/13 * 1024, module.get(), + SchedulerAlgorithm::kAuto, &sequence)); EXPECT_TRUE(changed); // All computations should have a rematerialized instruction added. @@ -323,6 +323,76 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { EXPECT_EQ(inner_computation->instruction_count(), 8); } +TEST_F(HloRematerializationTest, RngNotRematerialized) { + // Test that a single rng is not rematerialized: + // + // Entry computation: + // F32[] %param = {...} + // F32[1024] rng = rng(param) + // F32[1024] tanh = tanh(rng) + // F32[1024] exp = exp(rng) + // F32[1024] add_0 = add(rng, tanh) // LIVE: add_0 + rng + + // // tanh + exp + // + // F32[1024] add_1 = add(rng, add(exp, add_0)) // LIVE: add_1 + add_0 + + // // rng + tanh + exp + // + // F32[1024] add_2 = add(rng, add(tanh, add_1)) // LIVE: add_2 + add_1 + + // // rng + tanh + exp + auto module = CreateNewModule(); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + auto rng = builder.AddInstruction(HloInstruction::CreateRng( + vec1024_shape_, RandomDistribution::RNG_BERNOULLI, {param})); + auto tanh = builder.AddInstruction( + HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kTanh, rng)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kExp, rng)); + auto add_0 = builder.AddInstruction( + HloInstruction::CreateBinary(vec1024_shape_, HloOpcode::kAdd, rng, tanh)); + auto add_1 = builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, rng, + builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, exp, add_0)))); + builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, rng, + builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, tanh, add_1)))); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + + auto count_rngs = [](const HloComputation* computation) { + int64 rng_count = 0; + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kRng) { + ++rng_count; + } + } + return rng_count; + }; + // Before rematerialization there should be a single broadcast rng in + // the graph. + ASSERT_EQ(count_rngs(entry_computation), 1); + const int64 original_instruction_count = + entry_computation->instruction_count(); + SequentialHloOrdering::HloModuleSequence sequence; + // Pick a memory limit some where between 24KB (initial peak memory including + // parameter and output) and 20KB (peak memory possible with + // rematerialization). + TF_ASSERT_OK_AND_ASSIGN( + bool changed, HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), + module.get(), SchedulerAlgorithm::kAuto, &sequence)); + EXPECT_TRUE(changed); + // The rng should not have been rematerialized. + EXPECT_EQ(count_rngs(entry_computation), 1); + // There should have been rematerialization. + EXPECT_GT(entry_computation->instruction_count(), original_instruction_count); +} + TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { // Test that a single instruction is rematerialized several times. Module: // @@ -406,11 +476,11 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSERT_OK_AND_ASSIGN( - bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/22 * 1024, module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/22 * 1024, module.get(), + SchedulerAlgorithm::kAuto, &sequence)); EXPECT_TRUE(changed); // The broadcast should have been rematerialized 3 times. @@ -503,11 +573,11 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSERT_OK_AND_ASSIGN( - bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/22 * 1024, module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/22 * 1024, module.get(), + SchedulerAlgorithm::kAuto, &sequence)); // Rematerialization should only occur if the rematerializable instruction has // no indirect uses. if (indirectly_used) { diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 63f2b1296ed06d6477e9a24f8034bb57ceabd5cc..7b3a8cef97b5670b1ab753cee14203a58c1e5c27 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -39,6 +39,14 @@ namespace se = ::perftools::gputools; namespace xla { +/*static*/ StatusOr> +HloRunner::CreateModuleFromString(const tensorflow::StringPiece hlo_string, + const DebugOptions& debug_options) { + HloModuleConfig config; + config.set_debug_options(debug_options); + return tools::Parse(hlo_string, config); +} + /*static*/ StatusOr> HloRunner::ReadModuleFromHloProtoFile(const std::string& filename, const DebugOptions& debug_options) { @@ -104,26 +112,27 @@ HloRunner::HloRunner(se::Platform* platform) { VLOG(1) << "Created HloRunner for platform: " << platform->Name(); } -HloRunner::~HloRunner() { - // Deallocate all the memory allocated during the tests. - for (auto& allocation : allocations_) { - backend().default_stream_executor()->Deallocate(&allocation); - } -} +HloRunner::~HloRunner() {} -StatusOr HloRunner::Execute( +StatusOr> HloRunner::ExecuteInternal( std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments, - Shape* result_shape) { + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes) { + if (run_hlo_passes) { + TF_ASSIGN_OR_RETURN( + module, backend().compiler()->RunHloPasses( + std::move(module), backend().default_stream_executor())); + } TF_ASSIGN_OR_RETURN( std::unique_ptr executable, - backend().compiler()->Compile(std::move(module), - backend().default_stream_executor())); + backend().compiler()->RunBackend(std::move(module), + backend().default_stream_executor())); se::Stream stream(backend().default_stream_executor()); stream.Init(); ExecutableRunOptions run_options; + run_options.set_device_ordinal(backend().default_device_ordinal()); run_options.set_stream(&stream); run_options.set_allocator(backend().memory_allocator()); run_options.set_inter_op_thread_pool(backend().inter_op_thread_pool()); @@ -133,71 +142,35 @@ StatusOr HloRunner::Execute( ServiceExecutableRunOptions service_run_options( run_options, backend().StreamBorrower(), backend().inter_op_thread_pool()); - TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase result, - executable->ExecuteOnStream(&service_run_options, arguments, - /*hlo_execution_profile=*/nullptr)); - TF_RET_CHECK(stream.BlockHostUntilDone()); - - allocations_.push_back(result); - *result_shape = executable->result_shape(); - - if (ShapeUtil::IsTuple(*result_shape)) { - // We must record element buffers of tuples as well to avoid leaks. - DCHECK(!ShapeUtil::IsNestedTuple(*result_shape)); + // Copy arguments to device. + std::vector> argument_buffers; + std::vector argument_buffer_ptrs; + for (Literal* argument : arguments) { TF_ASSIGN_OR_RETURN( - std::vector element_buffers, - backend().transfer_manager()->ShallowCopyTupleFromDevice( - backend().default_stream_executor(), result, *result_shape)); - - // A tuple may contain the same buffer in more than one element. Keep track - // of the buffers already added to avoid duplicates in allocations_. - std::set added_opaques; - for (auto element_buffer : element_buffers) { - if (added_opaques.count(element_buffer.opaque()) == 0) { - CHECK(element_buffer.opaque() != nullptr); - added_opaques.insert(element_buffer.opaque()); - allocations_.push_back(element_buffer); - } - } + std::unique_ptr argument_buffer, + backend().transfer_manager()->AllocateScopedShapedBuffer( + argument->shape(), run_options.allocator(), + run_options.device_ordinal())); + TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( + stream.parent(), *argument, *argument_buffer)); + argument_buffers.push_back(std::move(argument_buffer)); + argument_buffer_ptrs.push_back(argument_buffers.back().get()); } - return result; -} - -StatusOr HloRunner::TransferToDevice( - const Literal& literal) { - // Allocate memory on the device using the stream executor. - int64 allocation_size = - backend().transfer_manager()->GetByteSizeRequirement(literal.shape()); - se::DeviceMemoryBase allocation = - backend().default_stream_executor()->AllocateArray( - allocation_size); - allocations_.push_back(allocation); - - TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( - backend().default_stream_executor(), literal, &allocation)); - - return allocation; -} + TF_ASSIGN_OR_RETURN( + std::unique_ptr result, + executable->ExecuteOnStream(&service_run_options, argument_buffer_ptrs, + /*hlo_execution_profile=*/nullptr)); -StatusOr> HloRunner::TransferFromDevice( - const Shape& shape, se::DeviceMemoryBase device_base) { - auto literal = MakeUnique(); - TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromDevice( - backend().default_stream_executor(), device_base, shape, shape, - literal.get())); - return std::move(literal); -} + // Create a ScopedShapedBuffer of the result to manage deallocation. This will + // deallocate all the device memory when it goes out of scope. + TF_ASSIGN_OR_RETURN( + std::unique_ptr scoped_result, + ScopedShapedBuffer::MakeScoped(result.get(), run_options.allocator())); -StatusOr> HloRunner::ExecuteAndTransfer( - std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments) { - Shape result_shape; - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase device_base, - Execute(std::move(module), arguments, &result_shape)); - return TransferFromDevice(result_shape, device_base); + return backend().transfer_manager()->TransferLiteralFromDevice( + stream.parent(), *scoped_result); } Backend& HloRunner::backend() { diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index a5732848c6b4191faf8d7b07c749132ca8b14413..d4b221fb52dff64dda264a931df6fd19b86e5260 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -35,7 +35,8 @@ namespace xla { // A base class for running an HloModule. This executes the given HloModule on a // certain backend directly without using the client interface. HloModule can be -// explicitly built, or loaded from a serialization file (e.g., hlo proto file). +// explicitly built, or loaded from a serialization file (e.g., hlo proto +// file), or parsed from a hlo textual IR string. class HloRunner { public: HloRunner(); @@ -44,6 +45,12 @@ class HloRunner { ~HloRunner(); + // Converts an HloModule from the given hlo textual IR string (in + // HloModule::ToString format). + static StatusOr> CreateModuleFromString( + const tensorflow::StringPiece hlo_string, + const DebugOptions& debug_options); + // Reads the proto file in xla.HloProto format, creates and returns the // HloModule. Will try to parse the filename as binary proto, then try as // text proto if that fails. @@ -65,32 +72,14 @@ class HloRunner { // Executes the given module with given literals as input and returns the // result as a Literal. The LiteralPtr type accepts Literal* or // std::unique_ptr. + // + // If run_hlo_passes is false, the module will be executed without Hlo + // optimization. template StatusOr> Execute( std::unique_ptr module, - const tensorflow::gtl::ArraySlice literals); - - // Executes the given module and returns a global data handle. - StatusOr Execute( - std::unique_ptr module, - tensorflow::gtl::ArraySlice - arguments, - Shape* result_shape); - - // Transfers the given literal to the device and returns the data handle. - StatusOr TransferToDevice( - const Literal& literal); - - // Transfers the array referred to by the given handle from the device and - // returns as a Literal. - StatusOr> TransferFromDevice( - const Shape& shape, perftools::gputools::DeviceMemoryBase device_base); - - // Executes the given module and return the result as a Literal. - StatusOr> ExecuteAndTransfer( - std::unique_ptr module, - tensorflow::gtl::ArraySlice - arguments); + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes = true); // If backend is not created in the constructor, creates and returns the // default backend. If creation fails, crashes the program. @@ -100,9 +89,12 @@ class HloRunner { Backend& backend(); private: - struct EigenThreadPoolWrapper; + StatusOr> ExecuteInternal( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes = true); - std::vector allocations_; + struct EigenThreadPoolWrapper; std::unique_ptr thread_pool_wrapper_; @@ -112,14 +104,14 @@ class HloRunner { template StatusOr> HloRunner::Execute( std::unique_ptr module, - const tensorflow::gtl::ArraySlice literals) { - std::vector arguments; - for (const auto& literal : literals) { - TF_ASSIGN_OR_RETURN(perftools::gputools::DeviceMemoryBase argument, - TransferToDevice(*literal)); - arguments.push_back(argument); + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes) { + // Construct a vector of plain pointers for the arguments. + std::vector argument_pointers; + for (const auto& argument : arguments) { + argument_pointers.push_back(&*argument); } - return ExecuteAndTransfer(std::move(module), arguments); + return ExecuteInternal(std::move(module), argument_pointers, run_hlo_passes); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index 8ccbcaeee4a9c9e94b344231953e20ac8f4b2053..2594c29efd717b3bead34d326c28c7efdf093c50 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -31,6 +31,8 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" +using ::tensorflow::strings::HumanReadableNumBytes; + namespace xla { StatusOr MinimumMemoryForSequence( @@ -367,7 +369,17 @@ StatusOr MinimumMemoryForComputation( StatusOr> CreateMemoryMinimizingSequence( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { + const LogicalBuffer::SizeFunction& size_function, + SchedulerAlgorithm algorithm) { + VLOG(2) << "Computation: " << computation.name(); + if (algorithm == SchedulerAlgorithm::kListSchedule) { + return ListScheduler::Run(computation, points_to_analysis, size_function); + } + if (algorithm == SchedulerAlgorithm::kDfsSchedule) { + return RunDFSMemoryScheduler(computation, points_to_analysis, + size_function); + } + // We try both a list-scheduler based ordering and a DFS based ordering, and // choose whichever returns a lower min-memory, not accounting for // fragmentation. @@ -382,7 +394,7 @@ StatusOr> CreateMemoryMinimizingSequence( const int64 list_memory, MinimumMemoryForComputation(computation, list_sequence, points_to_analysis, size_function)); - VLOG(2) << "Min-memory list sequence: " << list_memory << " bytes"; + VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); TF_ASSIGN_OR_RETURN( std::vector dfs_sequence, @@ -391,13 +403,15 @@ StatusOr> CreateMemoryMinimizingSequence( const int64 dfs_memory, MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis, size_function)); - VLOG(2) << "Min-memory dfs sequence: " << dfs_memory << " bytes"; + VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); if (list_memory <= dfs_memory) { - VLOG(2) << "Chose min-memory list sequence: " << list_memory << " bytes"; + VLOG(2) << "Chose min-memory list sequence: " + << HumanReadableNumBytes(list_memory); return list_sequence; } else { - VLOG(2) << "Chose min-memory dfs sequence: " << dfs_memory << " bytes"; + VLOG(2) << "Chose min-memory dfs sequence: " + << HumanReadableNumBytes(dfs_memory); return dfs_sequence; } } @@ -405,27 +419,30 @@ StatusOr> CreateMemoryMinimizingSequence( } // namespace StatusOr -CreateMemoryMinimizingSequence( - const HloModule& module, const LogicalBuffer::SizeFunction& size_function) { +CreateMemoryMinimizingSequence(const HloModule& module, + const LogicalBuffer::SizeFunction& size_function, + SchedulerAlgorithm algorithm) { SequentialHloOrdering::HloModuleSequence sequence; TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(&module)); for (const auto* computation : module.MakeNonfusionComputations()) { - TF_ASSIGN_OR_RETURN(sequence[computation], - CreateMemoryMinimizingSequence( - *computation, *points_to_analysis, size_function)); + TF_ASSIGN_OR_RETURN( + sequence[computation], + CreateMemoryMinimizingSequence(*computation, *points_to_analysis, + size_function, algorithm)); } return sequence; } StatusOr> CreateMemoryMinimizingSequence( const HloComputation& computation, - const LogicalBuffer::SizeFunction& size_function) { + const LogicalBuffer::SizeFunction& size_function, + SchedulerAlgorithm algorithm) { CHECK(!computation.IsFusionComputation()); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(computation.parent())); return CreateMemoryMinimizingSequence(computation, *points_to_analysis, - size_function); + size_function, algorithm); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h index ec92a56b962152b15981f868369683144aa7c76a..1d1eb1e064f75c2220b39e84b010e720a0c37880 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.h +++ b/tensorflow/compiler/xla/service/hlo_scheduling.h @@ -33,17 +33,28 @@ StatusOr MinimumMemoryForSequence( const SequentialHloOrdering::HloModuleSequence& module_sequence, const LogicalBuffer::SizeFunction& size_function); +enum class SchedulerAlgorithm { + kListSchedule, + kDfsSchedule, + + // Selects the available scheduler algorithm that had the minimum memory in + // the resulting sequence (a la MinimumMemoryForSequence). + kAuto, +}; + // Returns an HloModuleSequence which seeks to minimize the memory required for // the computation. size_function is the function returning the number of bytes // required for a LogicalBuffer. StatusOr CreateMemoryMinimizingSequence( - const HloModule& module, const LogicalBuffer::SizeFunction& size_function); + const HloModule& module, const LogicalBuffer::SizeFunction& size_function, + SchedulerAlgorithm algorithm = SchedulerAlgorithm::kAuto); // Overload of above that computes the sequence for a single computation. StatusOr> CreateMemoryMinimizingSequence( const HloComputation& computation, - const LogicalBuffer::SizeFunction& size_function); + const LogicalBuffer::SizeFunction& size_function, + SchedulerAlgorithm algorithm = SchedulerAlgorithm::kAuto); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 735666345421657f7f3d714826a428784e6072e7..447c2446668253c932b44b51b2db22bfd47f9957 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -160,7 +160,59 @@ bool HloSharding::HasUniqueDevice() const { } } +Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const { + if (!ShapeUtil::IsTuple(shape)) { + return tensorflow::errors::InvalidArgument( + StrCat("Sharding is tuple-shaped but validation shape is not.")); + } + // The easiest way to get the number of elements in a nested tuple is just to + // create a shape tree. We could call GetAsShapeTree, but that will try and + // apply our tuple_shardings_ to the shape tree, and that might cause a crash + // at this point as we haven't validated them. + ShapeTree bool_shape_tree(shape, false); + int64 num_leaves = + std::distance(bool_shape_tree.leaf_begin(), bool_shape_tree.leaf_end()); + if (num_leaves != tuple_elements_.size()) { + return tensorflow::errors::InvalidArgument( + StrCat("Validation tuple shape has ", num_leaves, + " leaf elements, but this sharding contains ", + tuple_elements_.size(), " elements.")); + } + + // Now we've validated the number of tuple elements, it's safe to request a + // shape tree. + ShapeTree shape_tree = GetAsShapeTree(shape); + for (const auto& index_to_sharding : shape_tree.leaves()) { + Status status = index_to_sharding.second.ValidateNonTuple( + ShapeUtil::GetSubshape(shape, index_to_sharding.first), num_devices); + if (!status.ok()) { + tensorflow::errors::AppendToMessage( + &status, StrCat("Note: While validating sharding tuple element ", + index_to_sharding.first.ToString(), " which is ", + index_to_sharding.second.ToString())); + return status; + } + } + return Status::OK(); +} + Status HloSharding::Validate(const Shape& shape, int64 num_devices) const { + Status status = IsTuple() ? ValidateTuple(shape, num_devices) + : ValidateNonTuple(shape, num_devices); + if (!status.ok()) { + tensorflow::errors::AppendToMessage( + &status, StrCat("Note: While validating sharding ", ToString(), + " against shape ", ShapeUtil::HumanString(shape))); + } + return status; +} + +Status HloSharding::ValidateNonTuple(const Shape& shape, + int64 num_devices) const { + if (ShapeUtil::IsTuple(shape)) { + return tensorflow::errors::InvalidArgument( + StrCat("Validation shape is a tuple but sharding is not.")); + } if (replicated_) { return Status::OK(); } @@ -174,13 +226,11 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const { // Don't overwrite a bad status, so we report the first error. if (status.ok()) { if (core >= num_devices) { - status = - tensorflow::errors::InvalidArgument(tensorflow::strings::StrCat( - "core ", core, " > ", num_devices, " in tile assignment")); + status = tensorflow::errors::InvalidArgument(StrCat( + "core ", core, " > ", num_devices, " in tile assignment")); } else if (seen_cores.count(core) != 0) { - status = - tensorflow::errors::InvalidArgument(tensorflow::strings::StrCat( - "core ", core, " is not unique in tile assignment")); + status = tensorflow::errors::InvalidArgument( + StrCat("core ", core, " is not unique in tile assignment")); } } seen_cores.insert(core); @@ -196,7 +246,8 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const { // The tile rank must be the same as the input rank. if (ShapeUtil::Rank(shape) != ShapeUtil::Rank(tile_shape_)) { return tensorflow::errors::InvalidArgument( - "Tile rank is different to the input rank"); + "Tile rank is different to the input rank. sharding=", ToString(), + ", input_shape=", ShapeUtil::HumanString(shape)); } // The tile shape must not be the same as the input shape without maximal_ @@ -214,9 +265,9 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const { auto tile_dim = tile_shape_.dimensions(i); auto shape_dim = shape.dimensions(i); if (tile_dim > shape_dim) { - return tensorflow::errors::InvalidArgument(tensorflow::strings::StrCat( - "Tile is larger than input shape (dimension ", i, ", ", tile_dim, - " > ", shape_dim)); + return tensorflow::errors::InvalidArgument( + StrCat("Tile is larger than input shape (dimension ", i, ", ", + tile_dim, " > ", shape_dim)); } } @@ -226,10 +277,10 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const { int64 expected_dim = CeilOfRatio(shape.dimensions(i), tile_shape_.dimensions(i)); if (tile_assignment_.dimensions()[i] != expected_dim) { - return tensorflow::errors::InvalidArgument(tensorflow::strings::StrCat( - "Tile assignment tensor has incorrect shape. Dimension ", i, - " expected ", expected_dim, " but got ", - tile_assignment_.dimensions()[i])); + return tensorflow::errors::InvalidArgument( + StrCat("Tile assignment tensor has incorrect shape. Dimension ", i, + " expected ", expected_dim, " but got ", + tile_assignment_.dimensions()[i])); } } diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index dbd16b7c9d4c942a62b4c7ca73b488f10cb83f73..7263198385cf0c84b1dac1e15177dcac99adaafb 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -80,6 +80,17 @@ class HloSharding { return HloSharding(flattened_list); } + // Creates a new sharding for a tuple type. The requested tuple shape must not + // be nested. For nested tuples, use the ShapeTree overload. + static HloSharding Tuple(const Shape& tuple_shape, + tensorflow::gtl::ArraySlice shardings) { + CHECK(ShapeUtil::IsTuple(tuple_shape)); + CHECK(!ShapeUtil::IsNestedTuple(tuple_shape)); + std::vector flattened_list(shardings.begin(), shardings.end()); + CHECK_EQ(flattened_list.size(), ShapeUtil::TupleElementCount(tuple_shape)); + return HloSharding(flattened_list); + } + // Create a new sharding from a protobuf OpSharding. static StatusOr FromProto(const OpSharding& proto); @@ -222,6 +233,11 @@ class HloSharding { tile_assignment_({0}), tuple_elements_(tuple_shardings) {} + // Internal helper to validate a tuple sharding. + Status ValidateTuple(const Shape& shape, int64 num_devices) const; + // Internal helper to validate a non-tuple (leaf) sharding. + Status ValidateNonTuple(const Shape& shape, int64 num_devices) const; + bool replicated_; bool maximal_; bool tuple_; diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index 3161dda271d86cc3eaa24e94d30be28887a604bd..0c7487b3ac77ff181d44dd55ebcf2608feaf02ea 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -145,11 +145,13 @@ TEST_F(HloShardingTest, NestedTuple) { ShapeUtil::MakeShape(F32, {4, 6}), }); + HloSharding tiled_sharding = HloSharding::Tile( + ShapeUtil::MakeShape(F32, {4, 3}), Array({{0, 1}})); OpSharding proto; proto.set_type(OpSharding::Type::OpSharding_Type_TUPLE); *proto.add_tuple_shardings() = HloSharding::Replicate().ToProto(); *proto.add_tuple_shardings() = HloSharding::AssignDevice(0).ToProto(); - *proto.add_tuple_shardings() = HloSharding::AssignDevice(1).ToProto(); + *proto.add_tuple_shardings() = tiled_sharding.ToProto(); HloSharding tuple_sharding = HloSharding::FromProto(proto).ConsumeValueOrDie(); @@ -157,7 +159,15 @@ TEST_F(HloShardingTest, NestedTuple) { tuple_sharding.GetAsShapeTree(nested_tuple_shape); EXPECT_EQ(shape_tree.element({0}), HloSharding::Replicate()); EXPECT_EQ(shape_tree.element({1, 0}), HloSharding::AssignDevice(0)); - EXPECT_EQ(shape_tree.element({2}), HloSharding::AssignDevice(1)); + EXPECT_EQ(shape_tree.element({2}), tiled_sharding); + + EXPECT_IS_OK(tuple_sharding.Validate(nested_tuple_shape, /*num_devices=*/5)); + // Test should fail because tuple element count does not match. + EXPECT_IS_NOT_OK(tuple_sharding.Validate(ShapeUtil::MakeTupleShape({}), + /*num_devices=*/5)); + // Test should fail because the input type is not a tuple. + EXPECT_IS_NOT_OK(tuple_sharding.Validate(ShapeUtil::MakeShape(F32, {}), + /*num_devices=*/5)); } TEST_F(HloShardingTest, Hash) { diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc index 101a710d1cad9401134fdfe1d0ec9df241bc01e1..3dc733940fc89952bd5e75a9b28d9cbf356f8000 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc @@ -166,7 +166,7 @@ void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction, layout_string = ShapeUtil::HumanStringWithLayout(instruction->shape()); } else { layout_string = StrCat( - "{", Join(instruction->shape().layout().minor_to_major(), ","), "}"); + "{", Join(LayoutUtil::MinorToMajor(instruction->shape()), ","), "}"); } attrs["layout"].set_s(layout_string); } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index c938450891ac170b1a9bea5eea0c7af19f8a180d..d963a8a2f4fac563f7e8d4e9d4dc3d6e761d40de 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -59,21 +59,27 @@ class ShapeVerifier : public DfsHloVisitor { } Status HandleConvert(HloInstruction* convert) override { - if (ShapeUtil::ElementIsComplex(convert->operand(0)->shape())) { - TF_RET_CHECK(ShapeUtil::ElementIsComplex(convert->shape())) - << "Unsupported complex->real kConvert"; - } return CheckShape(convert, ShapeInference::InferConvertShape( convert->operand(0)->shape(), convert->shape().element_type())); } + Status HandleBitcastConvert(HloInstruction* convert) override { + return CheckShape(convert, ShapeInference::InferBitcastConvertShape( + convert->operand(0)->shape(), + convert->shape().element_type())); + } + Status HandleCopy(HloInstruction* copy) override { return CheckUnaryShape(copy); } Status HandleDot(HloInstruction* dot) override { - return CheckBinaryShape(dot); + TF_ASSIGN_OR_RETURN(const Shape expected, + ShapeInference::InferDotOpShape( + dot->operand(0)->shape(), dot->operand(1)->shape(), + dot->dot_dimension_numbers())); + return CheckShape(dot, expected); } Status HandleConvolution(HloInstruction* convolution) override { @@ -87,8 +93,12 @@ class ShapeVerifier : public DfsHloVisitor { } Status HandleCrossReplicaSum(HloInstruction* crs) override { - return CheckShape(crs, ShapeInference::InferCrossReplicaSumShape( - crs->operand(0)->shape())); + std::vector operand_shapes; + for (const HloInstruction* operand : crs->operands()) { + operand_shapes.push_back(&operand->shape()); + } + return CheckShape( + crs, ShapeInference::InferCrossReplicaSumShape(operand_shapes)); } Status HandleReducePrecision(HloInstruction* reduce_precision) override { @@ -141,9 +151,6 @@ class ShapeVerifier : public DfsHloVisitor { } Status HandleBitcast(HloInstruction* bitcast) override { - // Bitcasts can be any shape, as long as the size matches the operand size. - TF_RET_CHECK(shape_size_fn_(bitcast->shape()) == - shape_size_fn_(bitcast->operand(0)->shape())); return tensorflow::Status::OK(); } @@ -263,6 +270,15 @@ class ShapeVerifier : public DfsHloVisitor { xla_while->while_body()->ComputeProgramShape().result()); } + Status HandleConditional(HloInstruction* conditional) override { + TF_RETURN_IF_ERROR(CheckShape( + conditional, + conditional->true_computation()->ComputeProgramShape().result())); + return CheckShape( + conditional, + conditional->false_computation()->ComputeProgramShape().result()); + } + Status HandlePad(HloInstruction* pad) override { return CheckShape(pad, ShapeInference::InferPadShape(pad->operand(0)->shape(), @@ -272,7 +288,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleSend(HloInstruction* send) override { TF_RET_CHECK(send->users().size() == 1); - const HloInstruction* send_done = send->users()[0]; + const HloInstruction* send_done = send->users().front(); TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone); TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done)); return CheckShape( @@ -290,7 +306,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleRecv(HloInstruction* recv) override { TF_RET_CHECK(recv->users().size() == 1); - const HloInstruction* recv_done = recv->users()[0]; + const HloInstruction* recv_done = recv->users().front(); TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone); TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done)); return CheckShape(recv, @@ -418,6 +434,63 @@ string ComputationsToString( }); } +// Verifies various invariants about the structure of the HLO: +// +// (1) each instruction has a non-null parent() set to the HloComputation which +// contains it. +// +// (2) each computation has a non-null parent() set to the HloModule which +// contains it. +// +// (3) the operands of each instruction are in the same computation as the +// instruction. +Status VerifyHloStructure(HloModule* module) { + for (const HloComputation* computation : module->computations()) { + if (computation->parent() == nullptr) { + return FailedPrecondition("Computation %s has a null parent pointer", + computation->name().c_str()); + } + if (computation->parent() != module) { + return FailedPrecondition( + "Computation %s parent() does not point to parent module", + computation->name().c_str()); + } + + for (const HloInstruction* instruction : computation->instructions()) { + if (instruction->parent() == nullptr) { + return FailedPrecondition("Instruction %s has a null parent pointer", + instruction->name().c_str()); + } + if (instruction->parent() != computation) { + return FailedPrecondition( + "Instruction %s parent() does not point to parent computation", + instruction->name().c_str()); + } + } + } + + // Check that operands are in the same computation separately from verifying + // parent() correctness so conditions like a null HloInstruction::parent() are + // identified and reported explicitly above rather than reporting a mismatched + // operand. + for (const HloComputation* computation : module->computations()) { + for (const HloInstruction* instruction : computation->instructions()) { + for (int i = 0; i < instruction->operand_count(); ++i) { + const HloInstruction* operand = instruction->operand(i); + if (operand->parent() != instruction->parent()) { + return FailedPrecondition( + "Operand %d (%s) of instruction %s is in a different " + "computation: %s vs %s", + i, operand->name().c_str(), instruction->name().c_str(), + operand->parent()->name().c_str(), + instruction->parent()->name().c_str()); + } + } + } + } + return tensorflow::Status::OK(); +} + } // namespace Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { @@ -538,6 +611,8 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { } StatusOr HloVerifier::Run(HloModule* module) { + TF_RETURN_IF_ERROR(VerifyHloStructure(module)); + tensorflow::gtl::FlatMap instructions; ShapeVerifier shape_verifier(shape_size_fn_); @@ -571,7 +646,7 @@ StatusOr HloVerifier::Run(HloModule* module) { // or ComputationLowerer::Visit() TF_RET_CHECK(instruction->dimensions().size() == ShapeUtil::Rank(instruction->operand(0)->shape())) - << "Broadcast HLO has invalid number of dimensions."; + << "Broadcast HLO has invalid number of dimensions."; } else if (instruction->opcode() == HloOpcode::kWhile) { auto* while_cond = instruction->while_condition(); auto* while_body = instruction->while_body(); diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..2a3b55decc5289e7e576d3c5897b333c0b1bc922 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -0,0 +1,101 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_verifier.h" + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +using ::testing::HasSubstr; + +using HloVerifierTest = HloTestBase; + +TEST_F(HloVerifierTest, NullInstructionParent) { + HloComputation::Builder builder(TestName()); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + HloInstruction* negate = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK(verifier().Run(module.get()).status()); + + negate->set_parent(nullptr); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), HasSubstr("has a null parent pointer")); +} + +TEST_F(HloVerifierTest, NullComputationParent) { + HloComputation::Builder builder(TestName()); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK(verifier().Run(module.get()).status()); + + computation->set_parent(nullptr); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), HasSubstr("has a null parent pointer")); +} + +TEST_F(HloVerifierTest, DifferentOperandParents) { + HloComputation::Builder builder(TestName()); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + HloInstruction* negate = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + HloComputation::Builder emb_builder(TestName()); + HloInstruction* emb_param = emb_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + module->AddEmbeddedComputation(emb_builder.Build()); + + TF_ASSERT_OK(verifier().Run(module.get()).status()); + TF_ASSERT_OK(negate->ReplaceOperandWith(0, emb_param)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("is in a different computation")); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index de4804996f84ef68ca80cef0178ad786ddaa3a39..ba901b99e4f3c72c84c1ecdf4e19e58ad9ab6506 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -33,7 +33,9 @@ namespace xla { switch (instruction.opcode()) { // Cheap instructions. case HloOpcode::kAdd: + case HloOpcode::kAnd: case HloOpcode::kBitcast: + case HloOpcode::kBitcastConvert: case HloOpcode::kBroadcast: case HloOpcode::kCeil: case HloOpcode::kClamp: @@ -53,15 +55,14 @@ namespace xla { case HloOpcode::kInfeed: case HloOpcode::kIsFinite: case HloOpcode::kLe: - case HloOpcode::kAnd: - case HloOpcode::kNot: - case HloOpcode::kOr: case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kMultiply: case HloOpcode::kNe: case HloOpcode::kNegate: + case HloOpcode::kNot: + case HloOpcode::kOr: case HloOpcode::kOutfeed: case HloOpcode::kPad: case HloOpcode::kReal: @@ -88,9 +89,9 @@ namespace xla { // Expensive instructions. case HloOpcode::kAtan2: - case HloOpcode::kBatchNormTraining: - case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormGrad: + case HloOpcode::kBatchNormInference: + case HloOpcode::kBatchNormTraining: case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kConvolution: @@ -104,19 +105,19 @@ namespace xla { case HloOpcode::kMap: case HloOpcode::kParameter: case HloOpcode::kPower: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: case HloOpcode::kRemainder: case HloOpcode::kRng: case HloOpcode::kSelectAndScatter: + case HloOpcode::kSend: + case HloOpcode::kSendDone: case HloOpcode::kSort: case HloOpcode::kTanh: case HloOpcode::kTrace: case HloOpcode::kWhile: - case HloOpcode::kSend: - case HloOpcode::kSendDone: - case HloOpcode::kRecv: - case HloOpcode::kRecvDone: return true; } diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 2704a805a91b93c69b751cdb61305ea7780f0ef2..0819ab3b90b2360c6b0b2afaa89f322afe566eb3 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -92,6 +92,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_execution_profile", "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", ], diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index 6d5796a24b5209355debd80b912b7fa62d40837c..dc63a2224d659fa427d4d1a30c5dc0f94d643b36 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -69,13 +69,19 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { return pipeline.Run(hlo_module).status(); } -StatusOr> InterpreterCompiler::Compile( +StatusOr> InterpreterCompiler::RunHloPasses( + std::unique_ptr hlo_module, + se::StreamExecutor* /*stream_exec*/) { + VLOG(1) << "Run hlo passes on graph " << hlo_module->name(); + TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get())); + return std::move(hlo_module); +} + +StatusOr> InterpreterCompiler::RunBackend( std::unique_ptr hlo_module, se::StreamExecutor* stream_exec) { TF_RET_CHECK(stream_exec != nullptr); - VLOG(1) << "Generate graph " << hlo_module->name(); - - TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get())); + VLOG(1) << "Run backend " << hlo_module->name(); // Typically you would visit the HLO graph, building up a compiled equivalent // In this case we are using an HloEvaluator at execution time, so we don't diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.h b/tensorflow/compiler/xla/service/interpreter/compiler.h index cfdc9b6256569b0137784b0d1db846a5f2339a5d..278cf5184227ae25518b1d46c0e16e4cce7bd1a8 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.h +++ b/tensorflow/compiler/xla/service/interpreter/compiler.h @@ -43,8 +43,12 @@ class InterpreterCompiler : public Compiler { InterpreterCompiler() {} ~InterpreterCompiler() override {} - StatusOr> Compile( - std::unique_ptr hlo_modules, + StatusOr> RunHloPasses( + std::unique_ptr hlo_module, + perftools::gputools::StreamExecutor* stream_exec) override; + + StatusOr> RunBackend( + std::unique_ptr hlo_module, perftools::gputools::StreamExecutor* stream_exec) override; StatusOr>> Compile( diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 96f937caf96232a72b2f3d80d2269d6ade2327dc..b01fcccdb4b338ed2575d1f2c48401adc648a09a 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/interpreter/executor.h" +#include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/lib/core/errors.h" @@ -42,48 +43,23 @@ namespace sep = ::perftools::gputools::interpreter; InterpreterExecutable::InterpreterExecutable( std::unique_ptr hlo_module) - : Executable(std::move(hlo_module)) {} + : Executable(std::move(hlo_module), /*hlo_profile_printer=*/nullptr, + /*hlo_profile_index_map=*/nullptr) {} InterpreterExecutable::~InterpreterExecutable() {} -static se::DeviceMemoryBase AllocateSingleOutput( - sep::InterpreterExecutor* executor, const Literal& literal) { - int64 size(xla::ShapeUtil::ByteSizeOf(literal.shape())); - void* buf = executor->Allocate(size); - const void* src = literal.InternalData(); - memcpy(buf, src, size); - return se::DeviceMemoryBase(buf, size); -} - -static se::DeviceMemoryBase AllocateOutputBuffer( - sep::InterpreterExecutor* executor, const Literal& literal) { - const Shape& shape = literal.shape(); - if (shape.element_type() != xla::TUPLE) { - return AllocateSingleOutput(executor, literal); - } else { - int64 size(xla::ShapeUtil::ByteSizeOf(shape, sizeof(void*))); - void** buf = reinterpret_cast(executor->Allocate(size)); - void** buf_rc = buf; - for (int64 n = 0; n < xla::ShapeUtil::TupleElementCount(shape); n++) { - se::DeviceMemoryBase out = - AllocateSingleOutput(executor, literal.tuple_literals(n)); - *buf++ = out.opaque(); - } - - return se::DeviceMemoryBase(buf_rc, size); - } -} - -StatusOr InterpreterExecutable::ExecuteOnStream( +StatusOr> InterpreterExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) { se::Stream* stream = run_options->stream(); + se::StreamExecutor* executor = stream->parent(); + const se::Platform* platform = executor->platform(); VLOG(1) << "Execute " << module().name(); if (VLOG_IS_ON(2)) { for (const auto& a : arguments) { - VLOG(2) << "-- argument " << a.opaque(); + VLOG(2) << "-- argument " << *a; } } @@ -95,33 +71,32 @@ StatusOr InterpreterExecutable::ExecuteOnStream( "Mismatch between argument count and graph parameter count."); } - // Create the arguments as an vector of XLA literals + TF_ASSIGN_OR_RETURN(TransferManager * transfer_manager, + TransferManager::GetForPlatform(platform)); + + // Transform the ShapedBuffer arguments into literals which the evaluator + // consumes. std::vector> arg_literals; - std::vector arg_literals_ptrs; for (int64 p = 0; p < computation->num_parameters(); ++p) { - // Create the input literal for the parameter - HloInstruction* param = computation->parameter_instruction(p); - arg_literals.emplace_back(Literal::CreateFromShape(param->shape())); - arg_literals_ptrs.push_back(arg_literals.back().get()); - - // Copy in the data from the stream_executor buffers - void* buffer = arg_literals.back()->MutableInternalData(); - memcpy(buffer, arguments[p].opaque(), - ShapeUtil::ByteSizeOf(param->shape())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr arg_literal, + transfer_manager->TransferLiteralFromDevice(executor, *arguments[p])); + arg_literals.push_back(std::move(arg_literal)); } // Execute the graph using the HloEvaluator. HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN(std::unique_ptr output, - evaluator.Evaluate(*computation, arg_literals_ptrs)); - - // Copy the result into the return buffer - perftools::gputools::StreamExecutor* executor(stream->parent()); - sep::InterpreterExecutor* interpreter_executor( - static_cast(executor->implementation())); - - se::DeviceMemoryBase ret = - AllocateOutputBuffer(interpreter_executor, *(output.get())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr result_literal, + evaluator.Evaluate>(*computation, arg_literals)); + + // Transform the result literal back into a ShapedBuffer. + TF_ASSIGN_OR_RETURN(std::unique_ptr result, + transfer_manager->AllocateShapedBuffer( + result_literal->shape(), run_options->allocator(), + run_options->device_ordinal())); + TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice( + executor, *result_literal, *result)); uint64 end_micros = tensorflow::Env::Default()->NowMicros(); @@ -131,20 +106,13 @@ StatusOr InterpreterExecutable::ExecuteOnStream( execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0)); } - return ret; -} - -StatusOr> InterpreterExecutable::ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - HloExecutionProfile* hlo_execution_profile) { - return tensorflow::errors::Unimplemented( - "ExecuteOnStream is not yet supported on Interpreter."); + return std::move(result); } -StatusOr InterpreterExecutable::ExecuteAsyncOnStream( +StatusOr> +InterpreterExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) { + tensorflow::gtl::ArraySlice arguments) { return tensorflow::errors::Unimplemented( "ExecuteAsyncOnStream is not yet supported on Interpreter."); } @@ -156,10 +124,5 @@ StatusOr InterpreterExecutable::ExecuteAsyncOnStream( return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); } -std::unique_ptr InterpreterExecutable::CreateCostAnalysis() - const { - return MakeUnique(ShapeSizeBytes); -} - } // namespace interpreter } // namespace xla diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h index c69b0d036d1058a6b24ee609a9923895d3246eec..410110a1adf04c83001c38ed03f5d60dd203dc7e 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.h +++ b/tensorflow/compiler/xla/service/interpreter/executable.h @@ -43,26 +43,17 @@ class InterpreterExecutable : public Executable { InterpreterExecutable(std::unique_ptr hlo_module); ~InterpreterExecutable() override; - StatusOr ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments, - HloExecutionProfile* hlo_execution_profile) override; - StatusOr> ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) override; - StatusOr ExecuteAsyncOnStream( + StatusOr> ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments) override; + tensorflow::gtl::ArraySlice arguments) override; static int64 ShapeSizeBytes(const Shape& shape); - std::unique_ptr CreateCostAnalysis() const override; - private: TF_DISALLOW_COPY_AND_ASSIGN(InterpreterExecutable); }; diff --git a/tensorflow/compiler/xla/service/interpreter/executor.cc b/tensorflow/compiler/xla/service/interpreter/executor.cc index 0bb3259ef43915067e614e72038387e8300ecc41..68371910d76f42c0b6d4b1adad9d6a83bdb858e6 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.cc +++ b/tensorflow/compiler/xla/service/interpreter/executor.cc @@ -85,7 +85,7 @@ bool InterpreterExecutor::HostCallback(Stream *stream, bool InterpreterExecutor::CreateStreamDependency(Stream *dependent, Stream *other) { AsExecutorStream(dependent)->EnqueueTask( - [other]() { other->BlockHostUntilDone(); }); + [other]() { SE_CHECK_OK(other->BlockHostUntilDone()); }); AsExecutorStream(dependent)->BlockUntilDone(); return true; } @@ -100,9 +100,9 @@ bool InterpreterExecutor::StopTimer(Stream *stream, Timer *timer) { return true; } -bool InterpreterExecutor::BlockHostUntilDone(Stream *stream) { +port::Status InterpreterExecutor::BlockHostUntilDone(Stream *stream) { AsExecutorStream(stream)->BlockUntilDone(); - return true; + return port::Status::OK(); } DeviceDescription *InterpreterExecutor::PopulateDeviceDescription() const { diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h index c59b2ccb1505b78be0c459ac9311428d65cc7e44..c5d07e906dafb033905c50c604069e80e1ce80cd 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.h +++ b/tensorflow/compiler/xla/service/interpreter/executor.h @@ -157,7 +157,7 @@ class InterpreterExecutor : public internal::StreamExecutorInterface { bool StartTimer(Stream *stream, Timer *timer) override; bool StopTimer(Stream *stream, Timer *timer) override; - bool BlockHostUntilDone(Stream *stream) override; + port::Status BlockHostUntilDone(Stream *stream) override; int PlatformDeviceCount() override { return 1; } diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 7eda7c2284c2457703fcfcd4226172e41dd4ae01..42bca3b783c5f3390e9507d54fb07660d9f98e35 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -477,16 +477,10 @@ Status LayoutAssignment::AddMandatoryConstraints( /*mandatory=*/true)); } else if (instruction->opcode() == HloOpcode::kCustomCall) { // Add constraints for kCustomCall instruction operands and instructions. - // For now we only support row major layouts for all inputs and outputs. - auto row_major_shape = [](const Shape& old_shape) { - Shape new_shape(old_shape); - std::vector dimension_order(new_shape.dimensions_size()); - std::iota(dimension_order.rbegin(), dimension_order.rend(), 0); - *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order); - return new_shape; - }; - - Shape result_shape(row_major_shape(instruction->shape())); + // For now we only support major-first layouts for all inputs and outputs. + Shape result_shape = ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + instruction->shape().element_type(), + AsInt64Slice(instruction->shape().dimensions())); TF_RETURN_IF_ERROR( constraints->SetInstructionLayout(result_shape, instruction)); for (int64 i = 0; i < instruction->operand_count(); ++i) { @@ -496,7 +490,10 @@ Status LayoutAssignment::AddMandatoryConstraints( continue; } - Shape row_major_operand_shape(row_major_shape(operand_shape)); + Shape row_major_operand_shape = + ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + operand_shape.element_type(), + AsInt64Slice(operand_shape.dimensions())); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( row_major_operand_shape, instruction, i, /*mandatory=*/true)); } @@ -530,9 +527,11 @@ Status CheckCallLayout(HloInstruction* call, Status CheckCustomCallLayout(HloInstruction* custom_call) { for (const HloInstruction* operand : custom_call->operands()) { TF_RET_CHECK( + ShapeUtil::IsOpaque(operand->shape()) || LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout())); } TF_RET_CHECK( + ShapeUtil::IsOpaque(custom_call->shape()) || LayoutUtil::IsMonotonicWithDim0Major(custom_call->shape().layout())); return Status::OK(); } @@ -711,8 +710,8 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( int64 operand_no) { const HloInstruction* operand = instruction->operand(operand_no); - CHECK(ShapeUtil::IsArray(instruction->shape()) && - ShapeUtil::IsArray(operand->shape())); + CHECK(ShapeUtil::IsArray(instruction->shape())); + CHECK(ShapeUtil::IsArray(operand->shape())); if (instruction->IsElementwiseOnOperand(operand_no) && !ShapeUtil::IsScalar(operand->shape()) && @@ -742,7 +741,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( const Shape& output_shape = instruction->shape(); Shape output_shape_with_layout = ShapeUtil::MakeShapeWithLayout( output_shape.element_type(), AsInt64Slice(output_shape.dimensions()), - AsInt64Slice(output_layout.minor_to_major())); + LayoutUtil::MinorToMajor(output_layout)); Shape operand_shape = operand->shape(); *operand_shape.mutable_layout() = LayoutUtil::GetDefaultLayoutForShape(operand_shape); @@ -771,7 +770,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( int64 rank = ShapeUtil::Rank(instruction->shape()); std::vector new_minor_to_major(rank); for (int64 i = 0; i < rank; ++i) { - int64 output_dim = output_layout.minor_to_major(i); + int64 output_dim = LayoutUtil::Minor(output_layout, i); int64 operand_dim = instruction->dimensions(output_dim); new_minor_to_major[i] = operand_dim; } @@ -814,7 +813,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( Shape operand_shape_with_layout = ShapeUtil::MakeShapeWithLayout( operand->shape().element_type(), AsInt64Slice(operand->shape().dimensions()), - AsInt64Slice(operand_layout.minor_to_major())); + LayoutUtil::MinorToMajor(operand_layout)); Shape output_shape = user->shape(); *output_shape.mutable_layout() = LayoutUtil::GetDefaultLayoutForShape(output_shape); @@ -844,7 +843,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( std::vector new_minor_to_major(rank); auto inverse_dimensions = InversePermutation(user->dimensions()); for (int64 i = 0; i < rank; ++i) { - int64 operand_dim = operand_layout.minor_to_major(i); + int64 operand_dim = LayoutUtil::Minor(operand_layout, i); int64 user_dim = inverse_dimensions[operand_dim]; new_minor_to_major[i] = user_dim; } @@ -1303,7 +1302,7 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape())); } - // Copy the root instrucion's result if the it does not match the result + // Copy the root instruction's result if the it does not match the result // layout constraint if (constraints.ResultLayout() != nullptr && !constraints.ResultLayout()->MatchesLayoutInShape( @@ -1328,6 +1327,20 @@ Status LayoutAssignment::RunOnComputation( << ")"; VLOG(2) << " ComputationLayout = " << computation_layout.ToString(); + // Clear existing layouts of the instructions. All layouts must be assigned by + // the LayoutAssignment pass, except for Infeed, Outfeed, Parameters and the + // computation result. The latter two are specified in computation_layout, so + // we only need to keep the existing layouts for Infeed and Outfeed. Clearing + // the layouts here avoids hiding potential bugs in the layout assignment pass + // that may accidently use the existing layout. + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kInfeed || + instruction->opcode() == HloOpcode::kOutfeed) { + continue; + } + LayoutUtil::ClearLayout(instruction->mutable_shape()); + } + // Construct LayoutConstraints with all layout constraints of the computation. LayoutConstraints constraints(points_to_analysis, computation); diff --git a/tensorflow/compiler/xla/service/liveness_util.cc b/tensorflow/compiler/xla/service/liveness_util.cc index 53d88eda7a81a8cd0ea245de84011cce0ab3eafe..68c99256a246edcf43a8358f667fc4458b9b4fea 100644 --- a/tensorflow/compiler/xla/service/liveness_util.cc +++ b/tensorflow/compiler/xla/service/liveness_util.cc @@ -103,7 +103,7 @@ namespace { // Returns all uses of all aliases of 'instruction' at 'index' in 'uses'. // Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index) -// where 'user' is a user of an alias of 'intruction' at 'index', and +// where 'user' is a user of an alias of 'instruction' at 'index', and // 'operand_index' is the operand index at which the alias appears in the // operand list of 'user'. std::vector> GetAllUsesOfInstructionAtIndex( @@ -243,6 +243,31 @@ bool CanShareOperandBufferWithUser( std::vector operand_indices = user->OperandIndices(operand); return operand_indices.size() == 1 && operand_indices[0] == 0; } + if (user->opcode() == HloOpcode::kCall) { + // TODO(b/62548313): Remove when buffer assignment is module scoped and + // does not assign buffers to calls. + // Find called computation parameter associated with 'operand'. + const std::vector operand_indices = user->OperandIndices(operand); + if (operand_indices.size() > 1) { + return false; + } + CHECK_EQ(1, operand_indices.size()); + auto* param = user->to_apply()->parameter_instruction(operand_indices[0]); + // Get all uses of 'operand' at 'index' in called computation. + auto param_uses = GetAllUsesOfInstructionAtIndex(param, operand_index, + points_to_analysis); + + // Return true iff: + // *) There exists exactly one use of 'operand' in called computation. + // *) The unique use is by the root instruction of called computation. + // (Note: we check the root of the called computation, because the + // root result buffer is required to alias with the Call result buffer). + // *) The root instruction of the called computation is element-wise on + // 'operand'. + auto* callee_root = user->to_apply()->root_instruction(); + return param_uses.size() == 1 && param_uses[0].first == callee_root && + callee_root->IsElementwiseOnOperand(param_uses[0].second); + } // Check if 'user' is element-wise. return user->IsElementwise(); } @@ -322,6 +347,31 @@ bool CanShareOperandBufferWithUser(HloInstruction* operand, std::vector operand_indices = user->OperandIndices(operand); return operand_indices.size() == 1 && operand_indices[0] == 0; } + if (user->opcode() == HloOpcode::kCall) { + // Get all uses of value defined by 'operand' at 'operand_index'. + const auto& uses = + dataflow.GetValueDefinedAt(operand, operand_index).uses(); + // Return true iff: + // *) There exists two uses of 'operand'. + // *) One use is by 'user' (caller). + // *) One use is by root instruction of called computation (callee root). + // (Note: we check the root of the called computation, because the + // root result buffer is required to alias with the Call result buffer). + // *) The root instruction of the called computation is element-wise on + // 'operand'. + const bool found_caller_use = + std::find_if(uses.begin(), uses.end(), [user](const HloUse& use) { + return use.instruction == user; + }) != uses.end(); + auto* callee_root = user->to_apply()->root_instruction(); + const bool found_elementwise_callee_use = + std::find_if( + uses.begin(), uses.end(), [callee_root](const HloUse& use) { + return use.instruction == callee_root && + callee_root->IsElementwiseOnOperand(use.operand_number); + }) != uses.end(); + return uses.size() == 2 && found_caller_use && found_elementwise_callee_use; + } // Check if 'user' is element-wise. return user->IsElementwise(); } diff --git a/tensorflow/compiler/xla/service/liveness_util_test.cc b/tensorflow/compiler/xla/service/liveness_util_test.cc index b5e15906d3c085f773eb46b543515a614e63c59a..2c2a02f6375343d67dfb155bbb03729ff6e490d2 100644 --- a/tensorflow/compiler/xla/service/liveness_util_test.cc +++ b/tensorflow/compiler/xla/service/liveness_util_test.cc @@ -277,8 +277,11 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { auto b = builder.AddInstruction(HloInstruction::CreateConstant( Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(data_shape, HloOpcode::kDot, a, b)); + HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); auto one = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(1.0))); @@ -312,8 +315,11 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) { auto b_t = builder.AddInstruction( HloInstruction::CreateTranspose(data_shape, b, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(data_shape, HloOpcode::kDot, a, b_t)); + HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums)); auto one = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(1.0))); @@ -415,5 +421,44 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { CanShareOperandBufferWithUser(data, {}, whil, {}, *dataflow_analysis_)); } +// Tests that Call can alias operand buffer if the only use of the operand +// in the called computation is an elementwise instruction. +TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { + Shape shape = ShapeUtil::MakeShape(F32, {8}); + // Build sub-computation with fusion root. + auto sub_builder = HloComputation::Builder(TestName() + "_sub"); + auto sub_param = sub_builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "sub_param")); + auto one = sub_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto ones = sub_builder.AddInstruction( + HloInstruction::CreateBroadcast(shape, one, {1})); + auto add = sub_builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones)); + + module_ = CreateNewModule(); + auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build()); + sub_computation->CreateFusionInstruction({add, ones}, + HloInstruction::FusionKind::kLoop); + + // Build entry-computation with kCall which calls 'sub_computation'. + auto builder = HloComputation::Builder(TestName()); + + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto reverse = + builder.AddInstruction(HloInstruction::CreateReverse(shape, param, {0})); + auto call = builder.AddInstruction( + HloInstruction::CreateCall(shape, {reverse}, sub_computation)); + computation_ = module_->AddEntryComputation(builder.Build()); + + RunAnalysis(); + + EXPECT_TRUE(CanShareOperandBufferWithUser(reverse, {}, call, {}, + *points_to_analysis_)); + EXPECT_TRUE(CanShareOperandBufferWithUser(reverse, {}, call, {}, + *dataflow_analysis_)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_compiler.cc b/tensorflow/compiler/xla/service/llvm_compiler.cc index ba0304fb8ca0de9cffc705f471eb0b740747ec92..34f3419269abbc73cd0ddb13c723a8da38ab19ff 100644 --- a/tensorflow/compiler/xla/service/llvm_compiler.cc +++ b/tensorflow/compiler/xla/service/llvm_compiler.cc @@ -27,8 +27,10 @@ StatusOr>> LLVMCompiler::Compile( "Model partitioning not implemented for the CPU/GPU compilers!"); } + TF_ASSIGN_OR_RETURN( + modules[i], RunHloPasses(std::move(modules[i]), stream_execs[i][0])); TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - Compile(std::move(modules[i]), stream_execs[i][0])); + RunBackend(std::move(modules[i]), stream_execs[i][0])); result.push_back(std::move(executable)); } diff --git a/tensorflow/compiler/xla/service/llvm_compiler.h b/tensorflow/compiler/xla/service/llvm_compiler.h index c4f689eabedd4eabe98d907bd3d6b185dfa4bd10..c5393cef4f961c5d04c32d0d4291732b8ec702f1 100644 --- a/tensorflow/compiler/xla/service/llvm_compiler.h +++ b/tensorflow/compiler/xla/service/llvm_compiler.h @@ -58,10 +58,14 @@ class LLVMCompiler : public Compiler { void RemovePostOptimizationHook() { user_post_optimization_hook_ = nullptr; } // Bring in - // StatusOr> Compile( - // std::unique_ptr module, - // perftools::gputools::StreamExecutor* executor) - using Compiler::Compile; + // StatusOr> RunBackend( + // std::unique_ptr module, + // perftools::gputools::StreamExecutor* stream_exec) + // StatusOr> RunHloPasses( + // std::unique_ptr module, + // perftools::gputools::StreamExecutor* stream_exec) + using Compiler::RunBackend; + using Compiler::RunHloPasses; StatusOr>> Compile( std::vector> modules, diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 7224bd689842d89563b374f3db3d4e314be18764..c558f7388cab587b5858d0594cdb2f3c41d75562 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -39,7 +39,7 @@ IrArray::Index::Index(llvm::Value* linear, const Shape& shape, << "Shape " << ShapeUtil::HumanStringWithLayout(shape) << " should have a layout."; int64 divisor = 1; - for (int64 dimension : layout_.minor_to_major()) { + for (int64 dimension : LayoutUtil::MinorToMajor(layout_)) { int64 size_of_current_dimension = shape.dimensions(dimension); // Emit IR instructions that compute // (linear_index / divisor) % current_dimension @@ -244,8 +244,8 @@ llvm::Value* IrArray::EmitArrayElementAddress( // // getelementptr base_ptr_, 0, most major index, ..., most minor index std::vector gep_indices(1, ir_builder->getInt64(0)); - for (int64 i = shape_->layout().minor_to_major_size() - 1; i >= 0; --i) { - int64 dimension = shape_->layout().minor_to_major(i); + for (int64 i = 0; i < LayoutUtil::MinorToMajor(*shape_).size(); ++i) { + int64 dimension = LayoutUtil::Major(shape_->layout(), i); gep_indices.push_back(actual_index[dimension]); } return ir_builder->CreateInBoundsGEP(base_ptr_, gep_indices, diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc index 29cc0f81bd2c06538e28d1b593ee6a897fea0f27..23d2d4e87d26f4988ebddcf20f5a27af6a7fe0d6 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" namespace xla { void KernelSupportLibrary::For( @@ -62,4 +63,72 @@ void KernelSupportLibrary::If( false_block_generator(); llvm_ir::SetToLastInsertPoint(if_data.after_block, ir_builder_); } + +void KernelSupportLibrary::EmitAndCallOutlinedKernel( + bool enable_fast_math, bool optimize_for_size, + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name, + KernelSupportLibrary::ArgumentVector arguments, + const std::function& + kernel_body_generator) { + llvm::Module* module = ir_builder->GetInsertBlock()->getModule(); + llvm::Function* function = + module->getFunction(llvm_ir::AsStringRef(kernel_name)); + + int64 null_arg_idx = -1; + std::vector sanitized_args; + sanitized_args.reserve(arguments.size()); + for (int64 i = 0, e = arguments.size(); i < e; i++) { + if (arguments[i]) { + sanitized_args.push_back(arguments[i]); + } else { + CHECK_EQ(null_arg_idx, -1); + null_arg_idx = i; + } + } + + if (!function) { + VLOG(2) << "Generating kernel for " << kernel_name; + std::vector arg_types; + std::transform(sanitized_args.begin(), sanitized_args.end(), + std::back_inserter(arg_types), + [](llvm::Value* arg) { return arg->getType(); }); + + auto* function_type = llvm::FunctionType::get( + ir_builder->getVoidTy(), arg_types, /*isVarArg=*/false); + + function = llvm_ir::CreateFunction( + function_type, llvm::GlobalValue::InternalLinkage, + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size, kernel_name, module); + + llvm::IRBuilder<>::InsertPointGuard guard(*ir_builder); + + auto* entry_bb = + llvm::BasicBlock::Create(ir_builder->getContext(), "entry", function); + auto* return_inst = llvm::ReturnInst::Create(ir_builder->getContext(), + /*retVal=*/nullptr, entry_bb); + // Set the insert point to before return_inst. + ir_builder->SetInsertPoint(return_inst); + + std::vector arg_values; + /* + * clang on OSX doesn't like std::transform or range for loop here. + * See https://github.com/tensorflow/tensorflow/issues/15196 + */ + for (llvm::Function::arg_iterator arg = function->arg_begin(), + arg_e = function->arg_end(); + arg != arg_e; ++arg) { + arg_values.push_back(arg); + } + if (null_arg_idx != -1) { + arg_values.insert(arg_values.begin() + null_arg_idx, nullptr); + } + kernel_body_generator(arg_values); + } else { + VLOG(3) << "Re-using kernel for " << kernel_name; + } + + ir_builder->CreateCall(function, llvm_ir::AsArrayRef(sanitized_args)); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h index 9bafb7b57740b7acd0286c113c8a0585c0f93689..827e092a3fa9116c461716b27c309033f7988745 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -118,6 +118,60 @@ class KernelSupportLibrary { const std::function& true_block_generator, const std::function& false_block_generator = []() {}); + using ArgumentVector = tensorflow::gtl::ArraySlice; + + // Generates the following control flow structure: + // + // define @`kernel_name`(arg0, arg1, ... arg`arguments.size()`) { + // kernel_body_generator({arg0, arg1, ... arg`arguments.size()`}); + // } + // + // ... + // call @`kernel_name`(arguments[0], arguments[1] ...) + // ... + // + // If a function called `kernel_name` is already present in the module then + // that function is re-used. In that sense we're using the llvm::Module as a + // cache of outlined kernels, keyed by function name. + // + // If any of the values in `arguments` is nullptr (i.e. a nullptr + // llvm::Value*) then we ignore it when generating LLVM IR, and instead pass + // in a nullptr llvm::Value* in its position to `kernel_body_generator`. + // Currently we only support at most one nullptr value in `arguments`. + static void EmitAndCallOutlinedKernel( + bool enable_fast_math, bool optimize_for_size, + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name, + ArgumentVector arguments, + const std::function& kernel_body_generator); + + // Thin wrappers around the more general EmitAndCallOutlinedKernel above. + static void EmitAndCallOutlinedKernel( + bool enable_fast_math, bool optimize_for_size, + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name, + llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2, + const std::function& + kernel_body_generator) { + EmitAndCallOutlinedKernel( + enable_fast_math, optimize_for_size, ir_builder, kernel_name, + {arg0, arg1, arg2}, [&](ArgumentVector args) { + kernel_body_generator(args[0], args[1], args[2]); + }); + } + + static void EmitAndCallOutlinedKernel( + bool enable_fast_math, bool optimize_for_size, + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name, + llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2, + llvm::Value* arg3, + const std::function& kernel_body_generator) { + EmitAndCallOutlinedKernel( + enable_fast_math, optimize_for_size, ir_builder, kernel_name, + {arg0, arg1, arg2, arg3}, [&](ArgumentVector args) { + kernel_body_generator(args[0], args[1], args[2], args[3]); + }); + } + private: llvm::IRBuilder<>* ir_builder_; bool prevent_unrolling_; diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index cd0c4a371e2b1cd0e1c52b77e47e8b081ab8e836..61c47a0b6eca38db5d78dc622a8cf909f6cf14ee 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -142,6 +142,13 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, return llvm::Type::getInt8Ty(module->getContext()); case S16: case U16: + case BF16: + // For BF16 we just need some type that is 16 bits wide so that it will + // take up the right amount of space in memory. LLVM does not have a BF16 + // type (the LLVM half type is IEEE 16 bit floating point, not bfloat), so + // we can't map it directly to an LLVM type. We will not map a BF16 + // addition to an addition on this type (int16) - this is just the type + // used for storage. return llvm::Type::getInt16Ty(module->getContext()); case S32: case U32: @@ -200,8 +207,8 @@ llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) { if (ShapeUtil::IsTuple(shape)) { // A tuple buffer is an array of pointers. result_type = llvm::ArrayType::get(result_type, shape.tuple_shapes_size()); - } else { - for (int64 dimension : shape.layout().minor_to_major()) { + } else if (ShapeUtil::IsArray(shape)) { + for (int64 dimension : LayoutUtil::MinorToMajor(shape)) { result_type = llvm::ArrayType::get(result_type, shape.dimensions(dimension)); } @@ -280,6 +287,11 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, value = llvm::ConstantFP::get(ir_element_type, literal.Get(*multi_index)); break; + case BF16: + value = llvm::ConstantInt::get( + ir_element_type, + tensorflow::bit_cast(literal.Get(*multi_index))); + break; case F64: value = llvm::ConstantFP::get(ir_element_type, literal.Get(*multi_index)); @@ -304,7 +316,7 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, // decrements with each recursive call. We want to iterate through the // dimensions in major-to-minor order as we recurse so just index into // minor_to_major to get the dimension number for this level of the recursion. - int64 dimension = shape.layout().minor_to_major(dimension_index); + int64 dimension = LayoutUtil::Minor(shape.layout(), dimension_index); // Recursively call LiteralToConstant to construct subarrays for the // more-minor dimensions. Gather the subarrays into a vector for bundling into @@ -320,7 +332,7 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, if (elements.empty()) { element_type = ir_element_type; for (int i = 0; i < dimension_index; ++i) { - int64 index = shape.layout().minor_to_major(i); + int64 index = LayoutUtil::Minor(shape.layout(), i); element_type = llvm::ArrayType::get(element_type, shape.dimensions(index)); } @@ -676,5 +688,32 @@ Status DumpIRToDirectory(const string& directory_name, return f->Close(); } +llvm::Function* CreateFunction(llvm::FunctionType* function_type, + llvm::GlobalValue::LinkageTypes linkage, + bool enable_fast_math, bool optimize_for_size, + tensorflow::StringPiece name, + llvm::Module* module) { + llvm::Function* function = + llvm::Function::Create(function_type, linkage, AsStringRef(name), module); + function->setCallingConv(llvm::CallingConv::C); + function->addFnAttr("no-frame-pointer-elim", "false"); + + if (enable_fast_math) { + function->addFnAttr("unsafe-fp-math", "true"); + function->addFnAttr("no-infs-fp-math", "true"); + function->addFnAttr("no-nans-fp-math", "true"); + function->addFnAttr("no-signed-zeros-fp-math", "true"); + } + + // Add the optize attribute to the function if optimizing for size. This + // controls internal behavior of some optimization passes (e.g. loop + // unrolling). + if (optimize_for_size) { + function->addFnAttr(llvm::Attribute::OptimizeForSize); + } + + return function; +} + } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index 063ead2b647d8fc5cc4f67004aaded80a2191fe9..6bdc6a01a2b487df3dd80a02e67f5bcf62dead31 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -281,6 +281,12 @@ Status DumpIRToDirectory(const string& directory_name, const string& hlo_module_name, const llvm::Module& llvm_module, bool optimized); +llvm::Function* CreateFunction(llvm::FunctionType* function_type, + llvm::GlobalValue::LinkageTypes linkage, + bool enable_fast_math, bool optimize_for_size, + tensorflow::StringPiece name, + llvm::Module* module); + } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index 6fa4cd08c9e0ac30b83c0e2b49d98d930c2e15df..a5f7c850c33757fe8d48567ade35544d81224e46 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -99,8 +99,8 @@ IrArray::Index LoopEmitter::EmitIndexAndSetExitBasicBlock( // dimension (of the target shape). ForLoopNest loop_nest(loop_name, ir_builder_); IrArray::Index array_index(shape_.dimensions_size()); - for (int i = shape_.layout().minor_to_major_size() - 1; i >= 0; --i) { - int64 dimension = shape_.layout().minor_to_major(i); + for (int i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) { + int64 dimension = LayoutUtil::Major(shape_.layout(), i); std::unique_ptr loop = loop_nest.AddLoop( /*start_index=*/0, /*end_index=*/shape_.dimensions(dimension), diff --git a/tensorflow/compiler/xla/service/llvm_ir/ops.h b/tensorflow/compiler/xla/service/llvm_ir/ops.h index 11e84d9cb5defbcb87a8f696d56c139686c960d8..f72f482e3128c61e53cc454e7da8b5795ba6f695 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ops.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ops.h @@ -40,11 +40,24 @@ bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice, inline bool CanEmitFusedDynamicUpdateSliceInPlace( HloInstruction* fusion, const BufferAssignment& assignment) { CHECK_EQ(fusion->opcode(), HloOpcode::kFusion); - return fusion->fusion_kind() == HloInstruction::FusionKind::kLoop && - fusion->fused_expression_root()->opcode() == - HloOpcode::kDynamicUpdateSlice && - CanUpdateDynamicSliceInPlace(fusion->fused_expression_root(), - assignment); + HloInstruction* fused_root = fusion->fused_expression_root(); + if (fused_root->opcode() != HloOpcode::kDynamicUpdateSlice || + fusion->fusion_kind() != HloInstruction::FusionKind::kLoop) { + return false; + } + // Walk DynamicUpdateSlice operand(0) to fused parameter and get its + // associated operand. See if it shares an allocation with this operand. + HloInstruction* fusion_operand; + ShapeIndex index; + std::tie(fusion_operand, index) = + fused_root->mutable_operand(0)->LatestNonGteAncestorAndIndex(); + if (fusion_operand->opcode() != HloOpcode::kParameter) { + return false; + } + auto* operand = fusion->operand(fusion_operand->parameter_number()); + return assignment.HasAllocationAt(operand, index) && + assignment.HasAllocationAt(fusion, {}) && + assignment.SharesSliceAtIndex(fusion, {}, operand, index); } // Emits IR for running the given dynamic-update-slice op in-place -- that is, diff --git a/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.cc index e8c6a83618eaa8430521197f1c166cb7eb11a28e..0f6d8483da88ba4bf3f26961c0cbc8d855faa82c 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.cc @@ -34,6 +34,12 @@ VectorSupportLibrary::VectorSupportLibrary(PrimitiveType primitive_type, } llvm::Value* VectorSupportLibrary::Mul(llvm::Value* lhs, llvm::Value* rhs) { + CHECK(lhs->getType() == scalar_type() || lhs->getType() == vector_type()); + return MulInternal(lhs, rhs); +} + +llvm::Value* VectorSupportLibrary::MulInternal(llvm::Value* lhs, + llvm::Value* rhs) { if (scalar_type_->isFloatingPointTy()) { return ir_builder()->CreateFMul(lhs, rhs, name()); } else { @@ -42,6 +48,12 @@ llvm::Value* VectorSupportLibrary::Mul(llvm::Value* lhs, llvm::Value* rhs) { } llvm::Value* VectorSupportLibrary::Add(llvm::Value* lhs, llvm::Value* rhs) { + CHECK(lhs->getType() == scalar_type() || lhs->getType() == vector_type()); + return AddInternal(lhs, rhs); +} + +llvm::Value* VectorSupportLibrary::AddInternal(llvm::Value* lhs, + llvm::Value* rhs) { if (scalar_type_->isFloatingPointTy()) { return ir_builder()->CreateFAdd(lhs, rhs, name()); } else { @@ -129,6 +141,122 @@ llvm::Value* VectorSupportLibrary::AddReduce(llvm::Value* vector) { name()); } +llvm::Value* VectorSupportLibrary::AvxStyleHorizontalAdd(llvm::Value* lhs, + llvm::Value* rhs) { + CHECK_EQ(lhs->getType(), vector_type()); + CHECK_EQ(rhs->getType(), vector_type()); + CHECK_EQ(vector_size() % 2, 0); + + llvm::SmallVector mask_a, mask_b; + + // Adding the values shuffled using mask_a and mask_b gives us the + // AVX-style horizontal add we want. The masks work as documented + // in https://llvm.org/docs/LangRef.html#shufflevector-instruction + // + // Here are the masks for vector_width() == 8: + // + // index: |0 |1 |2 | 3 |4 |5 | 6 | 7 + // --------+--+--+--+---+--+--+---+--- + // mask_a: |0 |2 |8 |10 |4 |6 |12 |14 + // mask_b: |1 |3 |9 |11 |5 |7 |13 |16 + // + // So, as an example, the value at lane 3 of the result vector is + // the result of adding lane 10 and lane 11 in the combined lhs++rhs + // vector, which are the lanes 2 and 3 in the rhs vector. + for (int i = 0; i < vector_size(); i += 2) { + int increment = i < vector_size() / 2 ? 0 : (vector_size() / 2); + mask_a.push_back(ir_builder()->getInt32(increment + i)); + mask_b.push_back(ir_builder()->getInt32(increment + i + 1)); + } + for (int i = 0; i < vector_size(); i += 2) { + int increment = i < vector_size() / 2 ? (vector_size() / 2) : vector_size(); + mask_a.push_back(ir_builder()->getInt32(increment + i)); + mask_b.push_back(ir_builder()->getInt32(increment + i + 1)); + } + + llvm::Value* shuffle_0 = ir_builder()->CreateShuffleVector( + lhs, rhs, llvm::ConstantVector::get(mask_a)); + llvm::Value* shuffle_1 = ir_builder()->CreateShuffleVector( + lhs, rhs, llvm::ConstantVector::get(mask_b)); + + return Add(shuffle_0, shuffle_1); +} + +llvm::Value* VectorSupportLibrary::ExtractLowHalf(llvm::Value* vector) { + llvm::SmallVector mask; + for (int i = 0; i < vector_size() / 2; i++) { + mask.push_back(ir_builder()->getInt32(i)); + } + + return ir_builder()->CreateShuffleVector(vector, + llvm::UndefValue::get(vector_type()), + llvm::ConstantVector::get(mask)); +} + +llvm::Value* VectorSupportLibrary::ExtractHighHalf(llvm::Value* vector) { + llvm::SmallVector mask; + for (int i = 0; i < vector_size() / 2; i++) { + mask.push_back(ir_builder()->getInt32(i + vector_size() / 2)); + } + + return ir_builder()->CreateShuffleVector(vector, + llvm::UndefValue::get(vector_type()), + llvm::ConstantVector::get(mask)); +} + +std::vector VectorSupportLibrary::ComputeHorizontalSums( + std::vector vectors, llvm::Value* init_values) { + // TODO(sanjoy): Move this magic constant to TargetMachineFeatures. + const int kAvxVectorWidth = 8; + if (vector_size() == kAvxVectorWidth && vectors.size() == kAvxVectorWidth) { + return ComputeAvxOptimizedHorizontalSums(std::move(vectors), init_values); + } + + std::vector result; + std::transform(vectors.begin(), vectors.end(), std::back_inserter(result), + [this](llvm::Value* vector) { return AddReduce(vector); }); + if (init_values) { + for (int64 i = 0, e = result.size(); i < e; i++) { + result[i] = Add(result[i], ir_builder()->CreateExtractElement( + init_values, ir_builder()->getInt32(i))); + } + } + return result; +} + +std::vector +VectorSupportLibrary::ComputeAvxOptimizedHorizontalSums( + std::vector vectors, llvm::Value* init_values) { + while (vectors.size() != 2) { + std::vector new_vectors; + for (int i = 0; i < vectors.size(); i += 2) { + new_vectors.push_back(AvxStyleHorizontalAdd(vectors[i], vectors[i + 1])); + } + + vectors = std::move(new_vectors); + } + + llvm::Value* low = + AddInternal(ExtractLowHalf(vectors[0]), ExtractHighHalf(vectors[0])); + if (init_values) { + low = AddInternal(ExtractLowHalf(init_values), low); + } + llvm::Value* high = + AddInternal(ExtractLowHalf(vectors[1]), ExtractHighHalf(vectors[1])); + if (init_values) { + high = AddInternal(ExtractHighHalf(init_values), high); + } + + std::vector results; + for (int i = 0; i < 8; i++) { + llvm::Value* scalar_result = ir_builder()->CreateExtractElement( + i < 4 ? low : high, ir_builder()->getInt32(i % 4), name()); + results.push_back(scalar_result); + } + + return results; +} + llvm::Value* VectorSupportLibrary::GetZeroVector() { return llvm::Constant::getNullValue(vector_type()); } @@ -142,7 +270,9 @@ LlvmVariable::LlvmVariable(llvm::Type* type, llvm::IRBuilder<>* ir_builder) alloca_ = llvm_ir::EmitAllocaAtFunctionEntry(type, "", ir_builder_); } -llvm::Value* LlvmVariable::Get() { return ir_builder_->CreateLoad(alloca_); } +llvm::Value* LlvmVariable::Get() const { + return ir_builder_->CreateLoad(alloca_); +} void LlvmVariable::Set(llvm::Value* new_value) { ir_builder_->CreateStore(new_value, alloca_); diff --git a/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h index 3072677ab05aa91c736baaa0dc3023329d810a52..f404687ab6864bd0702d142ff691a394b78278a5 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h +++ b/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h @@ -111,7 +111,12 @@ class VectorSupportLibrary { return LoadBroadcast(base_pointer, ir_builder()->getInt64(offset_elements)); } - llvm::Value* AddReduce(llvm::Value* vector); + // Compute the horizontal sum of each vector in `vectors`. The i'th element + // in the result vector is the (scalar) horizontal sum of the i'th vector in + // `vectors`. If `init_values` is not nullptr then the value in the i'th lane + // in `init_values` is added to the i'th horizontal sum. + std::vector ComputeHorizontalSums( + std::vector vectors, llvm::Value* init_values = nullptr); llvm::Value* GetZeroVector(); llvm::Value* GetZeroScalar(); @@ -126,6 +131,33 @@ class VectorSupportLibrary { const std::string& name() const { return name_; } private: + llvm::Value* ExtractLowHalf(llvm::Value*); + llvm::Value* ExtractHighHalf(llvm::Value*); + + llvm::Value* MulInternal(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* AddInternal(llvm::Value* lhs, llvm::Value* rhs); + + llvm::Value* AddReduce(llvm::Value* vector); + + // Perform an X86 AVX style horizontal add between `lhs` and `rhs`. The + // resulting IR for an 8-float wide vector is expected to lower to a single + // vhaddps instruction on a CPU that supports vhaddps, and not be too bad in + // other cases. + // + // For a vector width of 8, the result vector is computed as: + // Result[0] = Lhs[0] + Lhs[1] + // Result[1] = Lhs[2] + Lhs[3] + // Result[2] = Rhs[0] + Rhs[1] + // Result[3] = Rhs[2] + Rhs[3] + // Result[4] = Lhs[4] + Lhs[5] + // Result[5] = Lhs[6] + Lhs[7] + // Result[6] = Rhs[4] + Rhs[5] + // Result[7] = Rhs[6] + Rhs[7] + llvm::Value* AvxStyleHorizontalAdd(llvm::Value* lhs, llvm::Value* rhs); + + std::vector ComputeAvxOptimizedHorizontalSums( + std::vector vectors, llvm::Value* init_values); + int64 vector_size_; PrimitiveType primitive_type_; llvm::IRBuilder<>* ir_builder_; @@ -142,7 +174,7 @@ class LlvmVariable { public: LlvmVariable(llvm::Type*, llvm::IRBuilder<>* ir_builder); - llvm::Value* Get(); + llvm::Value* Get() const; void Set(llvm::Value* new_value); private: diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 06f43bd3cb2376d34a3104133c868c4f4e5cc730..4071b948a5f94bcc2e87d8bb3b9533fb3b1d2cb1 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -118,10 +118,8 @@ StatusOr> LocalService::CompileExecutable( TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, execute_backend_->stream_executor(device_ordinal)); - std::vector argument_buffers( - argument_layouts.size()); return BuildExecutable(versioned_handle, std::move(module_config), - argument_buffers, execute_backend_.get(), executor); + execute_backend_.get(), executor); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index a0d08c288dbcc45e83a36ce7b094b04a9dbae532..7d8c05fffa4ab11d7dbf9956d2cb7ebd5bcdd3c4 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -17,12 +17,44 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { +namespace { + +bool IsAllowed(char character) { + auto c = static_cast(character); + return (isalnum(c) != 0) || c == '_' || c == '.' || c == '-'; +} + +} // namespace + +NameUniquer::NameUniquer(const string& separator) { + CHECK(std::all_of(separator.begin(), separator.end(), IsAllowed)) + << "separator should comprises allowed characters only"; + separator_ = separator; +} + +/*static*/ string NameUniquer::GetSanitizedName(const string& name) { + string result = name; + CHECK(!result.empty()) << "name should not be empty"; + char c = static_cast(result[0]); + if (!isalpha(c) && c != '_') { + result[0] = '_'; + } + for (int i = 1; i < result.length(); i++) { + if (!IsAllowed(result[i])) { + result[i] = '_'; + } + } + return result; +} + string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { string root = prefix.empty() ? "name" : prefix.ToString(); + root = GetSanitizedName(root); // Strip away numeric suffix (if any). Only recognize separator if it is in // the middle of the name. diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h index ed379b52258463b960dea788721c2c4325ef0260..4139c2700b25e8600182a034a8ac6f4f041c12e6 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.h +++ b/tensorflow/compiler/xla/service/name_uniquer.h @@ -28,14 +28,21 @@ namespace xla { // Simple stateful class that helps generate "unique" names. To use it, simply // call GetUniqueName as many times as needed. The names returned by // GetUniqueName are guaranteed to be distinct for this instance of the class. +// Note that the names will be sanitized to match regexp +// "[a-zA-Z_][a-zA-Z0-9_.-]*". class NameUniquer { public: - explicit NameUniquer(const string& separator = "__") - : separator_(separator) {} + // The separator must contain allowed characters only: "[a-zA-Z0-9_.-]". + explicit NameUniquer(const string& separator = "__"); - // Get a unique name in a string, with an optional prefix for convenience. + // Get a sanitized unique name in a string, with an optional prefix for + // convenience. string GetUniqueName(tensorflow::StringPiece prefix = ""); + // Sanitizes and returns the name. Unallowed characters will be replaced with + // '_'. The result will match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*". + static string GetSanitizedName(const string& name); + private: // The string to use to separate the prefix of the name from the uniquing // integer value. diff --git a/tensorflow/compiler/xla/service/name_uniquer_test.cc b/tensorflow/compiler/xla/service/name_uniquer_test.cc index 9f0747a6e2175a968d8f3661ac51512009e86f29..4258cf16876ab46dce6df062ab701b1b1a4a7580 100644 --- a/tensorflow/compiler/xla/service/name_uniquer_test.cc +++ b/tensorflow/compiler/xla/service/name_uniquer_test.cc @@ -60,12 +60,30 @@ TEST_F(NameUniquerTest, NumericSuffixes) { EXPECT_EQ("bar", uniquer.GetUniqueName("bar.-1000")); EXPECT_EQ("bar.1", uniquer.GetUniqueName("bar.-2000")); EXPECT_EQ("bar.2", uniquer.GetUniqueName("bar.1")); +} + +TEST_F(NameUniquerTest, Sanitize) { + NameUniquer uniquer("_"); + + EXPECT_EQ("foo", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo_1", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo.54", uniquer.GetUniqueName("foo.54")); + EXPECT_EQ("foo_54", uniquer.GetUniqueName("foo_54")); + EXPECT_EQ("foo_54.1", uniquer.GetUniqueName("foo_54.1")); + EXPECT_EQ("foo_55", uniquer.GetUniqueName("foo")); + + // Invalid characters will be replaced with '_'. + EXPECT_EQ("bar", uniquer.GetUniqueName("bar<-1000")); + EXPECT_EQ("bar_1", uniquer.GetUniqueName("bar<-2000")); + EXPECT_EQ("bar_2", uniquer.GetUniqueName("bar_1")); // Separator is only recognized in the middle of the prefix. - EXPECT_EQ(".10", uniquer.GetUniqueName(".10")); - EXPECT_EQ(".10.1", uniquer.GetUniqueName(".10")); - EXPECT_EQ("foobar.", uniquer.GetUniqueName("foobar.")); - EXPECT_EQ("foobar..1", uniquer.GetUniqueName("foobar.")); + EXPECT_EQ("_10", uniquer.GetUniqueName( + ".10")); // the leading '.' is replaced with '_'. + EXPECT_EQ("_10_1", uniquer.GetUniqueName(".10")); + EXPECT_EQ("_10_2", uniquer.GetUniqueName("_10")); + EXPECT_EQ("foobar_", uniquer.GetUniqueName("foobar_")); + EXPECT_EQ("foobar__1", uniquer.GetUniqueName("foobar_")); } } // namespace diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc index 3a1818de82d3fd305e2c6b3bd1f2cf8125806a75..aa974ee61a27de9c19e97d8a6eb48f9261ce4bd9 100644 --- a/tensorflow/compiler/xla/service/platform_util.cc +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -33,10 +33,32 @@ namespace se = ::perftools::gputools; namespace xla { +using tensorflow::str_util::Lowercase; + // Minimum supported CUDA compute capability is 3.5. constexpr int kMinCudaComputeCapabilityMajor = 3; constexpr int kMinCudaComputeCapabilityMinor = 5; +// The name of the interpreter platform. +constexpr char kInterpreter[] = "interpreter"; + +namespace { + +string CanonicalPlatformName(const string& name) { + string platform_str = Lowercase(name); + // "cpu" and "host" mean the same thing. + if (platform_str == "cpu") { + platform_str = "host"; + } + // "gpu" and "cuda" mean the same thing. + if (platform_str == "gpu") { + platform_str = "cuda"; + } + return platform_str; +} + +} // namespace + /* static */ StatusOr> PlatformUtil::GetSupportedPlatforms() { se::MultiPlatformManager::PlatformMap platform_map; @@ -78,7 +100,7 @@ PlatformUtil::GetSupportedPlatforms() { return platforms; } -/* static */ StatusOr PlatformUtil::GetDefaultPlatform() { +/* static */ StatusOr PlatformUtil::GetSolePlatform() { TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms()); if (platforms.empty()) { return NotFound("no platforms found"); @@ -87,13 +109,77 @@ PlatformUtil::GetSupportedPlatforms() { } // Multiple platforms present and we can't pick a reasonable default. - auto l = [](string* out, const se::Platform* p) { out->append(p->Name()); }; - string platforms_string = tensorflow::str_util::Join(platforms, ", ", l); + string platforms_string = tensorflow::str_util::Join( + platforms, ", ", + [](string* out, const se::Platform* p) { out->append(p->Name()); }); return InvalidArgument( "must specify platform because more than one platform found: %s", platforms_string.c_str()); } +/* static */ StatusOr PlatformUtil::GetDefaultPlatform() { + TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms()); + if (platforms.empty()) { + return NotFound("no platforms found"); + } else if (platforms.size() == 1) { + return platforms[0]; + } else if (platforms.size() == 2) { + for (int i = 0; i < 2; i++) { + if (Lowercase(platforms[i]->Name()) == kInterpreter && + Lowercase(platforms[1 - i]->Name()) != kInterpreter) { + return platforms[1 - i]; + } + } + } + + // Multiple platforms present and we can't pick a reasonable default. + string platforms_string = tensorflow::str_util::Join( + platforms, ", ", + [](string* out, const se::Platform* p) { out->append(p->Name()); }); + return InvalidArgument( + "must specify platform because more than one platform (except for the " + "interpreter platform) found: %s", + platforms_string.c_str()); +} + +/*static*/ StatusOr PlatformUtil::GetPlatform( + const string& platform_name) { + string platform_str = CanonicalPlatformName(platform_name); + TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms()); + for (se::Platform* platform : platforms) { + if (Lowercase(platform->Name()) == platform_str) { + return platform; + } + } + return InvalidArgument("platform %s not found", platform_name.c_str()); +} + +/*static*/ StatusOr PlatformUtil::GetPlatformExceptFor( + const string& platform_name) { + string platform_str = CanonicalPlatformName(platform_name); + + TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms()); + std::vector matched; + for (se::Platform* platform : platforms) { + if (Lowercase(platform->Name()) != platform_name) { + matched.push_back(platform); + } + } + if (matched.empty()) { + return InvalidArgument("unable to find platform that is not %s", + platform_name.c_str()); + } + if (matched.size() == 1) { + return matched[0]; + } + string matched_string = tensorflow::str_util::Join( + matched, ", ", + [](string* out, const se::Platform* p) { out->append(p->Name()); }); + return InvalidArgument( + "found multiple platforms %s, but expected one platform except for %s", + matched_string.c_str(), platform_name.c_str()); +} + // Returns whether the device underlying the given StreamExecutor is supported // by XLA. static bool IsDeviceSupported(se::StreamExecutor* executor) { diff --git a/tensorflow/compiler/xla/service/platform_util.h b/tensorflow/compiler/xla/service/platform_util.h index eac573703085aca2801885cd9abbe0022f1c029e..69188820a70707d9c9be10b20fb7de92ad4d9873 100644 --- a/tensorflow/compiler/xla/service/platform_util.h +++ b/tensorflow/compiler/xla/service/platform_util.h @@ -16,11 +16,14 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PLATFORM_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_PLATFORM_UTIL_H_ +#include #include #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/types.h" namespace xla { @@ -34,10 +37,27 @@ class PlatformUtil { static StatusOr> GetSupportedPlatforms(); - // Convenience function which returns the default supported platform. If + // Convenience function which returns the default supported platform for + // tests. If exactly one supported platform is present, then this platform is + // the default platform. If exactly two platforms are present and one of them + // is the interpreter platform, then the other platform is the default + // platform. Otherwise returns an error. + static StatusOr GetDefaultPlatform(); + + // Convenience function which returns the sole supported platform. If // exactly one supported platform is present, then this platform is the // default platform. Otherwise returns an error. - static StatusOr GetDefaultPlatform(); + static StatusOr GetSolePlatform(); + + // Returns the platform according to the given name. Returns error if there is + // no such platform. + static StatusOr GetPlatform( + const string& platform_name); + + // Returns exactly one platform that does not have given name. Returns error + // if there is no such platform, or there are multiple such platforms. + static StatusOr GetPlatformExceptFor( + const string& platform_name); // Returns a vector of StreamExecutors for the given platform. The vector is // indexed by device ordinal (device numbering used by StreamExecutor). If an diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index 0fb90230f2f39a841973361f63d17af579a1342b..e62bafc50b0e1270702621c9ea7b2ee43e001fe0 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -101,8 +101,9 @@ HloInstruction* FirstNonScalarAndNonTrivialReshapeOperand( IsReshapeOrTranspose(operand) && !CanTriviallyChangeShape(operand->operand(0))) { VLOG(5) << "Found first non-scalar and non-trivial reshape operand of " - << hlo->ToStringNoMetadata() << ":\n\t" - << operand->ToStringNoMetadata(); + << hlo->ToString(HloPrintOptions().set_print_metadata(false)) + << ":\n\t" + << operand->ToString(HloPrintOptions().set_print_metadata(false)); return operand; } } @@ -133,8 +134,9 @@ bool AreEquivalentReshapes(const HloInstruction* a, const HloInstruction* b) { bool AllOperandsHaveEasyShapeChanges( const HloInstruction* instruction, const HloInstruction* first_reshape_operand) { + auto print_no_metadata = HloPrintOptions().set_print_metadata(false); VLOG(3) << "** Checking whether all operands have easy shape changes: " - << instruction->ToStringNoMetadata(); + << instruction->ToString(print_no_metadata); // Check whether all operands: // 0. Have the same dimensions as the output -- if not, it may be // implicitly broadcast, which can confound the movement's @@ -151,21 +153,21 @@ bool AllOperandsHaveEasyShapeChanges( VLOG(5) << "Operand shape differs from output shape; may be " "implicitly broadcast, so preventing " "movement\n\toperand: " - << operand->ToStringNoMetadata() - << "\n\tinstruction: " << instruction->ToStringNoMetadata(); + << operand->ToString(print_no_metadata) << "\n\tinstruction: " + << instruction->ToString(print_no_metadata); return false; } if (AreEquivalentReshapes(first_reshape_operand, operand)) { VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: " - << first_reshape_operand->ToStringNoMetadata() - << "\n\toperand: " << operand->ToStringNoMetadata(); + << first_reshape_operand->ToString(print_no_metadata) + << "\n\toperand: " << operand->ToString(print_no_metadata); continue; } if (CanTriviallyChangeShape(operand)) { VLOG(5) << "Operand can trivially change shape: " - << operand->ToStringNoMetadata(); + << operand->ToString(print_no_metadata); continue; } @@ -173,12 +175,12 @@ bool AllOperandsHaveEasyShapeChanges( // well. VLOG(5) << "Operand is neither equalivant to the first Reshape operand" "nor can trivially change shape: " - << operand->ToStringNoMetadata(); + << operand->ToString(print_no_metadata); return false; } VLOG(3) << "All operands have easy shape changes: " - << instruction->ToStringNoMetadata(); + << instruction->ToString(print_no_metadata); return true; } @@ -250,11 +252,13 @@ StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, return false; } + auto print_no_metadata = HloPrintOptions().set_print_metadata(false); // At this point we've decided to sink reshape/transpose operands. const Shape& new_operand_shape = first_reshape_operand->operand(0)->shape(); VLOG(3) << "** Sinking reshape or transpose: " - << instruction->ToStringNoMetadata() << "\n\tfirst reshape operand: " - << first_reshape_operand->ToStringNoMetadata() + << instruction->ToString(print_no_metadata) + << "\n\tfirst reshape operand: " + << first_reshape_operand->ToString(print_no_metadata) << "\n\tnew operand shape: " << ShapeUtil::HumanString(new_operand_shape); @@ -267,7 +271,7 @@ StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, continue; } VLOG(3) << "Updating operand #" << i << ": " - << operands[i]->ToStringNoMetadata(); + << operands[i]->ToString(print_no_metadata); operands[i] = UpdateOperand(computation, first_reshape_operand, new_operand_shape, operands[i]); } @@ -298,7 +302,7 @@ StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, switch (first_reshape_operand->opcode()) { case HloOpcode::kReshape: VLOG(3) << "Creating new reshape for new elementwise op: " - << new_elementwise->ToStringNoMetadata(); + << new_elementwise->ToString(print_no_metadata); new_reshape = HloInstruction::CreateReshape(instruction->shape(), new_elementwise); break; diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index ee9501dd4839ffcb6052df14699aad90565ae0e2..e77a46128b1dadbeea0df64a19f5ba980257cf8c 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -60,41 +60,32 @@ namespace xla { namespace { -// Copies the contents of an Allocation into a Literal proto. -tensorflow::Status LiteralFromAllocation(const Allocation* allocation, - const Shape& literal_shape, - Literal* literal) { - TF_ASSIGN_OR_RETURN( - se::StreamExecutor * executor, - allocation->backend()->stream_executor(allocation->device_ordinal())); - return allocation->backend()->transfer_manager()->TransferLiteralFromDevice( - executor, allocation->device_memory(), allocation->shape(), literal_shape, - literal); -} - // Records the arguments used to invoke a computation in a SessionModule // proto. tensorflow::Status RecordArguments( - const tensorflow::gtl::ArraySlice arg_allocations, + const tensorflow::gtl::ArraySlice arguments, + se::StreamExecutor* executor, TransferManager* transfer_manager, SessionModule* module) { module->clear_arguments(); - for (const Allocation* allocation : arg_allocations) { - Literal argument; - TF_RETURN_IF_ERROR( - LiteralFromAllocation(allocation, allocation->shape(), &argument)); - *module->add_arguments() = argument.ToProto(); + for (const ShapedBuffer* argument : arguments) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr literal, + transfer_manager->TransferLiteralFromDevice(executor, *argument)); + *module->add_arguments() = literal->ToProto(); } return tensorflow::Status::OK(); } // Records the result of a computation in a SessionModule proto. -tensorflow::Status RecordResult(const Allocation* result_allocation, +tensorflow::Status RecordResult(const ShapedBuffer& result, + se::StreamExecutor* executor, + TransferManager* transfer_manager, SessionModule* module) { module->clear_result(); - Literal result; - TF_RETURN_IF_ERROR(LiteralFromAllocation( - result_allocation, result_allocation->shape(), &result)); - *module->mutable_result() = result.ToProto(); + TF_ASSIGN_OR_RETURN( + std::unique_ptr literal, + transfer_manager->TransferLiteralFromDevice(executor, result)); + *module->mutable_result() = literal->ToProto(); return tensorflow::Status::OK(); } @@ -152,7 +143,9 @@ int ServiceOptions::intra_op_parallelism_threads() const { Service::Service(const ServiceOptions& options, std::unique_ptr execute_backend) - : options_(options), execute_backend_(std::move(execute_backend)) { + : options_(options), + allocation_tracker_(execute_backend.get()), + execute_backend_(std::move(execute_backend)) { CHECK_GT(options_.number_of_replicas(), 0); if (execute_backend_) { if (execute_backend_->device_count() > 0) { @@ -235,35 +228,33 @@ tensorflow::Status Service::ValidateResultShapeWithLayout( return ShapeUtil::ValidateShape(shape_with_layout); } -StatusOr> Service::ResolveAndValidateArguments( +StatusOr> Service::ResolveAndValidateArguments( tensorflow::gtl::ArraySlice arguments, - const Backend* backend, int device_ordinal) { - std::vector allocations; + int device_ordinal) { + std::vector shaped_buffers; for (size_t i = 0; i < arguments.size(); ++i) { - auto allocation_status = allocation_tracker_.Resolve(*arguments[i]); - if (!allocation_status.ok()) { - return Status(allocation_status.status().code(), - StrCat(allocation_status.status().error_message(), ", ", + auto buffer_status = allocation_tracker_.Resolve(*arguments[i]); + if (!buffer_status.ok()) { + return Status(buffer_status.status().code(), + StrCat(buffer_status.status().error_message(), ", ", "failed to resolve allocation for parameter ", i)); } - const Allocation* allocation = allocation_status.ValueOrDie(); + const ShapedBuffer* shaped_buffer = buffer_status.ValueOrDie(); // Verify allocation is same platform and device as the execution. - if (allocation->backend() != backend || - allocation->device_ordinal() != device_ordinal) { + if (shaped_buffer->platform() != execute_backend_->platform() || + shaped_buffer->device_ordinal() != device_ordinal) { return InvalidArgument( - "argument %lu is on device %s but computation will be executed " + "argument %lu is on device %s:%d but computation will be executed " "on device %s", - i, - allocation->backend() - ->device_name(allocation->device_ordinal()) - .c_str(), - backend->device_name(device_ordinal).c_str()); + i, shaped_buffer->platform()->Name().c_str(), + shaped_buffer->device_ordinal(), + execute_backend_->device_name(device_ordinal).c_str()); } - allocations.push_back(allocation); + shaped_buffers.push_back(shaped_buffer); } - return allocations; + return shaped_buffers; } StatusOr> Service::CreateModuleConfig( @@ -325,11 +316,11 @@ StatusOr> Service::CreateModuleConfig( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, - tensorflow::gtl::ArraySlice arguments, + tensorflow::gtl::ArraySlice arguments, const ExecutionOptions& execution_options) { std::vector argument_shapes; for (const auto* arg : arguments) { - argument_shapes.push_back(&arg->shape()); + argument_shapes.push_back(&arg->on_host_shape()); } return CreateModuleConfig(program_shape, argument_shapes, &execution_options); } @@ -398,8 +389,6 @@ StatusOr>> Service::BuildExecutables( StatusOr> Service::BuildExecutable( const VersionedComputationHandle& versioned_handle, std::unique_ptr module_config, - const tensorflow::gtl::ArraySlice - arguments, Backend* backend, se::StreamExecutor* executor) { VLOG(1) << Printf("BuildExecutable on service %p with handle %s", this, versioned_handle.ToString().c_str()); @@ -430,9 +419,12 @@ StatusOr> Service::BuildExecutable( /*include_unreachable_instructions=*/ true)); + TF_ASSIGN_OR_RETURN( + module, backend->compiler()->RunHloPasses(std::move(module), executor)); + TF_ASSIGN_OR_RETURN( std::unique_ptr executable, - backend->compiler()->Compile(std::move(module), executor)); + backend->compiler()->RunBackend(std::move(module), executor)); if (!other_directory_path.empty()) { executable->set_session_module(std::move(session_module)); @@ -444,8 +436,6 @@ StatusOr> Service::BuildExecutable( StatusOr> Service::BuildAndCacheExecutable( const VersionedComputationHandle& versioned_handle, std::unique_ptr module_config, - const tensorflow::gtl::ArraySlice - arguments, Backend* backend, perftools::gputools::StreamExecutor* executor, ExecutionProfile* profile) { std::shared_ptr executable = @@ -468,8 +458,8 @@ StatusOr> Service::BuildAndCacheExecutable( HloModuleConfig original_module_config = *module_config; TF_ASSIGN_OR_RETURN( std::unique_ptr executable_unique_ptr, - BuildExecutable(versioned_handle, std::move(module_config), arguments, - backend, executor)); + BuildExecutable(versioned_handle, std::move(module_config), backend, + executor)); if (profile != nullptr) { uint64 end_micros = tensorflow::Env::Default()->NowMicros(); @@ -486,9 +476,7 @@ StatusOr> Service::BuildAndCacheExecutable( StatusOr> Service::ExecuteParallelAndRegisterResult( tensorflow::gtl::ArraySlice executables, - tensorflow::gtl::ArraySlice< - std::vector> - arguments, + tensorflow::gtl::ArraySlice> arguments, Backend* backend, tensorflow::gtl::ArraySlice device_handles, tensorflow::gtl::ArraySlice result_tags, ExecutionProfile* profile) { @@ -544,7 +532,7 @@ Service::ExecuteParallelAndRegisterResult( // Asynchronously launch the computation. TF_ASSIGN_OR_RETURN( - perftools::gputools::DeviceMemoryBase result, + std::unique_ptr result, executables[i]->ExecuteAsyncOnStream(&run_options, arguments[i])); if (replica == 0 && profile != nullptr) { @@ -554,17 +542,20 @@ Service::ExecuteParallelAndRegisterResult( // All replicas share the same device address for the result allocation, // so only one of the replicas need to register the result handle. if (replica == 0) { - result_handles.push_back(allocation_tracker_.Register( - backend, replicas[0]->device_ordinal(), result, - executables[i]->result_shape(), result_tags[i])); + TF_ASSIGN_OR_RETURN( + GlobalDataHandle handle, + allocation_tracker_.Register(std::move(result), result_tags[i])); + result_handles.push_back(handle); } } } // Wait for all executions to complete. for (int64 i = 0; i < streams.size(); ++i) { - if (!streams[i]->BlockHostUntilDone()) { - return InternalError("failed to complete execution for stream %lld", i); + Status block_status = streams[i]->BlockHostUntilDone(); + if (!block_status.ok()) { + return InternalError("failed to complete execution for stream %lld: %s", + i, block_status.error_message().c_str()); } } @@ -572,12 +563,13 @@ Service::ExecuteParallelAndRegisterResult( // profile. for (auto& index_to_profiled_stream : index_to_profiled_streams) { int64 device = index_to_profiled_stream.first; - auto& module = executables[device]->module(); se::Stream* stream = index_to_profiled_stream.second; - HloExecutionProfile hlo_profile(module, - *executables[device]->CreateCostAnalysis()); - TF_RETURN_IF_ERROR(executables[device]->PopulateExecutionProfile( - &hlo_profile, stream->parent())); + Executable* executable = executables[device]; + const HloModule& module = executable->module(); + HloExecutionProfile hlo_profile(&executable->hlo_profile_printer(), + &executable->hlo_profile_index_map()); + TF_RETURN_IF_ERROR( + executable->PopulateExecutionProfile(&hlo_profile, stream->parent())); XLA_LOG_LINES( tensorflow::INFO, hlo_profile.ToString(streams[0]->parent()->GetDeviceDescription())); @@ -621,8 +613,7 @@ Service::ExecuteParallelAndRegisterResult( StatusOr Service::ExecuteAndRegisterResult( Executable* executable, - const tensorflow::gtl::ArraySlice - arguments, + const tensorflow::gtl::ArraySlice arguments, Backend* backend, perftools::gputools::StreamExecutor* executor, const string& result_tag, ExecutionProfile* profile) { // Set up streams. @@ -647,6 +638,7 @@ StatusOr Service::ExecuteAndRegisterResult( for (const Pool::SmartPtr& stream : streams) { ExecutableRunOptions options; options.set_stream(stream.get()); + options.set_device_ordinal(stream->parent()->device_ordinal()); options.set_allocator(backend->memory_allocator()); options.set_inter_op_thread_pool(backend->inter_op_thread_pool()); options.set_intra_op_thread_pool( @@ -656,24 +648,23 @@ StatusOr Service::ExecuteAndRegisterResult( backend->inter_op_thread_pool()); } - perftools::gputools::DeviceMemoryBase result; + std::unique_ptr result; if (options_.number_of_replicas() == 1) { TF_ASSIGN_OR_RETURN( - result, executable->ExecuteOnStreamWrapper( - &run_options[0], profile, arguments)); + result, + executable->ExecuteOnStreamWrapper>( + &run_options[0], profile, arguments)); } else { - std::vector< - tensorflow::gtl::ArraySlice> + // TODO(b/69985541): Support profiling also on this path. + std::vector> repeated_arguments(options_.number_of_replicas(), arguments); TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams( run_options, repeated_arguments)); TF_RET_CHECK(!results.empty()); - result = results[0]; + result = std::move(results[0]); } - return allocation_tracker_.Register(backend, executor->device_ordinal(), - result, executable->result_shape(), - result_tag); + return allocation_tracker_.Register(std::move(result), result_tag); } tensorflow::Status Service::SetReturnValue(const SetReturnValueRequest* arg, @@ -687,7 +678,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, ExecuteParallelResponse* result) { VLOG(1) << "running execute-parallel request: " << arg->ShortDebugString(); - std::vector> all_arguments; + std::vector> all_arguments; std::vector> all_executors; std::vector versioned_handles; std::vector> module_configs; @@ -744,19 +735,14 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, // In the case of partitioned computations, assume all arguments go on the // zeroth core. TF_ASSIGN_OR_RETURN( - std::vector arg_allocations, - ResolveAndValidateArguments(request.arguments(), execute_backend_.get(), + std::vector arguments, + ResolveAndValidateArguments(request.arguments(), executors[0]->device_ordinal())); - std::vector arguments; - arguments.reserve(arg_allocations.size()); - for (const Allocation* allocation : arg_allocations) { - arguments.push_back(allocation->device_memory()); - } // Create an HloModuleConfig object for the computation, given the shape of // the program and the argument allocations. TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arg_allocations, + CreateModuleConfig(*program_shape, arguments, request.execution_options())); VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -859,35 +845,30 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, user_computation->ComputeProgramShape(versioned_handle.version)); TF_ASSIGN_OR_RETURN( - std::vector arg_allocations, - ResolveAndValidateArguments(arg->arguments(), execute_backend_.get(), + std::vector arguments, + ResolveAndValidateArguments(arg->arguments(), execute_backend_->default_device_ordinal())); - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arg_allocations, - arg->execution_options())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module_config, + CreateModuleConfig(*program_shape, arguments, arg->execution_options())); VLOG(3) << "Execute created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); - std::vector arguments; - arguments.reserve(arg_allocations.size()); - for (const Allocation* allocation : arg_allocations) { - arguments.push_back(allocation->device_memory()); - } - TF_ASSIGN_OR_RETURN( std::shared_ptr executable, BuildAndCacheExecutable(versioned_handle, std::move(module_config), - arguments, execute_backend_.get(), + execute_backend_.get(), execute_backend_->default_stream_executor(), result->mutable_profile())); if (executable->dumping()) { executable->session_module()->set_execution_platform( execute_backend_->platform()->Name()); - TF_RETURN_IF_ERROR( - RecordArguments(arg_allocations, executable->session_module())); + TF_RETURN_IF_ERROR(RecordArguments( + arguments, execute_backend_->default_stream_executor(), + execute_backend_->transfer_manager(), executable->session_module())); } TF_ASSIGN_OR_RETURN( @@ -898,10 +879,11 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, "result of " + user_computation->name(), result->mutable_profile())); if (executable->dumping()) { - TF_ASSIGN_OR_RETURN(const Allocation* result_allocation, + TF_ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer, allocation_tracker_.Resolve(result->output())); - TF_RETURN_IF_ERROR( - RecordResult(result_allocation, executable->session_module())); + TF_RETURN_IF_ERROR(RecordResult( + *result_buffer, execute_backend_->default_stream_executor(), + execute_backend_->transfer_manager(), executable->session_module())); TF_RETURN_IF_ERROR(executable->DumpSessionModule()); } @@ -927,31 +909,24 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, user_computation->ComputeProgramShape(versioned_handle.version)); TF_ASSIGN_OR_RETURN( - std::vector arg_allocations, - ResolveAndValidateArguments(arg->arguments(), execute_backend_.get(), + std::vector arguments, + ResolveAndValidateArguments(arg->arguments(), execute_backend_->default_device_ordinal())); - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arg_allocations, - arg->execution_options())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module_config, + CreateModuleConfig(*program_shape, arguments, arg->execution_options())); VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); - std::vector arguments; - arguments.reserve(arg_allocations.size()); - for (const Allocation* allocation : arg_allocations) { - arguments.push_back(allocation->device_memory()); - } - ExecutionProfile profile; TF_ASSIGN_OR_RETURN( std::shared_ptr executable, - BuildAndCacheExecutable(versioned_handle, std::move(module_config), - arguments, execute_backend_.get(), - execute_backend_->default_stream_executor(), - &profile)); + BuildAndCacheExecutable( + versioned_handle, std::move(module_config), execute_backend_.get(), + execute_backend_->default_stream_executor(), &profile)); TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle())); @@ -966,7 +941,7 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, streams.push_back(std::move(stream)); } - perftools::gputools::DeviceMemoryBase result_data; + std::unique_ptr result_buffer; for (const Pool::SmartPtr& stream : streams) { ExecutableRunOptions options; options.set_stream(stream.get()); @@ -979,19 +954,19 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, options, execute_backend_->StreamBorrower()); TF_ASSIGN_OR_RETURN( - perftools::gputools::DeviceMemoryBase this_result_data, + std::unique_ptr this_result_buffer, executable->ExecuteAsyncOnStream(&service_options, arguments)); // Take the first result. - if (result_data == nullptr) { - result_data = this_result_data; + if (result_buffer == nullptr) { + result_buffer = std::move(this_result_buffer); } } - auto output = allocation_tracker_.Register( - execute_backend_.get(), execute_backend_->default_device_ordinal(), - result_data, executable->result_shape(), - "result of " + user_computation->name()); + TF_ASSIGN_OR_RETURN( + GlobalDataHandle output, + allocation_tracker_.Register(std::move(result_buffer), + "result of " + user_computation->name())); *result->mutable_execution() = execution_tracker_.Register( execute_backend_.get(), std::move(streams), profile, output); @@ -1018,38 +993,58 @@ tensorflow::Status Service::WaitForExecution(const WaitForExecutionRequest* arg, tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg, TransferToClientResponse* result) { - TF_ASSIGN_OR_RETURN(const Allocation* allocation, + TF_ASSIGN_OR_RETURN(const ShapedBuffer* shaped_buffer, allocation_tracker_.Resolve(arg->data())); - const Shape* literal_shape; + const Shape* return_shape; if (arg->has_shape_with_layout()) { if (!LayoutUtil::HasLayout(arg->shape_with_layout())) { return InvalidArgument("shape_with_layout must have layout if present."); } - literal_shape = &arg->shape_with_layout(); + return_shape = &arg->shape_with_layout(); } else { - literal_shape = &allocation->shape(); + return_shape = &shaped_buffer->on_host_shape(); } - Literal literal; - TF_RETURN_IF_ERROR( - LiteralFromAllocation(allocation, *literal_shape, &literal)); - *result->mutable_literal() = literal.ToProto(); + TF_ASSIGN_OR_RETURN( + se::StreamExecutor * executor, + execute_backend_->stream_executor(shaped_buffer->device_ordinal())); + + TF_ASSIGN_OR_RETURN( + std::unique_ptr result_literal, + execute_backend_->transfer_manager()->TransferLiteralFromDevice( + executor, *shaped_buffer)); + + if (LayoutUtil::LayoutsInShapesEqual(*return_shape, + result_literal->shape())) { + *result->mutable_literal() = result_literal->ToProto(); + } else { + *result->mutable_literal() = + result_literal->Relayout(*return_shape)->ToProto(); + } return tensorflow::Status::OK(); } +namespace { + +// Creates a clone of the given shaped buffer with the given device ordinal. The +// shape and DeviceMemoryBase values of the clone are identical to the original. +std::unique_ptr CloneShapedBufferOnDevice( + const ShapedBuffer& shaped_buffer, int device_ordinal) { + auto clone = MakeUnique( + shaped_buffer.on_host_shape(), shaped_buffer.on_device_shape(), + shaped_buffer.platform(), device_ordinal); + clone->buffers() = shaped_buffer.buffers(); + return clone; +} + +} // namespace + tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, TransferToServerResponse* result) { Literal literal = Literal(arg->literal()); const Shape& shape = literal.shape(); - if (ShapeUtil::IsTuple(shape) && options_.number_of_replicas() > 1) { - // TODO(b/32990684): Tuple transfers to host end up allocating further - // buffers - implement that correctly. - return Unimplemented( - "Tuple transfers to the device not supported with replication."); - } - std::vector replicas; if (arg->has_device_handle()) { TF_ASSIGN_OR_RETURN(replicas, @@ -1059,25 +1054,38 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle())); } - // Allocate memory on the device, using the stream executor. The size of the - // allocation is obtained by examining the shape of the literal passed from - // the client. An allocation handle is returned in the response. - int64 allocation_size = - execute_backend_->transfer_manager()->GetByteSizeRequirement(shape); - - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase allocation, - execute_backend_->memory_allocator()->Allocate( - replicas[0]->device_ordinal(), allocation_size)); - - *result->mutable_data() = allocation_tracker_.Register( - execute_backend_.get(), replicas[0]->device_ordinal(), allocation, shape, - StrCat("TransferToServer literal of size ", allocation_size)); + // All memory allocation is done on the first replica. The allocations in all + // other replicas mirror the firsts'. + int master_device_ordinal = replicas[0]->device_ordinal(); + TF_ASSIGN_OR_RETURN( + std::unique_ptr shaped_buffer, + execute_backend_->transfer_manager()->AllocateShapedBuffer( + shape, execute_backend_->memory_allocator(), master_device_ordinal)); + // Transfer the data to the replicas. for (se::StreamExecutor* executor : replicas) { - TF_RETURN_IF_ERROR( - execute_backend_->transfer_manager()->TransferLiteralToDevice( - executor, literal, &allocation)); + if (executor->device_ordinal() == master_device_ordinal) { + TF_RETURN_IF_ERROR( + execute_backend_->transfer_manager()->TransferLiteralToDevice( + executor, literal, *shaped_buffer)); + } else { + // The replica is not the master. Create an cloned shaped buffer with + // the replica's device ordinal. This is required because + // TransferLiteralToDevice verifies that the device ordinal of the shaped + // buffer matches that of the executor. + std::unique_ptr clone = + CloneShapedBufferOnDevice(*shaped_buffer, executor->device_ordinal()); + TF_RETURN_IF_ERROR( + execute_backend_->transfer_manager()->TransferLiteralToDevice( + executor, literal, *clone)); + } } + TF_ASSIGN_OR_RETURN( + *result->mutable_data(), + allocation_tracker_.Register(std::move(shaped_buffer), + StrCat("TransferToServer literal of shape ", + ShapeUtil::HumanString(shape)))); + return tensorflow::Status::OK(); } @@ -1228,8 +1236,9 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, [](const Literal& literal) { return &literal; }); HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN(auto result_literal, - evaluator.Evaluate(*module, parameter_ptrs)); + TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate( + *module, parameter_ptrs)); + // Since the shape_with_output_layout option in ExecutionOption is // non-effective to the Evaluator results, explicit relayout here. if (arg->has_output_layout()) { @@ -1242,9 +1251,9 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, tensorflow::Status Service::GetShape(const GetShapeRequest* arg, GetShapeResponse* result) { - TF_ASSIGN_OR_RETURN(const Allocation* allocation, + TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer, allocation_tracker_.Resolve(arg->data())); - *result->mutable_shape() = allocation->shape(); + *result->mutable_shape() = buffer->on_host_shape(); return tensorflow::Status::OK(); } @@ -1353,6 +1362,17 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { handle_status = computation->AddConcatenateInstruction(arg->concatenate_request()); break; + case OpRequest::kConditionalRequest: { + TF_ASSIGN_OR_RETURN(UserComputation * true_computation, + computation_tracker_.Resolve( + arg->conditional_request().true_computation())); + TF_ASSIGN_OR_RETURN(UserComputation * false_computation, + computation_tracker_.Resolve( + arg->conditional_request().false_computation())); + handle_status = computation->AddConditionalInstruction( + arg->conditional_request(), *true_computation, *false_computation); + break; + } case OpRequest::kConstantRequest: handle_status = computation->AddConstantInstruction(arg->constant_request()); @@ -1361,6 +1381,10 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { handle_status = computation->AddConvertInstruction(arg->convert_request()); break; + case OpRequest::kBitcastConvertRequest: + handle_status = computation->AddBitcastConvertInstruction( + arg->bitcast_convert_request()); + break; case OpRequest::kConvolveRequest: handle_status = computation->AddConvolveInstruction(arg->convolve_request()); @@ -1373,6 +1397,9 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { handle_status = computation->AddCustomCallInstruction(arg->custom_call_request()); break; + case OpRequest::kDotRequest: + handle_status = computation->AddDotInstruction(arg->dot_request()); + break; case OpRequest::kDynamicSliceRequest: handle_status = computation->AddDynamicSliceInstruction(arg->dynamic_slice_request()); @@ -1493,8 +1520,12 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { handle_status = computation->AddRecvInstruction(arg->recv_request()); break; } + case OpRequest::kFftRequest: + return Unimplemented("FftRequest not implemented in XLA service."); + case OpRequest::OP_NOT_SET: + return InvalidArgument("XLA service received OpRequest with OP_NOT_SET"); default: - return InvalidArgument("Unsupported operation"); + return InvalidArgument("Unsupported operation in XLA service"); } TF_ASSIGN_OR_RETURN(*result->mutable_output(), handle_status); diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 47f4f0ade594089aa71717ef1e122886b0a6c7ac..f962d0cdc7d41e1aeab55da5abcb1b40215b4144 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -250,7 +250,7 @@ class Service : public ServiceInterface { // class. StatusOr> CreateModuleConfig( const ProgramShape& program_shape, - tensorflow::gtl::ArraySlice arguments, + tensorflow::gtl::ArraySlice arguments, const ExecutionOptions& execution_options); protected: @@ -265,10 +265,10 @@ class Service : public ServiceInterface { // Resolves the given argument handles in the allocation tracker and returns // the corresponding allocations. The function also verifies that each - // allocation matches the given backend and device ordinal. - StatusOr> ResolveAndValidateArguments( + // allocation matches the execution platform and device ordinal. + StatusOr> ResolveAndValidateArguments( tensorflow::gtl::ArraySlice arguments, - const Backend* backend, int device_ordinal); + int device_ordinal); // Create a Hlo module config for the given program shape and arguments. // execution_options is optional; if not given a default is used. @@ -281,8 +281,6 @@ class Service : public ServiceInterface { StatusOr> BuildExecutable( const VersionedComputationHandle& versioned_handle, std::unique_ptr module_config, - const tensorflow::gtl::ArraySlice - arguments, Backend* backend, perftools::gputools::StreamExecutor* executor); // Same as BuildExecutable() above, but builds a list of Executables for the @@ -299,8 +297,6 @@ class Service : public ServiceInterface { StatusOr> BuildAndCacheExecutable( const VersionedComputationHandle& versioned_handle, std::unique_ptr module_config, - const tensorflow::gtl::ArraySlice - arguments, Backend* backend, perftools::gputools::StreamExecutor* executor, ExecutionProfile* profile); @@ -310,8 +306,7 @@ class Service : public ServiceInterface { // ExecutionProfile object which will be filled in with profile data. StatusOr ExecuteAndRegisterResult( Executable* executable, - const tensorflow::gtl::ArraySlice - arguments, + const tensorflow::gtl::ArraySlice arguments, Backend* backend, perftools::gputools::StreamExecutor* executor, const string& result_tag, ExecutionProfile* profile); @@ -320,9 +315,7 @@ class Service : public ServiceInterface { // from the tracker are returned. StatusOr> ExecuteParallelAndRegisterResult( tensorflow::gtl::ArraySlice executables, - tensorflow::gtl::ArraySlice< - std::vector> - arguments, + tensorflow::gtl::ArraySlice> arguments, Backend* backend, tensorflow::gtl::ArraySlice device_handles, tensorflow::gtl::ArraySlice result_tags, diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index dcd726f22c71b4bd709dc63b25d6fdea477c83c7..9c1b951d017569a6dc89bc6583c72b5e42f0c07c 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -90,8 +91,6 @@ BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) { return BINOP_ATAN2; case HloOpcode::kComplex: return BINOP_COMPLEX; - case HloOpcode::kDot: - return BINOP_DOT; case HloOpcode::kMultiply: return BINOP_MUL; case HloOpcode::kAdd: @@ -441,6 +440,14 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, /* static */ StatusOr ShapeInference::InferConvertShape( const Shape& operand_shape, PrimitiveType new_element_type) { + auto old_element_type = operand_shape.element_type(); + if (primitive_util::IsComplexType(old_element_type) && + !primitive_util::IsComplexType(new_element_type)) { + return Unimplemented( + "Unsupported conversion from complex to real type: %s => %s", + ShapeUtil::HumanString(operand_shape).c_str(), + PrimitiveType_Name(new_element_type).c_str()); + } if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) { // Note: we may want to support tuple conversions via this operation in the // future, by recursing into the tuple elements to check all sub-conversions @@ -454,6 +461,36 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return ShapeUtil::ChangeElementType(operand_shape, new_element_type); } +/* static */ StatusOr ShapeInference::InferBitcastConvertShape( + const Shape& operand_shape, PrimitiveType new_element_type) { + auto old_element_type = operand_shape.element_type(); + if (primitive_util::IsComplexType(old_element_type) != + primitive_util::IsComplexType(new_element_type)) { + return Unimplemented( + "Unsupported conversion between real and complex types: %s => %s", + ShapeUtil::HumanString(operand_shape).c_str(), + PrimitiveType_Name(new_element_type).c_str()); + } + if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) { + // Note: we may want to support tuple conversions via this operation in the + // future, by recursing into the tuple elements to check all sub-conversions + // are valid. For now we just reject them, though. + return InvalidArgument( + "cannot convert from or to tuple type; requested conversion: %s => %s", + ShapeUtil::HumanString(operand_shape).c_str(), + PrimitiveType_Name(new_element_type).c_str()); + } + if (primitive_util::BitWidth(old_element_type) != + primitive_util::BitWidth(new_element_type)) { + return InvalidArgument( + "cannot bitcast types with different bit-widths: %s => %s", + PrimitiveType_Name(old_element_type).c_str(), + PrimitiveType_Name(new_element_type).c_str()); + } + + return ShapeUtil::ChangeElementType(operand_shape, new_element_type); +} + /* static */ StatusOr ShapeInference::InferReducePrecisionShape( const Shape& operand_shape, const int exponent_bits, const int mantissa_bits) { @@ -511,8 +548,113 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return ShapeUtil::MakeShape(operand_shape.element_type(), dimensions); } -/* static */ StatusOr ShapeInference::InferDotOpShape(const Shape& lhs, - const Shape& rhs) { +// Current DotDimensionNumbers Requirements: +// +// Contracting Dimensions: +// *) Exactly one contracting dimension on both lhs and rhs. +// *) Contracting dimension size must be the same on both lhs and rhs. +// *) Contracting dimension numbers do not need to be the same (i.e. transposes +// are passed on to emitter implementations). +// +// Batch Dimensions: +// *) Same number of batch dimensions on both lhs and rhs. +// *) Same batch dimension numbers (and sizes) on both lhs and rhs. +// *) Batch dimension numbers must be ordered before contracting and +// non-contracting/non-batch dimension numbers. +// +// Non-Contracting-Non-Batch Dimensions: +// *) Can be 0 (matrix-vector) or 1 (matrix-matrix). +// + +namespace { + +Status ValidateDotDimensionNumbers( + const Shape& lhs, const Shape& rhs, + const DotDimensionNumbers& dimension_numbers) { + // Check that dimension numbers are in range. + auto dims_in_range = + [](const int64 rank, tensorflow::gtl::ArraySlice contracting_dims, + tensorflow::gtl::ArraySlice batch_dims) -> bool { + auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; }; + return std::all_of(contracting_dims.begin(), contracting_dims.end(), + in_range) && + std::all_of(batch_dims.begin(), batch_dims.end(), in_range); + }; + + tensorflow::gtl::ArraySlice lhs_contracting_dimensions = + AsInt64Slice(dimension_numbers.lhs_contracting_dimensions()); + tensorflow::gtl::ArraySlice rhs_contracting_dimensions = + AsInt64Slice(dimension_numbers.rhs_contracting_dimensions()); + tensorflow::gtl::ArraySlice lhs_batch_dimensions = + AsInt64Slice(dimension_numbers.lhs_batch_dimensions()); + tensorflow::gtl::ArraySlice rhs_batch_dimensions = + AsInt64Slice(dimension_numbers.rhs_batch_dimensions()); + + if (!dims_in_range(ShapeUtil::Rank(lhs), lhs_contracting_dimensions, + lhs_batch_dimensions) || + !dims_in_range(ShapeUtil::Rank(rhs), rhs_contracting_dimensions, + rhs_batch_dimensions)) { + return InvalidArgument("A dimension number is out of range in dot: %s", + dimension_numbers.DebugString().c_str()); + } + + // Check that dimension numbers are unique. + auto dims_unique = [](tensorflow::gtl::ArraySlice contracting_dims, + tensorflow::gtl::ArraySlice batch_dims) -> bool { + tensorflow::gtl::FlatSet dim_set; + auto is_unique = [&dim_set](int64 i) -> bool { + return dim_set.insert(i).second; + }; + return std::all_of(contracting_dims.begin(), contracting_dims.end(), + is_unique) && + std::all_of(batch_dims.begin(), batch_dims.end(), is_unique); + }; + + if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) || + !dims_unique(rhs_contracting_dimensions, rhs_batch_dimensions)) { + return InvalidArgument("A dimension number is not unique in dot: %s", + dimension_numbers.DebugString().c_str()); + } + + // Check that the count of non-contracting-non-batch dimensions is in {0, 1}. + const int64 lhs_non_contracting_non_batch_dims = + ShapeUtil::Rank(lhs) - + dimension_numbers.lhs_contracting_dimensions_size() - + dimension_numbers.lhs_batch_dimensions_size(); + const int64 rhs_non_contracting_non_batch_dims = + ShapeUtil::Rank(rhs) - + dimension_numbers.rhs_contracting_dimensions_size() - + dimension_numbers.rhs_batch_dimensions_size(); + if (lhs_non_contracting_non_batch_dims < 0 || + lhs_non_contracting_non_batch_dims > 1 || + rhs_non_contracting_non_batch_dims < 0 || + rhs_non_contracting_non_batch_dims > 1) { + return InvalidArgument( + "batch and contracting dimension number mismatch " + "with rank "); + } + + // Check that batch dimension numbers are ordered before all others, and + // that they are monotonically increasing. + std::vector batch_dim_numbers(lhs_batch_dimensions.size()); + std::iota(batch_dim_numbers.begin(), batch_dim_numbers.end(), 0); + if (!std::equal(batch_dim_numbers.begin(), batch_dim_numbers.end(), + lhs_batch_dimensions.begin()) || + !std::equal(batch_dim_numbers.begin(), batch_dim_numbers.end(), + rhs_batch_dimensions.begin())) { + return InvalidArgument( + "batch dimension numbers must precede non-batch dimensions and be" + "monotonically increasing."); + } + + return Status::OK(); +} + +} // namespace + +/* static */ StatusOr ShapeInference::InferDotOpShape( + const Shape& lhs, const Shape& rhs, + const DotDimensionNumbers& dimension_numbers) { TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of dot")); TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of dot")); @@ -532,37 +674,62 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return fail("element types do not match"); } - if (ShapeUtil::Rank(lhs) < 1 || ShapeUtil::Rank(lhs) > 2 || - ShapeUtil::Rank(rhs) < 1 || ShapeUtil::Rank(rhs) > 2) { - return fail("dot only supports rank 1 or 2"); + if ((ShapeUtil::Rank(lhs) < 1) || (ShapeUtil::Rank(rhs) < 1)) { + return fail("dot only supports rank 1 or above."); + } + + // Validate basic properties of dot dimension numbers. + TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(lhs, rhs, dimension_numbers)); + + // Check that there is only one contracting dimension for both lhs and rhs. + if (dimension_numbers.lhs_contracting_dimensions_size() != + dimension_numbers.rhs_contracting_dimensions_size() || + dimension_numbers.lhs_contracting_dimensions_size() != 1) { + return fail("must specify one contracting dimension for both lhs and rhs."); } - // Determine the index of the contracted dimensions for input tensors. - // dimensions -1 of lhs and dimension 0 of rhs are contracted. - int64 lhs_contracted_dimension = ShapeUtil::GetDimensionNumber(lhs, -1); - int64 rhs_contracted_dimension = 0; + // Check that contracting dimension sizes match. + const int64 lhs_contracting_dimension = + dimension_numbers.lhs_contracting_dimensions(0); + const int64 rhs_contracting_dimension = + dimension_numbers.rhs_contracting_dimensions(0); + if (lhs.dimensions(lhs_contracting_dimension) != + rhs.dimensions(rhs_contracting_dimension)) { + return fail("contracting dimension sizes do not match."); + } - // Check if the contracted dimension sizes are the same. - if ((lhs_contracted_dimension < ShapeUtil::Rank(lhs) && - rhs_contracted_dimension < ShapeUtil::Rank(rhs)) && - lhs.dimensions(lhs_contracted_dimension) != - rhs.dimensions(rhs_contracted_dimension)) { - return fail("contracted dimensions mismatch"); + // Check that number of batch dimensions match. + if (dimension_numbers.lhs_batch_dimensions_size() != + dimension_numbers.rhs_batch_dimensions_size()) { + return fail("must the same number of batch dimensions for lhs and rhs."); + } + + // Check that batch dimension numbers and sizes match. + for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) { + if (dimension_numbers.lhs_batch_dimensions(i) != + dimension_numbers.rhs_batch_dimensions(i) || + lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) != + rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) { + return fail("batch dimension numbers and sizes must match for lhs/rhs."); + } } // The ranks of lhs and rhs are decremented by 1 respectively due to the // contraction, and added for the rank of the result. When an input tensor is // a scalar, its contribution to the rank of the result is 0. // Generate the result dimensions in order, rhs dimensions followed by lhs - // dimensions except the contracted dimensions. + // dimensions except the contracted and batch dimensions. std::vector dimensions; + std::unordered_set rhs_batch_dims( + dimension_numbers.rhs_batch_dimensions().begin(), + dimension_numbers.rhs_batch_dimensions().end()); for (int64 i = 0; i < ShapeUtil::Rank(lhs); i++) { - if (i != lhs_contracted_dimension) { + if (i != lhs_contracting_dimension) { dimensions.push_back(lhs.dimensions(i)); } } for (int64 i = 0; i < ShapeUtil::Rank(rhs); i++) { - if (i != rhs_contracted_dimension) { + if (i != rhs_contracting_dimension && rhs_batch_dims.count(i) == 0) { dimensions.push_back(rhs.dimensions(i)); } } @@ -778,8 +945,6 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( rhs, tensorflow::strings::StrCat("rhs of binary operation ", BinaryOperation_Name(operation)))); switch (operation) { - case BINOP_DOT: - return InferDotOpShape(lhs, rhs); case BINOP_MAX: case BINOP_MIN: case BINOP_SUB: @@ -1407,7 +1572,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str()); } - if (dnums.spatial_dimensions_size() != + if (dnums.input_spatial_dimensions_size() != dnums.kernel_spatial_dimensions_size()) { return InvalidArgument( "Both arguments to convolution must have same number of dimensions.\n" @@ -1415,7 +1580,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( window.DebugString().c_str()); } - const int num_spatial_dims = dnums.spatial_dimensions_size(); + const int num_spatial_dims = dnums.input_spatial_dimensions_size(); if (window.dimensions_size() != num_spatial_dims) { return InvalidArgument( "Window must have same number of dimensions as dimension numbers.\n" @@ -1444,8 +1609,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( std::vector input_dnums(num_dims); input_dnums[0] = dnums.input_batch_dimension(); input_dnums[1] = dnums.input_feature_dimension(); - std::copy(dnums.spatial_dimensions().begin(), - dnums.spatial_dimensions().end(), input_dnums.begin() + 2); + std::copy(dnums.input_spatial_dimensions().begin(), + dnums.input_spatial_dimensions().end(), input_dnums.begin() + 2); std::sort(input_dnums.begin(), input_dnums.end()); std::vector window_dnums(num_dims); @@ -1455,12 +1620,20 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( dnums.kernel_spatial_dimensions().end(), window_dnums.begin() + 2); std::sort(window_dnums.begin(), window_dnums.end()); + std::vector output_dnums(num_dims); + output_dnums[0] = dnums.output_batch_dimension(); + output_dnums[1] = dnums.output_feature_dimension(); + std::copy(dnums.output_spatial_dimensions().begin(), + dnums.output_spatial_dimensions().end(), output_dnums.begin() + 2); + std::sort(output_dnums.begin(), output_dnums.end()); + std::vector expected_dnums(num_dims); std::iota(expected_dnums.begin(), expected_dnums.end(), 0); const auto in_range = [num_dims](int64 i) { return 0 <= i && i < num_dims; }; if (!std::all_of(input_dnums.begin(), input_dnums.end(), in_range) || - !std::all_of(window_dnums.begin(), window_dnums.end(), in_range)) { + !std::all_of(window_dnums.begin(), window_dnums.end(), in_range) || + !std::all_of(output_dnums.begin(), output_dnums.end(), in_range)) { return InvalidArgument( "A dimension number is out of range in convolution: %s", dnums.DebugString().c_str()); @@ -1478,10 +1651,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( "once: %s", dnums.DebugString().c_str()); } + if (output_dnums != expected_dnums) { + return InvalidArgument( + "Output dimensions of convolution must contain each dimension exactly " + "once: %s", + dnums.DebugString().c_str()); + } std::vector input_spatial_dims(num_spatial_dims); for (int i = 0; i < num_spatial_dims; ++i) { - input_spatial_dims[i] = lhs.dimensions(dnums.spatial_dimensions(i)); + input_spatial_dims[i] = lhs.dimensions(dnums.input_spatial_dimensions(i)); } const int64 input_features = lhs.dimensions(dnums.input_feature_dimension()); const int64 input_batch = lhs.dimensions(dnums.input_batch_dimension()); @@ -1529,17 +1708,27 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( dimensions[dnums.output_batch_dimension()] = input_batch; dimensions[dnums.output_feature_dimension()] = kernel_output_features; for (int i = 0; i < num_spatial_dims; ++i) { - dimensions[dnums.spatial_dimensions(i)] = window_output_shape.dimensions(i); + dimensions[dnums.output_spatial_dimensions(i)] = + window_output_shape.dimensions(i); } return ShapeUtil::MakeShape(lhs.element_type(), dimensions); } /* static */ StatusOr ShapeInference::InferCrossReplicaSumShape( - const Shape& operand) { - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand, "operand of cross replica sum")); - return operand; + tensorflow::gtl::ArraySlice operand_shapes) { + for (const Shape* operand_shape : operand_shapes) { + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(*operand_shape, "operand of cross replica sum")); + } + if (operand_shapes.size() == 1) { + return *operand_shapes[0]; + } + std::vector operand_shape_values; + for (const Shape* operand_shape : operand_shapes) { + operand_shape_values.push_back(*operand_shape); + } + return ShapeUtil::MakeTupleShape(operand_shape_values); } /* static */ StatusOr ShapeInference::InferReduceShape( @@ -1905,6 +2094,64 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return init; } +/* static */ StatusOr ShapeInference::InferConditionalShape( + const Shape& predicate, const Shape& true_operand, + const Shape& false_operand, const ProgramShape& true_computation, + const ProgramShape& false_computation) { + if (!ShapeUtil::ShapeIs(predicate, PRED, {})) { + return InvalidArgument("predicate must be a boolean; got %s.", + ShapeUtil::HumanString(predicate).c_str()); + } + + if (true_computation.parameters_size() != 1) { + return InvalidArgument("true_computation must take 1 argument; got %d.", + true_computation.parameters_size()); + } + if (!ShapeUtil::Compatible(true_computation.parameters(0), true_operand)) { + auto true_shape_string = [&]() { + return tensorflow::strings::Printf( + "true_operand: %s; true_computation: %s", + ShapeUtil::HumanString(true_operand).c_str(), + ShapeUtil::HumanString(true_computation).c_str()); + }; + return InvalidArgument( + "true_operand must match the shape of the only parameter of " + "true_computation: got %s.", + true_shape_string().c_str()); + } + + if (false_computation.parameters_size() != 1) { + return InvalidArgument("false_computation must take 1 argument; got %d.", + false_computation.parameters_size()); + } + if (!ShapeUtil::Compatible(false_computation.parameters(0), false_operand)) { + auto false_shape_string = [&]() { + return tensorflow::strings::Printf( + "false_operand: %s; false_computation: %s", + ShapeUtil::HumanString(false_operand).c_str(), + ShapeUtil::HumanString(false_computation).c_str()); + }; + return InvalidArgument( + "false_operand must match the shape of the only parameter of " + "false_computation: got %s.", + false_shape_string().c_str()); + } + if (!ShapeUtil::Compatible(true_computation.result(), + false_computation.result())) { + auto shape_string = [&]() { + return tensorflow::strings::Printf( + "true_computation result: %s; false_computation result: %s.", + ShapeUtil::HumanString(true_computation.result()).c_str(), + ShapeUtil::HumanString(false_computation.result()).c_str()); + }; + return InvalidArgument( + "the result of true_computation and false_computation must have the " + "same shape: got %s.", + shape_string().c_str()); + } + return true_computation.result(); +} + /* static */ StatusOr ShapeInference::InferBroadcastShape( const Shape& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "operand of broadcast")); diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index d5d497176d6c340d8c8f34cdacf6a9e32040c387..c06340d2d5df239642eb0af4836df64a898a1eaf 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -109,8 +109,10 @@ class ShapeInference { const Shape& lhs, const Shape& rhs, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers); - // Infers the shape produced a cross replica sum with the given operand shape. - static StatusOr InferCrossReplicaSumShape(const Shape& operand); + // Infers the shape produced a cross replica sum with the given operand + // shapes. + static StatusOr InferCrossReplicaSumShape( + tensorflow::gtl::ArraySlice operand_shapes); // Infers the shape produced by applying the given reduction computation // shape to the given input operand shape. @@ -178,6 +180,12 @@ class ShapeInference { const ProgramShape& body, const Shape& init); + // Infers the shape produced by a conditional operation. + static StatusOr InferConditionalShape( + const Shape& predicate, const Shape& true_operand, + const Shape& false_operand, const ProgramShape& true_computation, + const ProgramShape& false_computation); + // Infers the shape produced by a broadcast operation. static StatusOr InferBroadcastShape( const Shape& operand, tensorflow::gtl::ArraySlice broadcast_sizes); @@ -204,6 +212,13 @@ class ShapeInference { static StatusOr InferConvertShape(const Shape& operand_shape, PrimitiveType new_element_type); + // Helper that validates the given operand shape can be bitcast converted to + // the target output_shape via a bitcast convert instruction -- the + // requirement is that the shape is identical except for the element type and + // the element types have identical bit-widths. + static StatusOr InferBitcastConvertShape( + const Shape& operand_shape, PrimitiveType new_element_type); + // Helper that validates the input data type for a reduce-precision operation, // and returns the result shape. static StatusOr InferReducePrecisionShape(const Shape& operand_shape, @@ -222,11 +237,13 @@ class ShapeInference { tensorflow::gtl::ArraySlice arg_shapes, const ProgramShape& to_apply); - private: // Helper that infers the shape produced by performing a dot operation with // the given LHS and RHS shapes. - static StatusOr InferDotOpShape(const Shape& lhs, const Shape& rhs); + static StatusOr InferDotOpShape( + const Shape& lhs, const Shape& rhs, + const DotDimensionNumbers& dimension_numbers); + private: // Helper that infers the shape produced by performing an element-wise binary // operation with the given LHS and RHS shapes. // Note: By "element-wise" we mean operations that look at a single element in diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index d12f7bd1453890db3280e54719a6ce811006336d..99d87f3b550ae72befe254f23fad080dd210aaf4 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -395,8 +395,10 @@ TEST_F(ShapeInferenceTest, Convolve) { dnums.set_output_batch_dimension(0); dnums.set_input_feature_dimension(1); dnums.set_output_feature_dimension(1); - dnums.add_spatial_dimensions(2); - dnums.add_spatial_dimensions(3); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(3); + dnums.add_output_spatial_dimensions(3); // Dimension order: x1, batch, feature, x0 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3}); @@ -437,8 +439,10 @@ TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) { dnums.set_output_batch_dimension(0); dnums.set_input_feature_dimension(1); dnums.set_output_feature_dimension(1); - dnums.add_spatial_dimensions(2); - dnums.add_spatial_dimensions(3); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(3); + dnums.add_output_spatial_dimensions(3); // Dimension order: x1, batch, feature, x0 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3}); @@ -480,8 +484,10 @@ TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) { dnums.set_output_batch_dimension(0); dnums.set_input_feature_dimension(1); dnums.set_output_feature_dimension(1); - dnums.add_spatial_dimensions(2); - dnums.add_spatial_dimensions(3); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(3); + dnums.add_output_spatial_dimensions(3); // Dimension order: x1, batch, feature, x0 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 4}); @@ -524,8 +530,10 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) { dnums.set_output_batch_dimension(3); dnums.set_input_feature_dimension(2); dnums.set_output_feature_dimension(2); - dnums.add_spatial_dimensions(0); - dnums.add_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(0); + dnums.add_output_spatial_dimensions(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); dnums.set_kernel_input_feature_dimension(0); // duplicated with kernel_x0 dnums.set_kernel_output_feature_dimension(3); dnums.add_kernel_spatial_dimensions(0); @@ -890,8 +898,11 @@ TEST_F(ShapeInferenceTest, BroadcastScalar) { // scalar vector: error TEST_F(ShapeInferenceTest, ScalarDotVector) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_DOT, f32_, vector_32_, {}); + ShapeInference::InferDotOpShape(f32_, vector_32_, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("dot only supports rank")); @@ -899,61 +910,199 @@ TEST_F(ShapeInferenceTest, ScalarDotVector) { // 3D 2D: error TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) { - auto inferred_status = ShapeInference::InferBinaryOpShape( - BINOP_DOT, ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, {}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto inferred_status = ShapeInference::InferDotOpShape( + ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("dot only supports rank")); + HasSubstr("batch and contracting dimension number mismatch")); } // vector vector -> scalar TEST_F(ShapeInferenceTest, VectorDotVector) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_DOT, vector_64_, vector_64_, {}); + ShapeInference::InferDotOpShape(vector_64_, vector_64_, dot_dnums); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie())); auto inferred_status_mismatch = - ShapeInference::InferBinaryOpShape(BINOP_DOT, vector_64_, vector_32_, {}); + ShapeInference::InferDotOpShape(vector_64_, vector_32_, dot_dnums); ASSERT_FALSE(inferred_status_mismatch.ok()); } // matrix vector -> vector TEST_F(ShapeInferenceTest, MatrixDotVector) { - auto inferred_status = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, matrix_32_64_, vector_64_, {}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto inferred_status = + ShapeInference::InferDotOpShape(matrix_32_64_, vector_64_, dot_dnums); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_32_)); - auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, matrix_32_64_, vector_32_, {}); + auto inferred_status_mismatch = + ShapeInference::InferDotOpShape(matrix_32_64_, vector_32_, dot_dnums); ASSERT_FALSE(inferred_status_mismatch.ok()); } // vector matrix -> vector TEST_F(ShapeInferenceTest, VectorDotMatrix) { - auto inferred_status = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, vector_32_, matrix_32_64_, {}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + auto inferred_status = + ShapeInference::InferDotOpShape(vector_32_, matrix_32_64_, dot_dnums); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_64_)); - auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, vector_64_, matrix_32_64_, {}); + auto inferred_status_mismatch = + ShapeInference::InferDotOpShape(vector_64_, matrix_32_64_, dot_dnums); ASSERT_FALSE(inferred_status_mismatch.ok()); } // matrix matrix -> matrix TEST_F(ShapeInferenceTest, MatrixDotMatrix) { - auto inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, matrix_32_64_, matrix_64_48_, {}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto inferred_status_match = + ShapeInference::InferDotOpShape(matrix_32_64_, matrix_64_48_, dot_dnums); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE( ShapeUtil::Equal(inferred_status_match.ValueOrDie(), matrix_32_48_)) << "inferred: " << ShapeUtil::HumanString(inferred_status_match.ValueOrDie()) << " expected: " << ShapeUtil::HumanString(matrix_64_48_); - auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, matrix_32_64_, matrix_32_64_, {}); + auto inferred_status_mismatch = + ShapeInference::InferDotOpShape(matrix_32_64_, matrix_32_64_, dot_dnums); ASSERT_FALSE(inferred_status_mismatch.ok()); } +// BatchMatMul with two batch dimensions and one contracting dimension. +TEST_F(ShapeInferenceTest, DotGeneral) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 3, 14}); + Shape output_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(3); + dot_dnums.add_lhs_batch_dimensions(0); + dot_dnums.add_lhs_batch_dimensions(1); + + dot_dnums.add_rhs_contracting_dimensions(2); + dot_dnums.add_rhs_batch_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(1); + + auto inferred_status_match = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_IS_OK(inferred_status_match.status()); + ASSERT_TRUE( + ShapeUtil::Equal(inferred_status_match.ValueOrDie(), output_shape)) + << "inferred: " + << ShapeUtil::HumanString(inferred_status_match.ValueOrDie()) + << " expected: " << ShapeUtil::HumanString(output_shape); +} + +// BatchMatMul with two contracting dimensions fails. +TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); + Shape output_shape = ShapeUtil::MakeShape(F32, {2, 11, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(2); + dot_dnums.add_lhs_contracting_dimensions(3); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(1); + dot_dnums.add_rhs_batch_dimensions(0); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("must specify one contracting dimension for both " + "lhs and rhs")); +} + +// BatchMatMul with different batch dimension sizes fails. +TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimSizesFails) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 3, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(2); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(1); + dot_dnums.add_rhs_batch_dimensions(0); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("batch dimension numbers and sizes must match")); +} + +// BatchMatMul with different batch dimension numbers fails. +TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersFails) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 2, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(2); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(1); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("batch dimension numbers must precede non-batch")); +} + +// BatchMatMul with out-of-range dimension numbers fails. +TEST_F(ShapeInferenceTest, DotWithContractingDimNumberOutOfRange) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(3); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(1); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("A dimension number is out of range")); +} + +// BatchMatMul with non-unique dimension numbers fails. +TEST_F(ShapeInferenceTest, DotWithContractingNonUniqueDimNumber) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(1); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("A dimension number is not unique")); +} + TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) { // Test variations of broadcasting a vector for a binary add with a // matrix. @@ -1288,5 +1437,80 @@ TEST_F(ShapeInferenceTest, Transpose) { ShapeUtil::MakeShape(F32, {3, 4, 5, 2}))); } +TEST_F(ShapeInferenceTest, Conditional) { + auto inferred_status0 = ShapeInference::InferConditionalShape( + pred_, vector_32_, vector_64_, + ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, f32_)); + EXPECT_IS_OK(inferred_status0.status()); + EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie())); + + auto inferred_status1 = ShapeInference::InferConditionalShape( + pred_, matrix_32_48_, vector_32_, + ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_), + ShapeUtil::MakeProgramShape({vector_32_}, vector_64_)); + EXPECT_IS_OK(inferred_status1.status()); + EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie())); + + auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_}); + auto inferred_status2 = ShapeInference::InferConditionalShape( + pred_, matrix_32_48_, tuple_f32_v32, + ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), + ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)); + EXPECT_IS_OK(inferred_status2.status()); + EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie())); + + auto inferred_status_error0 = ShapeInference::InferConditionalShape( + s32_, vector_32_, vector_64_, + ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, f32_)); + EXPECT_FALSE(inferred_status_error0.ok()); + EXPECT_THAT(inferred_status_error0.status().error_message(), + HasSubstr("predicate must be a boolean")); + + auto inferred_status_error1 = ShapeInference::InferConditionalShape( + pred_, ShapeUtil::MakeTupleShape({f32_, vector_32_}), matrix_32_48_, + ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_), + ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)); + EXPECT_FALSE(inferred_status_error1.ok()); + EXPECT_THAT(inferred_status_error1.status().error_message(), + HasSubstr("true_computation must take 1 argument")); + + auto inferred_status_error2 = ShapeInference::InferConditionalShape( + pred_, vector_32_, vector_64_, + ShapeUtil::MakeProgramShape({vector_64_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, f32_)); + EXPECT_FALSE(inferred_status_error2.ok()); + EXPECT_THAT(inferred_status_error2.status().error_message(), + HasSubstr("true_operand must match the shape of the only " + "parameter of true_computation")); + + auto inferred_status_error3 = ShapeInference::InferConditionalShape( + pred_, matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_}), + ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), + ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_)); + EXPECT_FALSE(inferred_status_error3.ok()); + EXPECT_THAT(inferred_status_error3.status().error_message(), + HasSubstr("false_computation must take 1 argument")); + + auto inferred_status_error4 = ShapeInference::InferConditionalShape( + pred_, vector_32_, vector_64_, + ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_32_}, f32_)); + EXPECT_FALSE(inferred_status_error4.ok()); + EXPECT_THAT(inferred_status_error4.status().error_message(), + HasSubstr("false_operand must match the shape of the only " + "parameter of false_computation")); + + auto inferred_status_error5 = ShapeInference::InferConditionalShape( + pred_, vector_32_, vector_64_, + ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)); + EXPECT_FALSE(inferred_status_error5.ok()); + EXPECT_THAT(inferred_status_error5.status().error_message(), + HasSubstr("the result of true_computation and false_computation " + "must have the same shape")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index a7539a1a11d2bbd62c780890c6730dbb212307c4..c679d401c3691b14a43ce77cbe953cd4c64a9e92 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -34,58 +34,32 @@ namespace xla { using ::tensorflow::strings::Appendf; -/* static */ StatusOr> -ShapedBuffer::MakeArrayShapedBuffer(const Shape& shape, - const se::Platform* platform, - int device_ordinal, - const se::DeviceMemoryBase& buffer) { - if (ShapeUtil::IsTuple(shape)) { - return InvalidArgument("Shape must be an array: %s", - ShapeUtil::HumanStringWithLayout(shape).c_str()); - } - auto shaped_buffer = - MakeUnique(shape, platform, device_ordinal); - *shaped_buffer->mutable_shape_index_to_buffer_entry()->mutable_element({}) = - 0; - *shaped_buffer->mutable_buffers() = {buffer}; - return std::move(shaped_buffer); -} - -ShapedBuffer::ShapedBuffer(const Shape& shape, const se::Platform* platform, - int device_ordinal) - : shape_(shape), +ShapedBuffer::ShapedBuffer(const Shape& on_host_shape, + const Shape& on_device_shape, + const se::Platform* platform, int device_ordinal) + : on_host_shape_(on_host_shape), + on_device_shape_(on_device_shape), platform_(platform), device_ordinal_(device_ordinal), - shape_index_to_buffer_entry_(shape) {} + buffers_(on_device_shape) {} void ShapedBuffer::clear() { - for (se::DeviceMemoryBase& memory_base : buffers_) { + for (auto& pair : buffers_) { // A default constructed DeviceMemoryBase is a null pointer. - memory_base = se::DeviceMemoryBase(); + pair.second = se::DeviceMemoryBase(); } } -void ShapedBuffer::AddBufferAtIndex( - const perftools::gputools::DeviceMemoryBase& buffer, - const ShapeIndex& shape_index) { - *mutable_shape_index_to_buffer_entry()->mutable_element(shape_index) = - buffers().size(); - mutable_buffers()->push_back(buffer); -} - -const se::DeviceMemoryBase& ShapedBuffer::buffer( - const ShapeIndex& index) const { - return buffers_[shape_index_to_buffer_entry_.element(index)]; -} - -se::DeviceMemoryBase* ShapedBuffer::mutable_buffer(const ShapeIndex& index) { - return &buffers_[shape_index_to_buffer_entry_.element(index)]; -} - string ShapedBuffer::ToString() const { - string s = "ShapedBuffer(" + platform_->Name() + "):\n"; + string s = tensorflow::strings::StrCat( + "ShapedBuffer(", platform_->Name(), ":", device_ordinal(), + "), on-host shape=" + ShapeUtil::HumanStringWithLayout(on_host_shape()), + ", on-device shape=" + + ShapeUtil::HumanStringWithLayout(on_device_shape()), + ":\n"); ShapeUtil::ForEachSubshape( - shape(), [this, &s](const Shape& subshape, const ShapeIndex& index) { + on_device_shape(), + [this, &s](const Shape& subshape, const ShapeIndex& index) { string shape_str; if (ShapeUtil::IsTuple(subshape)) { shape_str = "tuple"; @@ -105,53 +79,24 @@ std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer) { return out; } -/* static */ StatusOr> -ScopedShapedBuffer::Allocate( - const Shape& shape, DeviceMemoryAllocator* allocator, int device_ordinal, - const std::function& shape_size_fn) { - if (!LayoutUtil::HasLayout(shape)) { - return InvalidArgument("Shape must have a layout: %s", - ShapeUtil::HumanStringWithLayout(shape).c_str()); - } - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape)); - auto shaped_buffer = - WrapUnique(new ScopedShapedBuffer(shape, allocator, device_ordinal)); - - // Allocate an appropriate sized buffer for each element in the shape - // including the tuple pointer arrays. - for (auto& pair : shaped_buffer->shape_index_to_buffer_entry_) { - const ShapeIndex& index = pair.first; - size_t& buffer_entry = pair.second; - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase memory_base, - shaped_buffer->allocator_->Allocate( - shaped_buffer->device_ordinal(), - shape_size_fn(ShapeUtil::GetSubshape( - shaped_buffer->shape(), index)))); - shaped_buffer->buffers_.push_back(memory_base); - buffer_entry = shaped_buffer->buffers_.size() - 1; - } - - return std::move(shaped_buffer); -} - /* static */ StatusOr> ScopedShapedBuffer::MakeScoped( ShapedBuffer* shaped_buffer, DeviceMemoryAllocator* allocator) { auto scoped_buffer = WrapUnique(new ScopedShapedBuffer( - shaped_buffer->shape(), allocator, shaped_buffer->device_ordinal())); + shaped_buffer->on_host_shape(), shaped_buffer->on_device_shape(), + allocator, shaped_buffer->device_ordinal())); scoped_buffer->buffers_ = shaped_buffer->buffers(); - scoped_buffer->shape_index_to_buffer_entry_ = - shaped_buffer->shape_index_to_buffer_entry(); - shaped_buffer->clear(); return std::move(scoped_buffer); } -ScopedShapedBuffer::ScopedShapedBuffer(const Shape& shape, +ScopedShapedBuffer::ScopedShapedBuffer(const Shape& on_host_shape, + const Shape& on_device_shape, DeviceMemoryAllocator* allocator, int device_ordinal) - : ShapedBuffer(shape, allocator->platform(), device_ordinal), + : ShapedBuffer(on_host_shape, on_device_shape, allocator->platform(), + device_ordinal), allocator_(allocator) {} ScopedShapedBuffer::~ScopedShapedBuffer() { @@ -159,7 +104,8 @@ ScopedShapedBuffer::~ScopedShapedBuffer() { // in the shape (eg, a tuple with a repeated element) so keep track of what // has been deallocated. std::set deallocated_opaques; - for (se::DeviceMemoryBase& memory_base : buffers_) { + for (auto& pair : buffers_) { + se::DeviceMemoryBase& memory_base = pair.second; if (!memory_base.is_null() && deallocated_opaques.count(memory_base.opaque()) == 0) { deallocated_opaques.insert(memory_base.opaque()); @@ -170,13 +116,10 @@ ScopedShapedBuffer::~ScopedShapedBuffer() { } std::unique_ptr ScopedShapedBuffer::release() { - auto shaped_buffer = - MakeUnique(shape(), platform(), device_ordinal()); - - *shaped_buffer->mutable_buffers() = buffers(); - *shaped_buffer->mutable_shape_index_to_buffer_entry() = - shape_index_to_buffer_entry(); + auto shaped_buffer = MakeUnique( + on_host_shape(), on_device_shape(), platform(), device_ordinal()); + shaped_buffer->buffers() = buffers(); clear(); return shaped_buffer; diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h index fa88caa13ff734995e8ab0925f17d0d3c26b8fda..f570ebb9cbb2837d3eadc32fe269845c995f7f89 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.h +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -31,61 +31,68 @@ limitations under the License. namespace xla { // Class which encapsulates a buffer or set of buffers containing data of a -// particular XLA shape. Used for zero-copy execution interface for a -// XLA client running in the same process as the service (LocalClient), +// particular XLA shape. class ShapedBuffer { public: - // Convenience method which creates a ShapedBuffer of array shape (not a - // tuple). Its single buffer pointer is set to the given value "buffer". The - // given buffer must be large enough to store the given shape as given by - // ShapeUtil::ByteSizeOf. - static StatusOr> MakeArrayShapedBuffer( - const Shape& shape, const perftools::gputools::Platform* platform, - int device_ordinal, const perftools::gputools::DeviceMemoryBase& buffer); - - ShapedBuffer(const Shape& shape, + // Construct a ScopedShapedBuffer with null DeviceMemoryBases at each + // index. The shape of the data on the host and the device may differ because + // the device may have a different representation for different data + // types. Therefore, both the on-host and on-device shape are required. The + // on-device shape determines the number of device allocations + // (DeviceMemoryBase) held by the ShapedBuffer. + ShapedBuffer(const Shape& on_host_shape, const Shape& on_device_shape, const perftools::gputools::Platform* platform, int device_ordinal); - const Shape& shape() const { return shape_; } + // Returns the shape of the on-host representation of the data held by this + // ShapedBuffer. + const Shape& on_host_shape() const { return on_host_shape_; } + + // Returns the shape of the on-device representation of the data held by this + // ShapedBuffer. + const Shape& on_device_shape() const { return on_device_shape_; } + const perftools::gputools::Platform* platform() const { return platform_; } int device_ordinal() const { return device_ordinal_; } + // Return the root buffer of the shape (shape index {}). + const perftools::gputools::DeviceMemoryBase& root_buffer() const { + return buffer(/*index=*/{}); + } + // Returns the buffer at the given shape index where index is defined as in // ShapeUtil::GetSubshape. const perftools::gputools::DeviceMemoryBase& buffer( - const ShapeIndex& index) const; - perftools::gputools::DeviceMemoryBase* mutable_buffer( - const ShapeIndex& index); - - // Returns the underlying structure which stores the buffer pointers. - const std::vector& buffers() const { - return buffers_; + const ShapeIndex& index) const { + return buffers_.element(index); } - std::vector* mutable_buffers() { - return &buffers_; + + // Sets the device memory buffer at the given index. + void set_buffer(const perftools::gputools::DeviceMemoryBase& buffer, + const ShapeIndex& index) { + *buffers_.mutable_element(index) = buffer; } - // Returns the tree of indices which map to buffer pointers. - const ShapeTree& shape_index_to_buffer_entry() const { - return shape_index_to_buffer_entry_; + // Returns the underlying ShapeTree containing all the device addresses in the + // ShapedBuffer. + const ShapeTree& buffers() const { + return buffers_; } - ShapeTree* mutable_shape_index_to_buffer_entry() { - return &shape_index_to_buffer_entry_; + ShapeTree& buffers() { + return buffers_; } // Set all device memory pointers in the object to null. void clear(); - // Adds a new buffer at the given shape index. - void AddBufferAtIndex(const perftools::gputools::DeviceMemoryBase& buffer, - const ShapeIndex& shape_index); - string ToString() const; protected: - // The shape of the device buffer with layout. - const Shape shape_; + // The shape of the data when represented on the host. + const Shape on_host_shape_; + + // The shape of the data on the device. + const Shape on_device_shape_; // The platform the memory is allocated on. const perftools::gputools::Platform* platform_; @@ -93,14 +100,8 @@ class ShapedBuffer { // The device the memory is allocated on. const int device_ordinal_; - // The list of DeviceMemoryBase pointers representing this shape. - // Note that there can be a many to one relationship between tuple elements - // and buffers. To account for this, shape_index_to_buffer_entry_ allows us - // to make from a position in a shape to an index into this list. - std::vector buffers_; - - // The tree of indices into buffers_. - ShapeTree shape_index_to_buffer_entry_; + // The tree of device buffers. Its shape is on_device_shape(). + ShapeTree buffers_; }; std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer); @@ -110,20 +111,16 @@ std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer); // destructed. class ScopedShapedBuffer : public ShapedBuffer { public: - // Return a newly allocated ScopedShapedBuffer of an arbitrary shape. Array - // buffers (leaves in the shape) are allocated and uninitialized. Tuple - // buffers (if any) are allocated and initialized to the backend-specific - // representation of an array of pointers to the tuple elements. - static StatusOr> Allocate( - const Shape& shape, DeviceMemoryAllocator* allocator, int device_ordinal, - const std::function& shape_size_fn); - // Takes a ShapedBuffer and returns a ScopedShapedBuffer which manages the // deallocation of the device memory held in the shaped buffer. All device // memory pointers in the given ShapedBuffer are set to null. static StatusOr> MakeScoped( ShapedBuffer* shaped_buffer, DeviceMemoryAllocator* allocator); + // Create a ScopedShapedBuffer with null DeviceMemoryBases at each index. + ScopedShapedBuffer(const Shape& on_host_shape, const Shape& on_device_shape, + DeviceMemoryAllocator* allocator, int device_ordinal); + // Return the allocator used to allocate the device memory held in this // ScopedShapedBuffer. DeviceMemoryAllocator* memory_allocator() const { return allocator_; } @@ -138,8 +135,6 @@ class ScopedShapedBuffer : public ShapedBuffer { virtual ~ScopedShapedBuffer(); protected: - ScopedShapedBuffer(const Shape& shape, DeviceMemoryAllocator* allocator, - int device_ordinal); ScopedShapedBuffer(const ScopedShapedBuffer&) = delete; void operator=(const ScopedShapedBuffer&) = delete; diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index d5f53ad56fb019d0ae7c27fc28706f05614ece68..2f36e2b16e0f2eed10aef811dd3cceeba6a5b8a9 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -40,6 +40,45 @@ TransferManager::GetPlatformTransferManagers() { return r; } +Status TransferManager::TransferArrayToDevice( + perftools::gputools::StreamExecutor* executor, const Literal& literal, + const perftools::gputools::DeviceMemoryBase& dest) { + const Shape on_device_shape = HostShapeToDeviceShape(literal.shape()); + TF_RET_CHECK(ShapeUtil::IsArray(on_device_shape)) + << "On-device representation of " + << ShapeUtil::HumanString(literal.shape()) + << " is not an array: " << ShapeUtil::HumanString(on_device_shape); + if (dest.size() < GetByteSizeRequirement(on_device_shape)) { + return FailedPrecondition( + "Allocation on device not large enough for array: " + "%lld < %lld", + dest.size(), GetByteSizeRequirement(on_device_shape)); + } + ShapedBuffer shaped_buffer(/*on_host_shape=*/literal.shape(), on_device_shape, + executor->platform(), executor->device_ordinal()); + shaped_buffer.set_buffer(dest, /*index=*/{}); + return TransferLiteralToDevice(executor, literal, shaped_buffer); +} + +StatusOr> TransferManager::TransferArrayFromDevice( + perftools::gputools::StreamExecutor* executor, const Shape& shape, + const perftools::gputools::DeviceMemoryBase& source) { + TF_RET_CHECK(ShapeUtil::Equal(HostShapeToDeviceShape(shape), shape)) + << "Shape " << ShapeUtil::HumanString(shape) + << " has a differently shaped representation on-device: " + << ShapeUtil::HumanString(HostShapeToDeviceShape(shape)); + if (source.size() < GetByteSizeRequirement(shape)) { + return FailedPrecondition( + "Allocation on device not large enough for array: " + "%lld < %lld", + source.size(), GetByteSizeRequirement(shape)); + } + ShapedBuffer shaped_buffer(/*on_host_shape=*/shape, shape, + executor->platform(), executor->device_ordinal()); + shaped_buffer.set_buffer(source, /*index=*/{}); + return TransferLiteralFromDevice(executor, shaped_buffer); +} + /* static */ void TransferManager::RegisterTransferManager( se::Platform::Id platform_id, TransferManagerCreationFunction creation_function) { @@ -75,14 +114,12 @@ TransferManager::GetPlatformTransferManagers() { Status TransferManager::WriteTupleIndexTables( perftools::gputools::StreamExecutor* executor, const ShapedBuffer& device_buffer) { - VLOG(2) << "Writing tuple index tables to ShapedBuffer rooted at " - << device_buffer.buffer(/*index=*/{}).opaque() - << "; shape: " << ShapeUtil::HumanString(device_buffer.shape()); + VLOG(2) << "Writing tuple index tables for " << device_buffer; TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); return ShapeUtil::ForEachSubshapeWithStatus( - device_buffer.shape(), + device_buffer.on_device_shape(), [&](const Shape& device_subshape, const ShapeIndex& index) -> Status { if (ShapeUtil::IsTuple(device_subshape)) { se::DeviceMemoryBase device_memory = device_buffer.buffer(index); @@ -97,7 +134,7 @@ Status TransferManager::WriteTupleIndexTables( elements.push_back(device_buffer.buffer(element_index)); element_index.pop_back(); } - return WriteTuplePointersToDevice(executor, elements, device_subshape, + return WriteSingleTupleIndexTable(executor, elements, device_subshape, &device_memory); } @@ -143,31 +180,43 @@ Status TransferManager::TransferBufferToDevice( return Status::OK(); } -StatusOr> -TransferManager::GatherBufferPointersFromTuple( - se::StreamExecutor* executor, const se::DeviceMemoryBase& source, - const Shape& shape) { - TF_RET_CHECK(ShapeUtil::IsTuple(shape)); - - std::set buffer_pointers; - buffer_pointers.insert(source); - - TF_ASSIGN_OR_RETURN(std::vector tuple_elements, - ShallowCopyTupleFromDevice(executor, source, shape)); - for (auto i = 0; i < tuple_elements.size(); ++i) { - const Shape& element_shape = shape.tuple_shapes(i); - if (ShapeUtil::IsTuple(element_shape)) { - TF_ASSIGN_OR_RETURN( - std::set buffer_pointers_in_element, - GatherBufferPointersFromTuple(executor, tuple_elements[i], - element_shape)); - buffer_pointers.insert(buffer_pointers_in_element.begin(), - buffer_pointers_in_element.end()); - } else { - buffer_pointers.insert(tuple_elements[i]); - } +StatusOr> TransferManager::AllocateShapedBuffer( + const Shape& on_host_shape, DeviceMemoryAllocator* allocator, + int device_ordinal) { + if (!LayoutUtil::HasLayout(on_host_shape)) { + return InvalidArgument( + "Shape must have a layout: %s", + ShapeUtil::HumanStringWithLayout(on_host_shape).c_str()); + } + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(on_host_shape)); + const Shape on_device_shape = HostShapeToDeviceShape(on_host_shape); + TF_RET_CHECK(LayoutUtil::HasLayout(on_device_shape)); + + auto shaped_buffer = WrapUnique(new ShapedBuffer( + on_host_shape, on_device_shape, allocator->platform(), device_ordinal)); + + // Allocate an appropriate sized buffer for each element in the shape + // including the tuple pointer arrays. + for (auto& pair : shaped_buffer->buffers()) { + const ShapeIndex& index = pair.first; + se::DeviceMemoryBase& memory_base = pair.second; + const Shape& subshape = ShapeUtil::GetSubshape(on_device_shape, index); + TF_ASSIGN_OR_RETURN(memory_base, + allocator->Allocate(shaped_buffer->device_ordinal(), + GetByteSizeRequirement(subshape))); } - return std::move(buffer_pointers); + + return std::move(shaped_buffer); +} + +StatusOr> +TransferManager::AllocateScopedShapedBuffer(const Shape& on_host_shape, + DeviceMemoryAllocator* allocator, + int device_ordinal) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr unscoped_buffer, + AllocateShapedBuffer(on_host_shape, allocator, device_ordinal)); + return ScopedShapedBuffer::MakeScoped(unscoped_buffer.get(), allocator); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index fdc123e54eb7f754c12510bef551b98da01b585d..9f2b5c4aecf0b52f610171e0c2755de577b2bd9e 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -44,55 +44,47 @@ class TransferManager { // Returns the ID of the platform that this transfer manager acts on. virtual perftools::gputools::Platform::Id PlatformId() const = 0; - // Transfers the region into the provided literal using the provided - // executor. device_shape is the shape, including layout, of the data on the - // device, while literal_shape will be the shape for the literal. device_shape - // and literal_shape must be compatible, but need not have the same layout. - // TODO(b/66694934): Remove TransferLiteral* methods which accept bare - // DeviceMemoryBase. - virtual Status TransferLiteralFromDevice( - perftools::gputools::StreamExecutor* executor, - const perftools::gputools::DeviceMemoryBase& region, - const Shape& device_shape, const Shape& literal_shape, - Literal* literal) = 0; - - // Transfers the given literal into the provided region output parameter, - // using the given executor. - virtual Status TransferLiteralToDevice( - perftools::gputools::StreamExecutor* executor, const Literal& literal, - perftools::gputools::DeviceMemoryBase* region) = 0; - - // Transfers the data held in the given ShapedBuffer into the provided literal - // using the provided executor. literal_shape will be the shape for the - // literal. The shape of the ShapedBuffer and literal_shape must be - // compatible, but need not have the same layout. + // Returns the shape of the on-device representation for the given shape on + // the host. This is intended for use with ShapedBuffer where buffers are + // pre-allocated by the host, e.g. TransferLiteralToDevice, without the user + // needing to consider device-specific behaviors. + virtual Shape HostShapeToDeviceShape(const Shape& host_shape) const { + return host_shape; + } + + // Returns a literal containing the data held in the given ShapedBuffer. + // using the provided executor. The optional literal_shape will be the shape + // for the literal. The shape of the ShapedBuffer and + // DeviceShape(literal_shape) must be compatible, but need not have the same + // layout. virtual StatusOr> TransferLiteralFromDevice( perftools::gputools::StreamExecutor* executor, const ShapedBuffer& device_buffer) = 0; // Transfers the given literal into the previously allocated device memory - // represented by the given ShapedBuffer using the given executor. + // represented by the given ShapedBuffer using the given executor. The shape + // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible, + // but need not have the same layout virtual Status TransferLiteralToDevice( perftools::gputools::StreamExecutor* executor, const Literal& literal, const ShapedBuffer& device_buffer) = 0; + // Convenience methods for transferring an array to or from the device at a + // known address. This avoids having to construct a ShapedBuffer just to + // transfer an array at a known address. + Status TransferArrayToDevice( + perftools::gputools::StreamExecutor* executor, const Literal& literal, + const perftools::gputools::DeviceMemoryBase& dest); + StatusOr> TransferArrayFromDevice( + perftools::gputools::StreamExecutor* executor, const Shape& shape, + const perftools::gputools::DeviceMemoryBase& source); + // Transfers the given literal into the Infeed interface of the device, // using the given executor. virtual Status TransferLiteralToInfeed( perftools::gputools::StreamExecutor* executor, const Literal& literal) = 0; - // Transfer a memory block of the given size from 'source' buffer to the - // Infeed interface of the device using the given executor. - // - // size is the size to transfer from source in bytes. - // - // source is the source data that must be in the target-dependent layout that - // the Infeed HLO used in the computation expects. - virtual Status TransferBufferToInfeed( - perftools::gputools::StreamExecutor* executor, int64 size, - const void* source) = 0; - // Transfers the given literal from the Outfeed interface of the device, // using the given executor. virtual Status TransferLiteralFromOutfeed( @@ -104,37 +96,26 @@ class TransferManager { tensorflow::gtl::ArraySlice executor) = 0; - // Shallow copy a tuple from the device and create a DeviceMemoryBase object - // for each element in the tuple. A DeviceMemoryBase object refers to the - // buffer containing the data of that element. The DeviceMemoryBase objects - // are returned as a vector. - virtual StatusOr> - ShallowCopyTupleFromDevice( - perftools::gputools::StreamExecutor* executor, - const perftools::gputools::DeviceMemoryBase& source, - const Shape& shape) = 0; - // Given an allocated ShapedBuffer, constructs the tuple index table(s) in // each buffer of the given ShapedBuffer corresponding to tuple shapes. If the // ShapedBuffer is array-shaped this method does nothing. Status WriteTupleIndexTables(perftools::gputools::StreamExecutor* executor, const ShapedBuffer& device_buffer); - // Returns all buffer pointers that the tuple `source` refers to. Unlike - // ShallowCopyTupleFromDevice, this function gather buffer pointers in nested - // tuples as well. Also, the returned DeviceMemoryBase objects are - // deduplicated. - StatusOr> - GatherBufferPointersFromTuple( - perftools::gputools::StreamExecutor* executor, - const perftools::gputools::DeviceMemoryBase& source, const Shape& shape); - // Determines the byte size requirement for the given shape on the underlying // architecture. This will be used to allocate an appropriately sized memory // region for a host-to-device transfer. virtual int64 GetByteSizeRequirement(const Shape& shape) const = 0; - typedef std::unique_ptr (*TransferManagerCreationFunction)(); + // Allocate a ShapedBuffer which can hold data with the given on-host + // shape. The on-device shape may be different as indicated by + // HostShapeToDeviceShape. + StatusOr> AllocateShapedBuffer( + const Shape& on_host_shape, DeviceMemoryAllocator* allocator, + int device_ordinal); + StatusOr> AllocateScopedShapedBuffer( + const Shape& on_host_shape, DeviceMemoryAllocator* allocator, + int device_ordinal); ///// // The TransferManager class also serves as a point to register objects for @@ -144,6 +125,7 @@ class TransferManager { // assumed to be a singleton, so no ownership is transferred. // // Precondition: a platform kind must not be registered more than once. + typedef std::unique_ptr (*TransferManagerCreationFunction)(); static void RegisterTransferManager( perftools::gputools::Platform::Id platform_id, TransferManagerCreationFunction transfer_manager); @@ -154,6 +136,17 @@ class TransferManager { const perftools::gputools::Platform* platform); protected: + // Transfer a memory block of the given size from 'source' buffer to the + // Infeed interface of the device using the given executor. + // + // size is the size to transfer from source in bytes. + // + // source is the source data that must be in the target-dependent layout that + // the Infeed HLO used in the computation expects. + virtual Status TransferBufferToInfeed( + perftools::gputools::StreamExecutor* executor, int64 size, + const void* source) = 0; + // Transfer a memory block of the given size from the device source into the // 'destination' buffer. // @@ -172,10 +165,9 @@ class TransferManager { const void* source, perftools::gputools::DeviceMemoryBase* destination); // Writes the given device-memory pointers in 'elements' to the given region - // to construct a tuple in the platform-specific tuple representation. This - // can handle nested tuples as well. In the nested case, the element - // DeviceMemoryBase points to another array of pointers on the device. - virtual Status WriteTuplePointersToDevice( + // to construct a tuple index table in the platform-specific tuple + // representation. + virtual Status WriteSingleTupleIndexTable( perftools::gputools::StreamExecutor* executor, tensorflow::gtl::ArraySlice elements, diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index 8c2640adf52f10c387e7a9c09c0d73a09c054919..83185ac49e9b7c386d10d1cbc4e20dcdfdfd6cae 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -42,7 +42,7 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoDot( TransposeFolding::OperandIndices operand_set; for (int64 i = 0; i < dot.operand_count(); ++i) { auto& operand = *dot.operand(i); - if (operand.IsRank2Transpose() && operand.user_count() == 1) { + if (operand.IsRank2Transpose()) { operand_set.push_back(i); } } @@ -58,27 +58,10 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoConvolution( return {}; } - const ConvolutionDimensionNumbers& dnums = - convolution.convolution_dimension_numbers(); - TransposeFolding::OperandIndices operand_set; for (int64 i = 0; i < convolution.operand_count(); ++i) { auto& operand = *convolution.operand(i); - if (operand.opcode() == HloOpcode::kTranspose && - operand.user_count() == 1) { - const auto& transpose_dimensions = operand.dimensions(); - // We can transpose the LHS so long as it doesn't move around spatial - // dimensions because ConvolutionDimensionNumbers doesn't have different - // fields for input and output spatial dimensions. - if (i == 0 && - std::any_of(dnums.spatial_dimensions().begin(), - dnums.spatial_dimensions().end(), - [&](const int64 spatial_dimension) { - return transpose_dimensions[spatial_dimension] != - spatial_dimension; - })) { - continue; - } + if (operand.opcode() == HloOpcode::kTranspose) { operand_set.push_back(i); } } @@ -118,6 +101,10 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { auto& convolution = *pair.first; auto& operand_indices = pair.second; + if (operand_indices.empty()) { + return false; + } + const ConvolutionDimensionNumbers& dnums = convolution.convolution_dimension_numbers(); ConvolutionDimensionNumbers new_dnums = dnums; @@ -137,8 +124,9 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { transpose_dimensions[dnums.input_batch_dimension()]); new_dnums.set_input_feature_dimension( transpose_dimensions[dnums.input_feature_dimension()]); - for (const auto& spatial_dimension : dnums.spatial_dimensions()) { - CHECK_EQ(spatial_dimension, transpose_dimensions[spatial_dimension]); + for (auto& input_spatial_dimension : + *new_dnums.mutable_input_spatial_dimensions()) { + input_spatial_dimension = transpose_dimensions[input_spatial_dimension]; } new_lhs = &transpose_operand; } else { diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 00462f9be1e9beb2f2694060ebfaa70b0b9dd4a0..caa1a111ad880b9dee62c1c94e32e8275c196fbf 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -64,9 +64,12 @@ TEST_F(TransposeFoldingTest, FoldDotTranspose) { HloInstruction* transpose_y = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0})); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {2, 2}), /*opcode=*/HloOpcode::kDot, - /*lhs=*/x, /*rhs=*/transpose_y)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x, + /*rhs=*/transpose_y, dot_dnums)); HloModule module("test_module"); HloComputation* entry_computation = @@ -104,9 +107,12 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) { HloInstruction* transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {2, 3}), const1, {1, 0})); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1, 3}), /*opcode=*/HloOpcode::kDot, - /*lhs=*/transpose0, /*rhs=*/transpose1)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + ShapeUtil::MakeShape(F32, {1, 3}), + /*lhs=*/transpose0, /*rhs=*/transpose1, dot_dnums)); HloModule module("test_module"); HloComputation* entry_computation = @@ -169,9 +175,12 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeInWhile) { HloInstruction* transpose_y = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0})); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {2, 2}), /*opcode=*/HloOpcode::kDot, - /*lhs=*/x, /*rhs=*/transpose_y)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x, + /*rhs=*/transpose_y, dot_dnums)); HloModule module("test_module"); HloComputation* entry_computation = @@ -362,10 +371,82 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { EXPECT_EQ( dnums.input_batch_dimension(), new_conv->convolution_dimension_numbers().input_feature_dimension()); - EXPECT_EQ(dnums.spatial_dimensions(0), - new_conv->convolution_dimension_numbers().spatial_dimensions(0)); - EXPECT_EQ(dnums.spatial_dimensions(1), - new_conv->convolution_dimension_numbers().spatial_dimensions(1)); + EXPECT_EQ( + dnums.input_spatial_dimensions(0), + new_conv->convolution_dimension_numbers().input_spatial_dimensions(0)); + EXPECT_EQ( + dnums.input_spatial_dimensions(1), + new_conv->convolution_dimension_numbers().input_spatial_dimensions(1)); + EXPECT_EQ( + dnums.output_spatial_dimensions(0), + new_conv->convolution_dimension_numbers().output_spatial_dimensions(0)); + EXPECT_EQ( + dnums.output_spatial_dimensions(1), + new_conv->convolution_dimension_numbers().output_spatial_dimensions(1)); +} + +// Test that a transpose of every dimension in the activations gets folded into +// convolution. +TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { + auto builder = HloComputation::Builder("entry_computation"); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {3, 2, 1, 1}), + /*name=*/"x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), + /*name=*/"y")); + HloInstruction* transpose_x = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 3, 2})); + auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); + Window window; + for (int i = 0; i < 2; ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_base_dilation(1); + dim->set_window_dilation(1); + dim->set_stride(1); + dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); + } + StatusOr conv_shape = ShapeInference::InferConvolveShape( + transpose_x->shape(), y->shape(), window, dnums); + EXPECT_IS_OK(conv_shape); + HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + conv_shape.ValueOrDie(), transpose_x, y, window, dnums)); + + HloModule module("test_module"); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build(conv)); + FoldTranspose(&module); + + // Instructions after folding: x, y, and the convolution. + std::unordered_set instruction_set( + entry_computation->instructions().begin(), + entry_computation->instructions().end()); + EXPECT_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; + EXPECT_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; + EXPECT_EQ(1, instruction_set.size()) + << "entry_computation should contain exactly 3 instructions."; + HloInstruction* new_conv = *instruction_set.begin(); + EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode()); + EXPECT_EQ(dnums.input_feature_dimension(), + new_conv->convolution_dimension_numbers().input_batch_dimension()); + EXPECT_EQ( + dnums.input_batch_dimension(), + new_conv->convolution_dimension_numbers().input_feature_dimension()); + EXPECT_EQ( + dnums.input_spatial_dimensions(0), + new_conv->convolution_dimension_numbers().input_spatial_dimensions(1)); + EXPECT_EQ( + dnums.input_spatial_dimensions(1), + new_conv->convolution_dimension_numbers().input_spatial_dimensions(0)); + EXPECT_EQ( + dnums.output_spatial_dimensions(0), + new_conv->convolution_dimension_numbers().output_spatial_dimensions(0)); + EXPECT_EQ( + dnums.output_spatial_dimensions(1), + new_conv->convolution_dimension_numbers().output_spatial_dimensions(1)); } } // namespace diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index 8f63c92e5b957189ad474459d4eed53986cecaae..066ffcd7e958ed40b324dc65da209b33bc0f98f9 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -88,8 +88,6 @@ HloOpcode BinaryOperationToHloOpcode(BinaryOperation binop) { return HloOpcode::kAtan2; case BINOP_COMPLEX: return HloOpcode::kComplex; - case BINOP_DOT: - return HloOpcode::kDot; case BINOP_MUL: return HloOpcode::kMultiply; case BINOP_ADD: @@ -765,6 +763,54 @@ StatusOr UserComputation::AddWhileInstruction( return handle; } +StatusOr UserComputation::AddConditionalInstruction( + const ConditionalRequest& conditional_request, + const UserComputation& true_computation, + const UserComputation& false_computation) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* pred, + LookUpRequest(conditional_request.predicate())); + TF_ASSIGN_OR_RETURN(const OperationRequest* true_operand, + LookUpRequest(conditional_request.true_operand())); + TF_ASSIGN_OR_RETURN(const OperationRequest* false_operand, + LookUpRequest(conditional_request.false_operand())); + + VersionedComputationHandle::Version true_computation_version = + true_computation.version(); + TF_ASSIGN_OR_RETURN( + std::shared_ptr true_computation_shape, + true_computation.ComputeProgramShape(true_computation_version)); + + VersionedComputationHandle::Version false_computation_version = + false_computation.version(); + TF_ASSIGN_OR_RETURN( + std::shared_ptr false_computation_shape, + false_computation.ComputeProgramShape(false_computation_version)); + + TF_ASSIGN_OR_RETURN(Shape inferred_shape, + ShapeInference::InferConditionalShape( + pred->output_shape(), true_operand->output_shape(), + false_operand->output_shape(), + *true_computation_shape, *false_computation_shape)); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = inferred_shape; + request.add_embedded_computation_versions(true_computation_version); + request.add_embedded_computation_versions(false_computation_version); + *request.mutable_request()->mutable_conditional_request() = + conditional_request; + + VLOG(1) << "AddConditionalInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << conditional_request.ShortDebugString(); + return handle; +} + StatusOr UserComputation::AddBroadcastInstruction( const BroadcastRequest& broadcast_request) { tensorflow::mutex_lock lock(mutex_); @@ -994,6 +1040,32 @@ StatusOr UserComputation::AddConvertInstruction( return handle; } +StatusOr UserComputation::AddBitcastConvertInstruction( + const ConvertRequest& convert_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, + LookUpRequest(convert_request.operand())); + + TF_ASSIGN_OR_RETURN(Shape new_shape, ShapeInference::InferConvertShape( + operand->output_shape(), + convert_request.new_element_type())); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = new_shape; + *request.mutable_request()->mutable_bitcast_convert_request() = + convert_request; + + VLOG(1) << "AddBitcastConvertInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << convert_request.ShortDebugString(); + return handle; +} + StatusOr UserComputation::AddReducePrecisionInstruction( const ReducePrecisionRequest& reduce_precision_request) { tensorflow::mutex_lock lock(mutex_); @@ -1056,7 +1128,7 @@ StatusOr UserComputation::AddCrossReplicaSumInstruction( TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(cross_replica_sum_request.operand())); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCrossReplicaSumShape( - operand->output_shape())); + {&operand->output_shape()})); ComputationDataHandle handle = CreateComputationDataHandle(); @@ -1181,6 +1253,33 @@ StatusOr UserComputation::AddCustomCallInstruction( return handle; } +StatusOr UserComputation::AddDotInstruction( + const DotRequest& dot_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, + LookUpRequest(dot_request.lhs())); + TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, + LookUpRequest(dot_request.rhs())); + + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDotOpShape( + lhs->output_shape(), rhs->output_shape(), + dot_request.dimension_numbers())); + + const ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = shape; + *request.mutable_request()->mutable_dot_request() = dot_request; + + VLOG(1) << "AddDotInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << dot_request.ShortDebugString(); + return handle; +} + StatusOr UserComputation::AddUnaryInstruction( const UnaryOpRequest& unary_request) { tensorflow::mutex_lock lock(mutex_); @@ -1407,7 +1506,7 @@ StatusOr LookUpRequest( return &session_computation.requests().at(handle_value); } -// Returns the OperationRequestion corresponding to the root (result) of the +// Returns the OperationRequest corresponding to the root (result) of the // session computation. StatusOr GetRoot( VersionedComputationHandle::Version version, @@ -1453,8 +1552,8 @@ UserComputation::ComputeProgramShape( request.request().parameter_request(); int64 param_no = parameter_request.parameter(); // Parameters may be out of order so expand ProgramShape parameters - // until - // it is at least large enough to hold the current parameter number. + // until it is at least large enough to hold the current parameter + // number. while (program_shape->parameters_size() <= param_no) { program_shape->add_parameters(); program_shape->add_parameter_names(); @@ -1603,6 +1702,15 @@ void PureFunctionalVisitor(const SessionComputation& session_computation, break; } + case OpRequest::kDotRequest: { + const DotRequest& dot_request = request.request().dot_request(); + PureFunctionalVisitor(session_computation, dot_request.lhs(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, dot_request.rhs(), + num_parameters, visited, is_functional); + break; + } + case OpRequest::kSendRequest: { *is_functional = false; break; @@ -1713,6 +1821,14 @@ void PureFunctionalVisitor(const SessionComputation& session_computation, break; } + case OpRequest::kBitcastConvertRequest: { + const ConvertRequest& convert_request = + request.request().bitcast_convert_request(); + PureFunctionalVisitor(session_computation, convert_request.operand(), + num_parameters, visited, is_functional); + break; + } + case OpRequest::kWhileRequest: { const WhileRequest& while_request = request.request().while_request(); PureFunctionalVisitor(session_computation, while_request.init(), @@ -1723,6 +1839,23 @@ void PureFunctionalVisitor(const SessionComputation& session_computation, break; } + case OpRequest::kConditionalRequest: { + const ConditionalRequest& conditional_request = + request.request().conditional_request(); + PureFunctionalVisitor(session_computation, + conditional_request.predicate(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + conditional_request.true_operand(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + conditional_request.false_operand(), num_parameters, + visited, is_functional); + // TODO(b/32495713): We aren't checking the true and false computations + // themselves. + break; + } + case OpRequest::kTernaryOpRequest: { const TernaryOpRequest& ternary_op_request = request.request().ternary_op_request(); @@ -1951,6 +2084,21 @@ UserComputation::GetEmbeddedComputations( break; } + case OpRequest::kConditionalRequest: { + CHECK_EQ(2, request.embedded_computation_versions_size()); + const ConditionalRequest& conditional_request = + request.request().conditional_request(); + const VersionedComputationHandle true_computation_versioned_handle = { + conditional_request.true_computation(), + request.embedded_computation_versions(0)}; + computations.push_back(true_computation_versioned_handle); + const VersionedComputationHandle false_computation_versioned_handle = + {conditional_request.false_computation(), + request.embedded_computation_versions(1)}; + computations.push_back(false_computation_versioned_handle); + break; + } + default: // No embedded computation. break; @@ -2037,6 +2185,16 @@ Status UserComputation::RemapEmbeddedComputations( TF_RETURN_IF_ERROR(update(while_request->mutable_body())); break; } + case OpRequest::kConditionalRequest: { + TF_RET_CHECK(2 == request.embedded_computation_versions_size()); + ConditionalRequest* conditional_request = + request.mutable_request()->mutable_conditional_request(); + TF_RETURN_IF_ERROR( + update(conditional_request->mutable_true_computation())); + TF_RETURN_IF_ERROR( + update(conditional_request->mutable_false_computation())); + break; + } default: // No embedded computation. TF_RET_CHECK(0 == request.embedded_computation_versions_size()); @@ -2370,12 +2528,28 @@ static void ForEachOperand( break; } + case OpRequest::kBitcastConvertRequest: { + const ConvertRequest& convert_request = + request.request().bitcast_convert_request(); + apply(convert_request.operand()); + break; + } + case OpRequest::kWhileRequest: { const WhileRequest& while_request = request.request().while_request(); apply(while_request.init()); break; } + case OpRequest::kConditionalRequest: { + const ConditionalRequest& conditional_request = + request.request().conditional_request(); + apply(conditional_request.predicate()); + apply(conditional_request.true_operand()); + apply(conditional_request.false_operand()); + break; + } + case OpRequest::kTernaryOpRequest: { const TernaryOpRequest& ternary_op_request = request.request().ternary_op_request(); @@ -2412,6 +2586,13 @@ static void ForEachOperand( break; } + case OpRequest::kDotRequest: { + const DotRequest& dot_request = request.request().dot_request(); + apply(dot_request.rhs()); + apply(dot_request.lhs()); + break; + } + case OpRequest::kUnaryOpRequest: { const UnaryOpRequest& unary_op_request = request.request().unary_op_request(); @@ -2691,13 +2872,22 @@ void ComputationLowerer::Visit( break; } + case OpRequest::kDotRequest: { + const DotRequest& dot_request = request.request().dot_request(); + HloInstruction* lhs = lookup_instruction(dot_request.lhs()); + HloInstruction* rhs = lookup_instruction(dot_request.rhs()); + hlo_instruction = add_instruction(HloInstruction::CreateDot( + request.output_shape(), lhs, rhs, dot_request.dimension_numbers())); + break; + } + case OpRequest::kCrossReplicaSumRequest: { const CrossReplicaSumRequest& cross_replica_sum_request = request.request().cross_replica_sum_request(); HloInstruction* operand = lookup_instruction(cross_replica_sum_request.operand()); hlo_instruction = add_instruction(HloInstruction::CreateCrossReplicaSum( - request.output_shape(), operand)); + request.output_shape(), {operand})); break; } @@ -2954,6 +3144,15 @@ void ComputationLowerer::Visit( break; } + case OpRequest::kBitcastConvertRequest: { + const ConvertRequest& convert_request = + request.request().bitcast_convert_request(); + HloInstruction* operand = lookup_instruction(convert_request.operand()); + hlo_instruction = add_instruction(HloInstruction::CreateBitcastConvert( + request.output_shape(), operand)); + break; + } + case OpRequest::kWhileRequest: { const WhileRequest& while_request = request.request().while_request(); CHECK_EQ(2, request.embedded_computation_versions_size()); @@ -2971,6 +3170,30 @@ void ComputationLowerer::Visit( break; } + case OpRequest::kConditionalRequest: { + const ConditionalRequest& conditional_request = + request.request().conditional_request(); + CHECK_EQ(2, request.embedded_computation_versions_size()); + VersionedComputationHandle::Version true_computation_version = + request.embedded_computation_versions(0); + HloComputation* true_computation = ResolveComputation( + conditional_request.true_computation(), true_computation_version); + VersionedComputationHandle::Version false_computation_version = + request.embedded_computation_versions(1); + HloComputation* false_computation = ResolveComputation( + conditional_request.false_computation(), false_computation_version); + HloInstruction* predicate = + lookup_instruction(conditional_request.predicate()); + HloInstruction* true_operand = + lookup_instruction(conditional_request.true_operand()); + HloInstruction* false_operand = + lookup_instruction(conditional_request.false_operand()); + hlo_instruction = add_instruction(HloInstruction::CreateConditional( + request.output_shape(), predicate, true_operand, true_computation, + false_operand, false_computation)); + break; + } + case OpRequest::kTernaryOpRequest: { const TernaryOpRequest& ternary_op_request = request.request().ternary_op_request(); @@ -2978,6 +3201,25 @@ void ComputationLowerer::Visit( HloInstruction* rhs = lookup_instruction(ternary_op_request.rhs()); HloInstruction* ehs = lookup_instruction(ternary_op_request.ehs()); auto hlo_opcode = TernaryOperationToHloOpcode(ternary_op_request.triop()); + + if (debug_options_.xla_eliminate_hlo_implicit_broadcast()) { + if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) { + // lhs side is being implicitly broadcast. Change to explicit. + lhs = + ImplicitBroadcastToExplicitBroadcast(lhs, request.output_shape()); + } + + if (!ShapeUtil::SameDimensions(request.output_shape(), rhs->shape())) { + rhs = + ImplicitBroadcastToExplicitBroadcast(rhs, request.output_shape()); + } + + if (!ShapeUtil::SameDimensions(request.output_shape(), ehs->shape())) { + ehs = + ImplicitBroadcastToExplicitBroadcast(ehs, request.output_shape()); + } + } + hlo_instruction = add_instruction(HloInstruction::CreateTernary( request.output_shape(), hlo_opcode, lhs, rhs, ehs)); break; @@ -3082,8 +3324,7 @@ void ComputationLowerer::Visit( lhs = (lhs == operand_to_broadcast) ? broadcasted_operand : lhs; rhs = (rhs == operand_to_broadcast) ? broadcasted_operand : rhs; } - if (debug_options_.xla_eliminate_hlo_implicit_broadcast() && - binary_op_request.binop() != BINOP_DOT) { + if (debug_options_.xla_eliminate_hlo_implicit_broadcast()) { if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) { // lhs side is being implicitly broadcast. Change to explicit. lhs = @@ -3137,7 +3378,7 @@ void ComputationLowerer::Visit( LOG(FATAL) << "Unexpected request type: " << request.request().op_case(); } (*instructions)[handle.handle()] = hlo_instruction; -} +} // NOLINT(readability/fn_size) } // namespace diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h index ac879ce55a75f6241a39f935b79017be46c1816b..8a78d520e19024f5e397d6e0c2f4e0523264e176 100644 --- a/tensorflow/compiler/xla/service/user_computation.h +++ b/tensorflow/compiler/xla/service/user_computation.h @@ -70,7 +70,7 @@ class UserComputation { // Enqueues a pad instruction onto this user computation. StatusOr AddPadInstruction( - const PadRequest& parameter_request); + const PadRequest& pad_request); // Enqueues a tracing instruction onto this user computation. // Returns an error status if the operand cannot be resolved. @@ -105,7 +105,7 @@ class UserComputation { // Enqueues a ternary instruction onto this user computation. // Returns an error status if the operand indices are out of bounds. StatusOr AddTernaryInstruction( - const TernaryOpRequest& request); + const TernaryOpRequest& ternary_request); // Enqueues a variadic instruction onto this user computation. // Returns an error status if the operand indices are out of bounds. @@ -153,6 +153,10 @@ class UserComputation { StatusOr AddCustomCallInstruction( const CustomCallRequest& custom_call_request); + // Enqueues a dot instruction onto this user computation. + StatusOr AddDotInstruction( + const DotRequest& dot_request); + // Enqueues a broadcast instruction onto this user computation. StatusOr AddBroadcastInstruction( const BroadcastRequest& broadcast_request); @@ -179,26 +183,30 @@ class UserComputation { // Enqueues a concatenate instruction onto this user computation. StatusOr AddConcatenateInstruction( - const ConcatenateRequest& slice_request); + const ConcatenateRequest& concatenate_request); // Enqueues a convert instruction onto this user computation. StatusOr AddConvertInstruction( const ConvertRequest& convert_request); + // Enqueues a bitcast element instruction onto this user computation. + StatusOr AddBitcastConvertInstruction( + const ConvertRequest& convert_request); + // Enqueues a reduce instruction onto this user computation. StatusOr AddReduceInstruction( const ReduceRequest& reduce_request, - const UserComputation& reduction_computation); + const UserComputation& to_apply_computation); // Enqueues a windowed reduce instruction onto this user computation. StatusOr AddReduceWindowInstruction( const ReduceWindowRequest& reduce_window_request, - const UserComputation& reduction_computation); + const UserComputation& to_apply_computation); // Enqueues a select-and-scatter instruction onto this user // computation. StatusOr AddSelectAndScatterInstruction( - const SelectAndScatterRequest& scatter_to_selected_window_element_request, + const SelectAndScatterRequest& select_and_scatter_request, const UserComputation& select_computation, const UserComputation& scatter_computation); @@ -212,6 +220,12 @@ class UserComputation { const UserComputation& condition_computation, const UserComputation& body_computation); + // Enqueues a conditional instruction on this user computation. + StatusOr AddConditionalInstruction( + const ConditionalRequest& conditional_request, + const UserComputation& true_computation, + const UserComputation& false_computation); + // Enqueues a Send instruction onto this user computation. Status AddSendInstruction(const SendRequest& send_request); diff --git a/tensorflow/compiler/xla/service/user_computation_test.cc b/tensorflow/compiler/xla/service/user_computation_test.cc index 5afaf226ae0cce7e9afc966c6b4adf838aeebc91..ca02115863e6906ef709ba63259024877e0dcef4 100644 --- a/tensorflow/compiler/xla/service/user_computation_test.cc +++ b/tensorflow/compiler/xla/service/user_computation_test.cc @@ -65,6 +65,7 @@ TEST_F(UserComputationTest, SimpleComputation) { OutfeedRequest outfeed_request; *outfeed_request.mutable_operand() = constant_handle; + *outfeed_request.mutable_shape() = kVectorShape; outfeed_request.set_outfeed_config("abc"); TF_ASSERT_OK(computation.AddOutfeedInstruction(outfeed_request)); @@ -334,50 +335,5 @@ TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) { operands[1]->opcode() == HloOpcode::kBroadcast); } -TEST_F(UserComputationTest, SkipDotInEliminatingImplicitBroadcast) { - auto debug_options = DebugOptions(); - debug_options.set_xla_eliminate_hlo_implicit_broadcast(true); - - // %a = Param({1, 3}); - // %b = Param({3, 1}); - // %dot = Dot(%a, %b); - ComputationHandle handle; - handle.set_handle(123); - UserComputation computation("TheComputation", handle); - - ParameterRequest a_request; - *a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 3}); - a_request.set_name("a"); - a_request.set_parameter(0); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, - computation.AddParameterInstruction(a_request)); - - ParameterRequest b_request; - *b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {3, 1}); - b_request.set_name("b"); - b_request.set_parameter(1); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, - computation.AddParameterInstruction(b_request)); - - BinaryOpRequest dot; - dot.set_binop(BINOP_DOT); - *dot.mutable_lhs() = a_handle; - *dot.mutable_rhs() = b_handle; - TF_ASSERT_OK(computation.AddBinaryInstruction(dot).status()); - - auto hlo_resolver = [](const VersionedComputationHandle& handle) { - return nullptr; - }; - VersionedComputationHandle latest_version = computation.GetVersionedHandle(); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver, - debug_options)); - - EXPECT_EQ(3, hlo_computation->instruction_count()); -} - } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index 2fac914892e07b1935581e770293ddf00af7bc41..fb0e6f7ce00cff48727dc55bf45c07994643331d 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -289,7 +289,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // Don't try this transformation if the while loop isn't removable, since if // it succeeds ultimately we're going to have to replace the old while loop // with a new one. - if (!while_op->parent()->IsRemovable(while_op)) { + if (!while_op->parent()->IsRemovable(while_op) || while_op->HasSideEffect()) { VLOG(2) << "Can't remove dead parameters from non-removable while op."; return false; } @@ -306,6 +306,8 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { return false; } + auto print_no_metadata = HloPrintOptions().set_print_metadata(false); + // Bail if param0 of while_cond or while_body has users which aren't of type // get-tuple-element. for (const HloInstruction* instr : {while_body->parameter_instruction(0), @@ -313,9 +315,10 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { for (const HloInstruction* user : instr->users()) { if (user->opcode() != HloOpcode::kGetTupleElement) { VLOG(2) << "Cowardly refusing to analyze while loop with " - << instr->ToStringNoMetadata() - << " used by non-GTE instruction " << user->ToStringNoMetadata() - << " in computation " << instr->parent()->name(); + << instr->ToString(print_no_metadata) + << " used by non-GTE instruction " + << user->ToString(print_no_metadata) << " in computation " + << instr->parent()->name(); return false; } } @@ -342,7 +345,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // // Careful: HloInstruction::operand_index returns the first index the // operand appears in, but it may appear more than once! - if (user->user_count() == 1 && user->users()[0] == while_body_root && + if (user->user_count() == 1 && user->users().front() == while_body_root && while_body_root->operand_index(user) == user->tuple_index() && std::count(while_body_root->operands().begin(), while_body_root->operands().end(), user) == 1) { @@ -351,7 +354,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { used_tuple_indices.insert(user->tuple_index()); if (used_tuple_indices.size() == tuple_size) { - VLOG(2) << "Loop " << while_op->ToStringNoMetadata() + VLOG(2) << "Loop " << while_op->ToString(print_no_metadata) << " uses all of its inputs; no simplification possible."; return false; } @@ -375,7 +378,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { used_tuple_indices.insert(i); if (used_tuple_indices.size() == tuple_size) { - VLOG(2) << "Loop " << while_op->ToStringNoMetadata() + VLOG(2) << "Loop " << while_op->ToString(print_no_metadata) << " uses all of its inputs; no simplification possible."; return false; } @@ -387,7 +390,8 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { CHECK_LT(used_tuple_indices.size(), tuple_size); VLOG(1) << "Eliminating " << tuple_size - used_tuple_indices.size() - << " elements from tuple of " << while_op->ToStringNoMetadata(); + << " elements from tuple of " + << while_op->ToString(print_no_metadata); // Build up maps from the old/new to the new/old tuple indices. std::vector new_to_old_tuple_idx(used_tuple_indices.begin(), @@ -403,6 +407,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // Compute the shape of the while op after we remove the dead indices. std::vector new_while_tuple_elem_shapes; + new_while_tuple_elem_shapes.reserve(new_to_old_tuple_idx.size()); for (int64 old_idx : new_to_old_tuple_idx) { new_while_tuple_elem_shapes.push_back( while_init->shape().tuple_shapes(old_idx)); @@ -430,7 +435,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { continue; } CHECK_EQ(user->opcode(), HloOpcode::kGetTupleElement) - << user->ToStringNoMetadata(); + << user->ToString(print_no_metadata); int64 old_idx = user->tuple_index(); auto new_idx_iter = old_to_new_tuple_idx.find(old_idx); @@ -443,15 +448,16 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // This is a GTE of an index that we've removed. Remove it from the // cloned computation. CHECK(user->user_count() == 0 || - user->user_count() == 1 && user->users()[0] == while_body_root) - << "Instruction " << user->ToStringNoMetadata() + user->user_count() == 1 && + user->users().front() == while_body_root) + << "Instruction " << user->ToString(print_no_metadata) << " should be unused (except by root of while body), but has " "users: {" << tensorflow::str_util::Join( user->users(), ", ", - [](string* out, const HloInstruction* instr) { + [&](string* out, const HloInstruction* instr) { tensorflow::strings::StrAppend( - out, instr->ToStringNoMetadata()); + out, instr->ToString(print_no_metadata)); }) << "}"; @@ -469,6 +475,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { std::unordered_map> while_body_replacements = make_while_computation_replacements(while_body); std::vector new_while_body_root_elems; + new_while_body_root_elems.reserve(new_to_old_tuple_idx.size()); for (int64 old_idx : new_to_old_tuple_idx) { new_while_body_root_elems.push_back( while_body_root->mutable_operand(old_idx)); @@ -483,6 +490,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // clean this up in the common case where while_init is a tuple op. (It's // definitely tuple-shaped, but it's not necessarily a tuple op.) std::vector new_while_init_elems; + new_while_init_elems.reserve(new_to_old_tuple_idx.size()); for (int64 old_idx : new_to_old_tuple_idx) { new_while_init_elems.push_back( computation->AddInstruction(HloInstruction::CreateGetTupleElement( @@ -554,7 +562,7 @@ static StatusOr TryRemoveWhileLoop(HloInstruction* while_op) { // the loop aren't removed, just cloned and added back to the loop. // Nevertheless our infrastructure sees loop simplification as removal of // these nodes and currently doesn't allow it. - if (!while_op->parent()->IsRemovable(while_op)) { + if (!while_op->parent()->IsRemovable(while_op) || while_op->HasSideEffect()) { VLOG(2) << "Not attempting to remove while loop it is not removable: " << while_op->ToShortString(); return false; diff --git a/tensorflow/compiler/xla/shape_layout.cc b/tensorflow/compiler/xla/shape_layout.cc index 5bf9842a6ce7be747f58c10f302f85c6f82ac6f9..789eba5780d37e1fd4d80ec881855951c8bba0eb 100644 --- a/tensorflow/compiler/xla/shape_layout.cc +++ b/tensorflow/compiler/xla/shape_layout.cc @@ -32,13 +32,13 @@ tensorflow::Status ShapeLayout::CopyLayoutFromShape(const Shape& other_shape) { return tensorflow::Status::OK(); } -tensorflow::Status ShapeLayout::AssignLayoutToShape(Shape* other_shape) const { - if (!ShapeUtil::Compatible(*other_shape, shape_)) { +tensorflow::Status ShapeLayout::AssignLayoutToShape(Shape* to_shape) const { + if (!ShapeUtil::Compatible(*to_shape, shape_)) { return InvalidArgument("Shape %s is not compatible with shape %s", - ShapeUtil::HumanString(*other_shape).c_str(), + ShapeUtil::HumanString(*to_shape).c_str(), ShapeUtil::HumanString(shape()).c_str()); } - *other_shape = shape_; + *to_shape = shape_; return tensorflow::Status::OK(); } diff --git a/tensorflow/compiler/xla/shape_layout.h b/tensorflow/compiler/xla/shape_layout.h index 92564660f21bf1b596c4b9ca04c07eaca27ed192..4c83750f3e6f3c735db66d8e0b86ae3f43e5ca11 100644 --- a/tensorflow/compiler/xla/shape_layout.h +++ b/tensorflow/compiler/xla/shape_layout.h @@ -38,18 +38,19 @@ class ShapeLayout { explicit ShapeLayout(const Shape& shape) : shape_(shape) {} // Assigns the layouts in this ShapeLayout to the Layout fields of the given - // shape. 'shape' and the shape of the ShapeLayout object must be compatible. - tensorflow::Status AssignLayoutToShape(Shape* shape) const; + // shape. 'to_shape' and the shape of the ShapeLayout object must be + // compatible. + tensorflow::Status AssignLayoutToShape(Shape* to_shape) const; // Returns true if the Layouts in this ShapeLayout match the layouts in the // given shape. Returns false otherwise. If the given shape is not compatible // with the ShapeLayout's shape, then false is returned. bool MatchesLayoutInShape(const Shape& shape) const; - // Copies the layout from the given shape into this ShapeLayout. 'shape' must - // be compatible with the ShapeLayout's shape, and 'shape' must have a layout - // (LayoutUtil::HasLayout). - tensorflow::Status CopyLayoutFromShape(const Shape& shape); + // Copies the layout from the given shape into this ShapeLayout. 'other_shape' + // must be compatible with the ShapeLayout's shape, and 'other_shape' must + // have a layout (LayoutUtil::HasLayout). + tensorflow::Status CopyLayoutFromShape(const Shape& other_shape); // Clears (Layout::Clear) all the Layouts stored in this object. void Clear(); diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index bf8d19015079f2ce0bd450594040ed818f94b66b..d752619bd65751779c24f061e44e206d66b01465 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -238,7 +238,7 @@ class ShapeTree { // (or compatible). // index : the index of the element in the shape. See ShapeUtil::GetSubshape // for definition of index. - // data : The data value at this elemnt. + // data : The data value at this element. template void ForEachElement(const Fn& func) const; diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index c0a0e13f073a639baa46151a68b83cfe92215c23..ead9f5c4ce76a8d452dd18f5cd1803a027556637 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -64,30 +65,36 @@ namespace { // the shapes are the same. If compare_layouts is true, then layouts must also // match. bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { - if (ShapeUtil::IsTuple(lhs)) { - return ShapeUtil::IsTuple(rhs) && + if (ShapeUtil::IsTuple(lhs) || ShapeUtil::IsTuple(rhs)) { + return ShapeUtil::IsTuple(lhs) && ShapeUtil::IsTuple(rhs) && ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), [=](const Shape& l, const Shape& r) { return CompareShapes(l, r, compare_layouts); }); + } else if (ShapeUtil::IsOpaque(lhs) || ShapeUtil::IsOpaque(rhs)) { + return ShapeUtil::IsOpaque(lhs) && ShapeUtil::IsOpaque(rhs); } - // Explicitly compare the fields rather than using MessageDifferencer because - // we want empty layouts to be treated identically to missing layouts. + if (compare_layouts) { - if (!ContainersEqual(lhs.layout().minor_to_major(), - rhs.layout().minor_to_major())) { - VLOG(3) << "CompareShapes: lhs layout != rhs layout"; - return false; - } - if (!ContainersEqual(lhs.layout().padded_dimensions(), - rhs.layout().padded_dimensions())) { - VLOG(3) - << "CompareShapes: lhs padded_dimensions != rhs padded_dimensions"; + if (lhs.layout().format() != rhs.layout().format()) { return false; } - if (lhs.layout().padding_value() != rhs.layout().padding_value()) { - VLOG(3) << "CompareShapes: lhs padding value != rhs padding_value"; - return false; + if (LayoutUtil::IsDense(lhs)) { + if (!ContainersEqual(LayoutUtil::MinorToMajor(lhs), + LayoutUtil::MinorToMajor(rhs))) { + VLOG(3) << "CompareShapes: lhs layout != rhs layout"; + return false; + } + if (!ContainersEqual(lhs.layout().padded_dimensions(), + rhs.layout().padded_dimensions())) { + VLOG(3) + << "CompareShapes: lhs padded_dimensions != rhs padded_dimensions"; + return false; + } + if (lhs.layout().padding_value() != rhs.layout().padding_value()) { + VLOG(3) << "CompareShapes: lhs padding value != rhs padding_value"; + return false; + } } } @@ -235,6 +242,7 @@ StatusOr MakeShapeWithLayoutInternal( } /* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) { + CHECK(LayoutUtil::IsDense(*shape)); shape->mutable_layout()->add_minor_to_major(Rank(*shape)); shape->add_dimensions(bound); TF_DCHECK_OK(ValidateShape(*shape)); @@ -329,6 +337,14 @@ StatusOr MakeShapeWithLayoutInternal( return MakeTupleShape(new_elements); } +// Returns the shape of a real or imaginary component. +/* static */ Shape ShapeUtil::ComplexComponentShape( + const Shape& complex_shape) { + CHECK(ElementIsComplex(complex_shape)) << HumanString(complex_shape); + return ChangeElementType(complex_shape, primitive_util::ComplexComponentType( + complex_shape.element_type())); +} + /* static */ bool ShapeUtil::ShapeIs(const Shape& shape, PrimitiveType element_type, std::initializer_list dimensions) { @@ -396,6 +412,26 @@ const string& LowercasePrimitiveTypeName(PrimitiveType s) { static PrimitiveTypeNameGenerator* gen = new PrimitiveTypeNameGenerator(); return gen->LowercaseName(s); } + +StatusOr StringToPrimitiveType(const string& name) { + static std::unordered_map* name_to_type = [] { + static auto* map = new std::unordered_map; + for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { + if (PrimitiveType_IsValid(i)) { + auto value = static_cast(i); + (*map)[LowercasePrimitiveTypeName(value)] = value; + } + } + return map; + }(); + auto found = name_to_type->find(name); + if (found == name_to_type->end()) { + return InvalidArgument("Invalid element type string: \"%s\".", + name.c_str()); + } + return found->second; +} + } // namespace /* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) { @@ -500,17 +536,10 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { comma_list_to_int64s(dimensions_string)); // Extract the primitive element type. - PrimitiveType primitive_type = PRIMITIVE_TYPE_INVALID; - for (PrimitiveType i = - static_cast(PRIMITIVE_TYPE_INVALID + 1); - i < TUPLE; i = static_cast(i + 1)) { - if (tensorflow::str_util::Lowercase(PrimitiveType_Name(i)) == - element_type_string) { - primitive_type = i; - break; - } - } - if (primitive_type == PRIMITIVE_TYPE_INVALID) { + TF_ASSIGN_OR_RETURN(const PrimitiveType primitive_type, + StringToPrimitiveType(element_type_string)); + if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE || + primitive_type == OPAQUE) { return InvalidArgument("Invalid element type string: \"%s\".", element_type_string.c_str()); } @@ -553,6 +582,16 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return SameDimensions(lhs, rhs) && SameElementType(lhs, rhs); } +/* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs, + const Shape& rhs) { + if (lhs.element_type() == TUPLE) { + return rhs.element_type() == TUPLE && + ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), + CompatibleIgnoringElementType); + } + return SameDimensions(lhs, rhs); +} + /* static */ int64 ShapeUtil::GetDimension(const Shape& shape, int64 dimension_number) { return shape.dimensions(GetDimensionNumber(shape, dimension_number)); @@ -684,9 +723,9 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return LayoutUtil::ValidateLayoutInShape(shape); } -/* static */ Shape ShapeUtil::ChangeElementType(const Shape& shape, +/* static */ Shape ShapeUtil::ChangeElementType(const Shape& original, PrimitiveType type) { - Shape new_shape = shape; + Shape new_shape = original; new_shape.set_element_type(type); return new_shape; } @@ -853,7 +892,9 @@ Status ForEachMutableSubshapeHelper( new_shape.add_dimensions(dim); } if (shape.has_layout()) { + CHECK(LayoutUtil::IsDense(shape)); Layout* new_layout = new_shape.mutable_layout(); + new_layout->set_format(DENSE); new_layout->clear_minor_to_major(); for (auto index : Permute(permutation, shape.layout().minor_to_major())) { new_layout->add_minor_to_major(index); @@ -1280,6 +1321,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, shape.mutable_dimensions()->erase(shape.dimensions().begin() + dim_to_delete); if (LayoutUtil::HasLayout(shape)) { Layout* layout = shape.mutable_layout(); + layout->set_format(DENSE); for (size_t i = 0; i < layout->minor_to_major().size();) { if (layout->minor_to_major(i) == dim_to_delete) { layout->mutable_minor_to_major()->erase( diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 82a513a65ad62904e595b650cc02dcf3e8451958..301247d61c5e1ecd428b061594c042ab35a3364e 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -170,7 +171,7 @@ class ShapeUtil { // As above, but for program shapes, returns a string for the form: // // (param_name: f32[42x12], ...) -> f32[24x42] - static string HumanString(const ProgramShape& shape); + static string HumanString(const ProgramShape& program_shape); // Parses a ShapeUtil::HumanString-format shape string back into a shape // object. @@ -190,6 +191,11 @@ class ShapeUtil { // compatibility. static bool Compatible(const Shape& lhs, const Shape& rhs); + // Returns true if the rank and dimension sizes are identical. Element type + // and layout are ignored. Tuple elements are compared recursively for + // compatibility. + static bool CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs); + // Returns whether the lhs and rhs shapes are identical protobufs. static bool Equal(const Shape& lhs, const Shape& rhs); @@ -319,7 +325,8 @@ class ShapeUtil { return shape.element_type() == OPAQUE; } - // Returns whether the shape is an array. + // Returns whether the shape is an array. Note that scalars are considered + // arrays. static bool IsArray(const Shape& shape) { return !IsTuple(shape) && !IsOpaque(shape); } @@ -346,6 +353,10 @@ class ShapeUtil { // shape. E.g. a tuple like (f32, s32, u32) would slice via 1,3 to (s32, u32). static Shape SliceTuple(const Shape& tuple, int64 start, int64 limit); + // Returns the shape of the real/imaginary components of the given complex + // shape. + static Shape ComplexComponentShape(const Shape& complex_shape); + // Shorthand for testing whether a shape is of a given element type and // sequence of dimensions. // @@ -497,8 +508,7 @@ class ShapeUtil { CHECK_EQ(Rank(shape), base.size()); CHECK_EQ(incr.size(), base.size()); CHECK_EQ(count.size(), base.size()); - const Layout& layout = shape.layout(); - const int64 rank = layout.minor_to_major_size(); + const int64 rank = LayoutUtil::MinorToMajor(shape).size(); // Allows handling R0 arrays, such that the visitor function will be called // once with the proper empty indexes. int64 n = -1; @@ -506,7 +516,7 @@ class ShapeUtil { while (n < rank && visitor_function(indexes)) { // Increments dimensions in minor to major order. for (n = 0; n < rank; ++n) { - int64 dim = layout.minor_to_major(n); + int64 dim = LayoutUtil::Minor(shape.layout(), n); indexes[dim] += incr[dim]; if (indexes[dim] < base[dim] + count[dim]) { break; diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 0ba542ad1bec290c35c52a8dd5177893770310fd..3be6d6c4299aff62582c1b9fdc46fb78712f95c8 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -145,6 +145,7 @@ TEST(ShapeUtilTest, IncompatibleTuplesWithSwappedElements) { Shape tuple2 = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(PRED, {4, 5})}); EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2)); + EXPECT_FALSE(ShapeUtil::CompatibleIgnoringElementType(tuple1, tuple2)); } TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentPrimitiveType) { @@ -153,6 +154,7 @@ TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentPrimitiveType) { Shape tuple2 = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(S32, {3, 2})}); EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2)); + EXPECT_TRUE(ShapeUtil::CompatibleIgnoringElementType(tuple1, tuple2)); } TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentDimensions) { @@ -163,20 +165,6 @@ TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentDimensions) { EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2)); } -TEST(ShapeUtilTest, EmptyLayoutEqualsMissingLayout) { - // A shape with a missing layout should be equal to a shape with an empty - // layout. - Shape scalar1 = ShapeUtil::MakeShape(F32, {}); - Shape scalar2 = ShapeUtil::MakeShape(F32, {}); - - EXPECT_TRUE(ShapeUtil::Equal(scalar1, scalar2)); - - scalar1.clear_layout(); // Remove layout field. - scalar2.mutable_layout(); // Create empty layout field. - - EXPECT_TRUE(ShapeUtil::Equal(scalar1, scalar2)); -} - TEST(ShapeUtilTest, CompareShapesWithPaddedDimensionsMismatch) { Shape shape1 = ShapeUtil::MakeShape(F32, {20, 30}); shape1.mutable_layout()->add_padded_dimensions(10); @@ -197,17 +185,17 @@ TEST(ShapeUtilTest, CompareShapesWithPaddingValueMismatch) { EXPECT_FALSE(ShapeUtil::Equal(shape1, shape2)); } -TEST(ShapeUtilTest, ScalarUnpopulatedLayoutEqualsScalarLayout) { - Shape scalar_unpopulated = ShapeUtil::MakeShape(F32, {}); - scalar_unpopulated.clear_layout(); - ASSERT_FALSE(scalar_unpopulated.has_layout()) - << ShapeUtil::HumanStringWithLayout(scalar_unpopulated); +TEST(ShapeUtilTest, ScalarDefaultLayoutEqualsScalarEmptyMin2Maj) { + Shape scalar_default_layout = ShapeUtil::MakeShape(F32, {}); + ASSERT_TRUE(scalar_default_layout.has_layout()) + << ShapeUtil::HumanStringWithLayout(scalar_default_layout); - const Shape scalar_populated = ShapeUtil::MakeShapeWithLayout(F32, {}, {}); - ASSERT_TRUE(scalar_populated.has_layout()) - << ShapeUtil::HumanStringWithLayout(scalar_populated); + const Shape scalar_empty_min2maj = + ShapeUtil::MakeShapeWithLayout(F32, {}, {}); + ASSERT_TRUE(scalar_empty_min2maj.has_layout()) + << ShapeUtil::HumanStringWithLayout(scalar_empty_min2maj); - EXPECT_TRUE(ShapeUtil::Equal(scalar_unpopulated, scalar_populated)); + EXPECT_TRUE(ShapeUtil::Equal(scalar_default_layout, scalar_empty_min2maj)); } TEST(ShapeUtilTest, ByteSizeOfWithoutPadding) { diff --git a/tensorflow/compiler/xla/statusor_test.cc b/tensorflow/compiler/xla/statusor_test.cc index 5fa2211ac66177514ac8ecabfa8791e7c8c014a2..f9d25945bc617507735fb6c4d011c39723497f69 100644 --- a/tensorflow/compiler/xla/statusor_test.cc +++ b/tensorflow/compiler/xla/statusor_test.cc @@ -32,26 +32,26 @@ namespace { class Base1 { public: virtual ~Base1() {} - int pad; + int pad_; }; class Base2 { public: virtual ~Base2() {} - int yetotherpad; + int yetotherpad_; }; class Derived : public Base1, public Base2 { public: ~Derived() override {} - int evenmorepad; + int evenmorepad_; }; class CopyNoAssign { public: - explicit CopyNoAssign(int value) : foo(value) {} - CopyNoAssign(const CopyNoAssign& other) : foo(other.foo) {} - int foo; + explicit CopyNoAssign(int value) : foo_(value) {} + CopyNoAssign(const CopyNoAssign& other) : foo_(other.foo_) {} + int foo_; private: const CopyNoAssign& operator=(const CopyNoAssign&); @@ -253,7 +253,7 @@ TEST(StatusOr, TestCopyCtorNonAssignable) { StatusOr original(value); StatusOr copy(original); EXPECT_EQ(copy.status(), original.status()); - EXPECT_EQ(original.ValueOrDie().foo, copy.ValueOrDie().foo); + EXPECT_EQ(original.ValueOrDie().foo_, copy.ValueOrDie().foo_); } TEST(StatusOr, TestCopyCtorStatusOKConverting) { diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index f3885e90214e8ea77d26e5ae250fc5821267826b..d8c0584d10c854ff46c6ce65c37a8ec92e02d6cf 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -69,6 +69,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_dataflow_analysis", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", @@ -104,7 +105,9 @@ cc_library( hdrs = ["hlo_test_base.h"], deps = [ ":literal_test_util", + ":test_utils", "//tensorflow/compiler/xla:shape_layout", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -114,6 +117,10 @@ cc_library( "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_runner", + "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/service:interpreter_plugin", # reference backend + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", @@ -354,6 +361,7 @@ xla_test( xla_test( name = "map_test", srcs = ["map_test.cc"], + tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal_util", @@ -431,6 +439,28 @@ xla_test( ], ) +xla_test( + name = "conditional_test", + srcs = ["conditional_test.cc"], + # Currently, Conditional is supported only in CPU and GPU backends. + backends = [ + "cpu", + "gpu", + "cpu_parallel", + ], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + xla_test( name = "unary_op_test", srcs = ["unary_op_test.cc"], @@ -512,6 +542,7 @@ xla_test( name = "array_elementwise_ops_test", srcs = ["array_elementwise_ops_test.cc"], shard_count = 25, + tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -770,6 +801,41 @@ xla_test( ], ) +xla_test( + name = "bfloat16_test", + srcs = ["bfloat16_test.cc"], + blacklisted_backends = [ + "gpu", + ], + shard_count = 40, + deps = [ + ":test_utils", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + xla_test( name = "slice_test", srcs = ["slice_test.cc"], @@ -1230,6 +1296,23 @@ xla_test( ], ) +xla_test( + name = "bitcast_convert_test", + srcs = ["bitcast_convert_test.cc"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core:test", + ], +) + xla_test( name = "compilation_cache_test", srcs = ["compilation_cache_test.cc"], @@ -1294,6 +1377,7 @@ xla_test( srcs = ["client_test.cc"], deps = [ "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", @@ -1621,6 +1705,45 @@ xla_test( ], ) +# A demo of textual IR based test. +xla_test( + name = "sample_text_test", + srcs = ["sample_text_test.cc"], + # You can leave this empty if you want to test all supported backends. + backends = [ + "cpu", + "gpu", + ], + deps = [ + ":hlo_test_base", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + ], +) + +# A demo of test that loads an hlo module from a file and compares results on gpu and cpu. +tf_cc_test( + name = "sample_file_test", + srcs = ["sample_file_test.cc"], + data = ["isolated_convolution.hlo"], + tags = ["requires-gpu-sm35"], + deps = [ + ":hlo_test_base", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/service:cpu_plugin", # reference backend + "//tensorflow/compiler/xla/service:gpu_plugin", # test backend + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc index 028d1251b455b82a291c236f7866e52e27d3590e..7525bc4bdfbaa942ea8af29af31829ae8742e833 100644 --- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -203,6 +204,15 @@ struct BatchNormTestParam { int64 feature_index; float random_value_mean; float random_value_var; + + friend ::std::ostream& operator<<(::std::ostream& os, + const BatchNormTestParam& p) { + os << "bounds={" << tensorflow::str_util::Join(p.bounds, ", ") << "}, "; + os << "feature_index=" << p.feature_index << ", "; + os << "random_value_mean=" << p.random_value_mean << ", "; + os << "random_value_var=" << p.random_value_var; + return os; + } }; // Tests to test the fused operation of BatchNorm. diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ac3f3f4c9ddb03d003a44f5abd7a2e26c42f490d --- /dev/null +++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc @@ -0,0 +1,160 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/reference_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class Bfloat16Test : public ClientLibraryTestBase { + protected: + const ErrorSpec error_spec_{0.001, 0.001}; +}; + +XLA_TEST_F(Bfloat16Test, ScalarOperation) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR0(static_cast(2.0f)); + auto y = builder.ConstantR0(static_cast(1.0f)); + builder.Add(x, y); + + ComputeAndCompareR0(&builder, static_cast(3.0f), {}, + error_spec_); +} + +XLA_TEST_F(Bfloat16Test, LogOperation) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR0(static_cast(4.0f)); + builder.Log(x); + + ComputeAndCompareR0(&builder, static_cast(1.387f), {}, + error_spec_); +} + +XLA_TEST_F(Bfloat16Test, NegateScalarF16) { + ComputationBuilder builder(client_, TestName()); + builder.Neg(builder.ConstantR0(static_cast(2.1f))); + + ComputeAndCompareR0(&builder, static_cast(-2.1f), {}, + error_spec_); +} + +XLA_TEST_F(Bfloat16Test, BatchNormTraining) { + const int kFeatureIndex = 2; + ComputationBuilder builder(client_, TestName()); + + auto operand = builder.ConstantR4FromArray4D( + {{{{static_cast(1.f)}, {static_cast(2.f)}}, + {{static_cast(3.f)}, {static_cast(4.f)}}}, + {{{static_cast(5.f)}, {static_cast(6.f)}}, + {{static_cast(7.f)}, {static_cast(8.f)}}}}); + + auto scale = builder.ConstantR1( + {static_cast(2.0f), static_cast(3.0f)}); + + auto offset = builder.ConstantR1( + {static_cast(1.0f), static_cast(2.0f)}); + + auto tuple = builder.BatchNormTraining(operand, scale, offset, + /*epsilon=*/0.001, kFeatureIndex); + + auto expected = *Literal::MakeTuple( + {Literal::CreateR4( + {{{{static_cast(-1.7f)}, {static_cast(-2.04f)}}, + {{static_cast(0.105f)}, {static_cast(0.65f)}}}, + {{{static_cast(1.89f)}, {static_cast(3.35f)}}, + {{static_cast(3.7f)}, {static_cast(6.04f)}}}}) + .get(), + Literal::CreateR1( + {static_cast(4), static_cast(5)}) + .get(), + Literal::CreateR1( + {static_cast(5), static_cast(5)}) + .get()}); + + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01)); +} + +XLA_TEST_F(Bfloat16Test, BatchNormGrad) { + const int kFeatureIndex = 2; + ComputationBuilder builder(client_, TestName()); + + auto operand = builder.ConstantR4FromArray4D( + Array4D(2, 2, 2, 1, static_cast(0.0f))); + + auto scale = builder.ConstantR1( + {static_cast(1.0f), static_cast(1.0f)}); + + auto mean = builder.ConstantR1( + {static_cast(0.0f), static_cast(0.0f)}); + + auto var = builder.ConstantR1( + {static_cast(1.0f), static_cast(1.0f)}); + + auto grad_output = builder.ConstantR4FromArray4D( + {{{{static_cast(1.f)}, {static_cast(2.f)}}, + {{static_cast(3.f)}, {static_cast(4.f)}}}, + {{{static_cast(5.f)}, {static_cast(6.f)}}, + {{static_cast(7.f)}, {static_cast(8.f)}}}}); + + builder.BatchNormGrad(operand, scale, mean, var, grad_output, + /*epsilon=*/0.0, kFeatureIndex); + + auto expected = *Literal::MakeTuple( + {Literal::CreateR4( + {{{{static_cast(-3.f)}, {static_cast(-3.f)}}, + {{static_cast(-1.f)}, {static_cast(-1.f)}}}, + {{{static_cast(1.f)}, {static_cast(1.f)}}, + {{static_cast(3.f)}, {static_cast(3.f)}}}}) + .get(), + Literal::CreateR1( + {static_cast(0), static_cast(0)}) + .get(), + Literal::CreateR1( + {static_cast(16), static_cast(20)}) + .get()}); + + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/bitcast_convert_test.cc b/tensorflow/compiler/xla/tests/bitcast_convert_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0d94d65c1015fb54ada3fdfc95d0c31d0a0f158b --- /dev/null +++ b/tensorflow/compiler/xla/tests/bitcast_convert_test.cc @@ -0,0 +1,141 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class BitcastConvertTest : public ClientLibraryTestBase { + public: + explicit BitcastConvertTest(perftools::gputools::Platform* platform = nullptr) + : ClientLibraryTestBase(platform) { + mutable_debug_options()->add_xla_disable_hlo_passes("algsimp"); + mutable_debug_options()->add_xla_disable_hlo_passes("inline"); + } +}; + +TEST_F(BitcastConvertTest, ConvertR1S32ToR1S32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({42, 64}); + builder.BitcastConvertType(a, S32); + + std::vector expected = {42, 64}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(BitcastConvertTest, ConvertR1F32ToR1F32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({42.0f, 64.0f}); + builder.BitcastConvertType(a, F32); + + std::vector expected = {42.0f, 64.0f}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(BitcastConvertTest, BitcastR1S32ToR1F32) { + ComputationBuilder builder(client_, TestName()); + auto a = + builder.ConstantR1({0, static_cast(0x80000000), 0x3F800000, + static_cast(0xBF800000), 0x3F000000, + static_cast(0xBF000000)}); + builder.BitcastConvertType(a, F32); + + std::vector expected = {0.0f, -0.0f, 1.0f, -1.0f, 0.5f, -0.5f}; + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(BitcastConvertTest, ConvertR1S0S32ToR1S0F32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + builder.BitcastConvertType(a, F32); + + std::vector expected = {}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(BitcastConvertTest, ConvertR1F32ToR1S32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({42.6, 64.4}); + builder.BitcastConvertType(a, S32); + + std::vector expected = {0x422a6666, 0x4280cccd}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(BitcastConvertTest, ConvertS32Extremes) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1( + {std::numeric_limits::min(), std::numeric_limits::max()}); + builder.BitcastConvertType(a, F32); + + std::vector expected = {-0.0f, NAN}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0, 0)); +} + +TEST_F(BitcastConvertTest, ConvertMapToS32) { + ComputationBuilder builder(client_, TestName()); + auto b = builder.CreateSubBuilder("convert"); + auto param = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "in"); + b->BitcastConvertType(param, S32); + auto a = builder.ConstantR1({42.0f, 64.0f}); + builder.Map({a}, b->BuildAndNoteError(), {0}); + + std::vector expected = {0x42280000, 0x42800000}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(BitcastConvertTest, ConvertMapToF32) { + ComputationBuilder builder(client_, TestName()); + auto b = builder.CreateSubBuilder("convert"); + auto param = b->Parameter(0, ShapeUtil::MakeShape(S32, {}), "in"); + b->BitcastConvertType(param, F32); + auto a = builder.ConstantR1({0x42280000, 0x42800000}); + builder.Map({a}, b->BuildAndNoteError(), {0}); + + std::vector expected = {42.0f, 64.0f}; + ComputeAndCompareR1(&builder, expected, {}); +} + +// Regression test for b/31758660. When ReshapeMover transforms +// input -> reshape -> convert +// to +// input -> convert -> reshape +// the new convert should have the same element type as the old convert. +TEST_F(BitcastConvertTest, ConvertReshape) { + ComputationBuilder builder(client_, TestName()); + auto input = builder.ConstantR1({0x42280000}); + auto reshape = builder.Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{}); + builder.BitcastConvertType(reshape, F32); + + ComputeAndCompareR0(&builder, 42.0f, {}); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl index f594c609db6282513a27a479a85e6a3dd1a7a3cd..610302ac1256a57db6ed6e18016a4136973e3891 100644 --- a/tensorflow/compiler/xla/tests/build_defs.bzl +++ b/tensorflow/compiler/xla/tests/build_defs.bzl @@ -29,6 +29,7 @@ def xla_test(name, deps, xla_test_library_deps=[], backends=[], + blacklisted_backends=[], args=[], tags=[], copts=[], @@ -92,17 +93,24 @@ def xla_test(name, backends: A list of backends to generate tests for. Supported values: "cpu", "cpu_parallel", "gpu". If this list is empty, the test will be generated for all supported backends. + blacklisted_backends: A list of backends to NOT generate tests for. args: Test arguments for the target. tags: Tags for the target. - backend_args: A dict mapping backend name to list of additional args to - use for that target. + copts: Additional copts to pass to the build. + data: Additional data to pass to the build. backend_tags: A dict mapping backend name to list of additional tags to use for that target. + backend_args: A dict mapping backend name to list of additional args to + use for that target. + **kwargs: Additional keyword arguments to pass to native.cc_test. """ test_names = [] if not backends: backends = all_backends + backends = [backend for backend in backends + if backend not in blacklisted_backends] + native.cc_library( name="%s_lib" % name, srcs=srcs, diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index ef54714e46ffe6f22f26410c33fa62c2d528f280..50bf185936808fbd9c49f7fbd5ab0c0b4a76504b 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -262,20 +262,39 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( expected.shape().element_type() == PRED) << ShapeUtil::HumanString(expected.shape()); } + // We allow using a float expected literal for a bfloat16 output. In this + // case, we need to convert the expected literal to bfloat16. + const Literal* expected_ptr = &expected; + std::unique_ptr converted_expected; + Shape layout_shape; + if (use_bfloat16_) { + converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected); + expected_ptr = converted_expected.get(); + if (shape_with_layout != nullptr) { + layout_shape = *shape_with_layout; + ShapeUtil::ForEachMutableSubshape( + &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) { + if (subshape->element_type() == F32) { + subshape->set_element_type(BF16); + } + }); + shape_with_layout = &layout_shape; + } + } auto expect_equal = [&](const Literal& actual, const string& error_message) { - LiteralTestUtil::ExpectEqual(expected, actual, error_message); + LiteralTestUtil::ExpectEqual(*expected_ptr, actual, error_message); }; if (execution_options_.debug_options().xla_test_all_output_layouts()) { return ComputeAndCompareLiteralWithAllOutputLayouts( - computation, expected, arguments, expect_equal); + computation, *expected_ptr, arguments, expect_equal); } if (execution_options_.debug_options().xla_test_all_input_layouts()) { return ComputeAndCompareLiteralWithAllInputLayouts( - computation, expected, arguments, expect_equal, shape_with_layout); + computation, *expected_ptr, arguments, expect_equal, shape_with_layout); } TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, shape_with_layout)); - LiteralTestUtil::ExpectEqual(expected, *actual); + LiteralTestUtil::ExpectEqual(*expected_ptr, *actual); return tensorflow::Status::OK(); } @@ -286,20 +305,39 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( TF_RET_CHECK(ShapeUtil::ElementIsFloating(expected.shape()) || ShapeUtil::ElementIsComplex(expected.shape())); TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); + // We allow using a float expected literal for a bfloat16 output. In this + // case, we need to convert the expected literal to bfloat16. + const Literal* expected_ptr = &expected; + std::unique_ptr converted_expected; + Shape layout_shape; + if (use_bfloat16_) { + converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected); + expected_ptr = converted_expected.get(); + if (shape_with_layout != nullptr) { + layout_shape = *shape_with_layout; + ShapeUtil::ForEachMutableSubshape( + &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) { + if (subshape->element_type() == F32) { + subshape->set_element_type(BF16); + } + }); + shape_with_layout = &layout_shape; + } + } auto expect_near = [&](const Literal& actual, const string& error_message) { - LiteralTestUtil::ExpectNear(expected, actual, error, error_message); + LiteralTestUtil::ExpectNear(*expected_ptr, actual, error, error_message); }; if (execution_options_.debug_options().xla_test_all_output_layouts()) { - return ComputeAndCompareLiteralWithAllOutputLayouts(computation, expected, - arguments, expect_near); + return ComputeAndCompareLiteralWithAllOutputLayouts( + computation, *expected_ptr, arguments, expect_near); } if (execution_options_.debug_options().xla_test_all_input_layouts()) { return ComputeAndCompareLiteralWithAllInputLayouts( - computation, expected, arguments, expect_near, shape_with_layout); + computation, *expected_ptr, arguments, expect_near, shape_with_layout); } TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, shape_with_layout)); - LiteralTestUtil::ExpectNear(expected, *actual, error); + LiteralTestUtil::ExpectNear(*expected_ptr, *actual, error); return tensorflow::Status::OK(); } @@ -402,8 +440,11 @@ ClientLibraryTestBase::ComputeValueAndReference( Computation ClientLibraryTestBase::CreateScalarRelu() { ComputationBuilder builder(client_, "relu"); - auto z_value = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "z_value"); - auto zero = builder.ConstantR0(0.0); + auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); + auto z_value = builder.Parameter(0, shape, "z_value"); + auto zero = use_bfloat16_ + ? builder.ConstantR0(static_cast(0.0f)) + : builder.ConstantR0(0.0f); builder.Max(z_value, zero); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); @@ -412,8 +453,9 @@ Computation ClientLibraryTestBase::CreateScalarRelu() { Computation ClientLibraryTestBase::CreateScalarMax() { ComputationBuilder builder(client_, "max"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); + auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); + auto x = builder.Parameter(0, shape, "x"); + auto y = builder.Parameter(1, shape, "y"); builder.Max(x, y); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); @@ -422,11 +464,12 @@ Computation ClientLibraryTestBase::CreateScalarMax() { Computation ClientLibraryTestBase::CreateScalarReluSensitivity() { ComputationBuilder builder(client_, "relu_sensitivity"); - auto activation = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "activation"); - auto backprop = - builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "backprop"); - auto zero = builder.ConstantR0(0.0); + auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); + auto activation = builder.Parameter(0, shape, "activation"); + auto backprop = builder.Parameter(1, shape, "backprop"); + auto zero = use_bfloat16_ + ? builder.ConstantR0(static_cast(0.0f)) + : builder.ConstantR0(0.0f); auto activation_gtz = builder.Gt(activation, zero); builder.Select(activation_gtz, /*on_true=*/backprop, /*on_false=*/zero); @@ -461,4 +504,27 @@ ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols, return array; } +std::unique_ptr +ClientLibraryTestBase::CreateParameterAndTransferLiteral( + int64 parameter_number, const Literal& literal, const string& name, + ComputationBuilder* builder, ComputationDataHandle* data_handle) { + const Literal* param_literal = &literal; + std::unique_ptr converted_literal; + if (use_bfloat16_) { + converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal); + param_literal = converted_literal.get(); + } + std::unique_ptr data = + client_->TransferToServer(*param_literal).ConsumeValueOrDie(); + *data_handle = + builder->Parameter(parameter_number, param_literal->shape(), name); + return data; +} + +ComputationDataHandle ClientLibraryTestBase::CreateConstantFromLiteral( + const Literal& literal, ComputationBuilder* builder) { + return builder->ConstantLiteral( + use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 1dc274c59172313bcc1b6e5e7029657c3fea937f..4d0cf8bf71cf22d7c046bb22754a8d4e299ed9db 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -194,7 +194,7 @@ class ClientLibraryTestBase : public ::testing::Test { tensorflow::gtl::ArraySlice arguments); void ComputeAndCompareTuple( ComputationBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec abs_error); + tensorflow::gtl::ArraySlice arguments, ErrorSpec error); // Convenience method for running a built computation and comparing the result // with the HloEvaluator. @@ -245,51 +245,102 @@ class ClientLibraryTestBase : public ::testing::Test { const int rows, const int cols, const int rows_padded, const int cols_padded); - // Create a parameter instruction that wraps a given value and then stores + // Creates a parameter instruction, transfers the literal for the parameter to + // server, then stores into "data_handle" the global handle for that + // parameter. When the use_bfloat16 flag is set but the literal has F32 + // elements, the literal will be converted to BF16 before being transferred. + std::unique_ptr CreateParameterAndTransferLiteral( + int64 parameter_number, const Literal& literal, const string& name, + ComputationBuilder* builder, ComputationDataHandle* data_handle); + + // Creates a constant instruction with the given literal. When the + // use_bfloat16 flag is set but the literal has F32 elements, the elements + // will be converted to BF16s. + ComputationDataHandle CreateConstantFromLiteral(const Literal& literal, + ComputationBuilder* builder); + + // Creates a constant instruction with the given array. When the use_bfloat16 + // flag is set but the array has float elements, the elements will be + // converted to bfloat16s. + template + ComputationDataHandle CreateConstantFromArray(const Array& array, + ComputationBuilder* builder) { + return CreateConstantFromLiteral(*Literal::CreateFromArray(array), builder); + } + + // Same as CreateConstantFromArray, but for scalars. + template + ComputationDataHandle CreateConstantFromScalar(NativeT value, + ComputationBuilder* builder) { + return CreateConstantFromLiteral(*Literal::CreateR0(value), + builder); + } + + // Creates a parameter instruction that wraps a given value and then stores // into "data_handle" the global handle for that parameter. // // "parameter_number" is the parameter number. // "name" is the name of the parameter instruction. + // + // When the use_bfloat16 flag is set but NativeT is float, the data will be + // converted to bfloat16. template std::unique_ptr CreateR0Parameter( NativeT value, int64 parameter_number, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle); - // Create a parameter instruction that wraps the given values and then stores + // Creates a parameter instruction that wraps the given values and then stores // into "data_handle" the global handle for that parameter. // // "parameter_number" is the parameter number. // "name" is the name of the parameter instruction. + // + // When the use_bfloat16 flag is set but NativeT is float, the data will be + // converted to bfloat16. template std::unique_ptr CreateR1Parameter( tensorflow::gtl::ArraySlice values, int64 parameter_number, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle); - // Create a parameter instruction that wraps the given constant array + // Creates a parameter instruction that wraps the given constant array // "array_2d" and then stores to "data_handle" the global handle for that // parameter. // // "parameter_number" is the parameter number. // "name" is the name of the parameter instruction. + // + // When the use_bfloat16 flag is set but NativeT is float, the data will be + // converted to bfloat16. template std::unique_ptr CreateR2Parameter( const Array2D& array_2d, int64 parameter_number, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle); - // Create a parameter instruction that wraps the given constant array + // Creates a parameter instruction that wraps the given constant array // "array_3d" and then stores to "data_handle" the global handle for that // parameter. // // "parameter_number" is the parameter number. // "name" is the name of the parameter instruction. + // + // When the use_bfloat16 flag is set but NativeT is float, the data will be + // converted to bfloat16. template std::unique_ptr CreateR3Parameter( const Array3D& array_3d, int64 parameter_number, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle); + // Getter and setter for the use_bfloat16 flag, which indicates whether to run + // tests with all float-type input/output converted to bfloat16. + bool use_bfloat16() const { return use_bfloat16_; } + void set_use_bfloat16(bool value) { use_bfloat16_ = value; } + + // The float type used in this test, BF16 or F32 according to use_bfloat16. + PrimitiveType FloatType() const { return use_bfloat16_ ? BF16 : F32; } + Client* client_; ExecutionOptions execution_options_; @@ -315,6 +366,10 @@ class ClientLibraryTestBase : public ::testing::Test { ComputeValueAndReference(ComputationBuilder* builder, const ComputationDataHandle& operand, tensorflow::gtl::ArraySlice arguments); + + // Whether to run tests with all float-type input/output converted to + // bfloat16. + bool use_bfloat16_ = false; }; template @@ -333,6 +388,7 @@ void ClientLibraryTestBase::ComputeAndCompareR0( tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = @@ -357,6 +413,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1( tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = @@ -381,6 +438,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2( tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = @@ -405,6 +463,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3( tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = @@ -429,6 +488,7 @@ void ClientLibraryTestBase::ComputeAndCompareR4( tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = @@ -442,6 +502,9 @@ std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( NativeT value, int64 parameter_number, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle) { std::unique_ptr literal = Literal::CreateR0(value); + if (use_bfloat16_ && literal->shape().element_type() == F32) { + literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); *data_handle = builder->Parameter(parameter_number, literal->shape(), name); @@ -454,6 +517,9 @@ std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle) { std::unique_ptr literal = Literal::CreateR1(values); + if (use_bfloat16_ && literal->shape().element_type() == F32) { + literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); *data_handle = builder->Parameter(parameter_number, literal->shape(), name); @@ -466,6 +532,9 @@ std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle) { std::unique_ptr literal = Literal::CreateR2FromArray2D(array_2d); + if (use_bfloat16_ && literal->shape().element_type() == F32) { + literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); *data_handle = builder->Parameter(parameter_number, literal->shape(), name); @@ -478,6 +547,9 @@ std::unique_ptr ClientLibraryTestBase::CreateR3Parameter( const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle) { std::unique_ptr literal = Literal::CreateR3FromArray3D(array_3d); + if (use_bfloat16_ && literal->shape().element_type() == F32) { + literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); *data_handle = builder->Parameter(parameter_number, literal->shape(), name); diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index 183bcf1dd333a6955bcae6dd07d2ef31fe817434..8853ed9e5780672d4006c326291767b8b5253f56 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -20,10 +20,12 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" @@ -42,26 +44,26 @@ TEST_F(ClientTest, ExecuteWithLayout) { for (const std::vector& transfer_layout : layouts) { b.Add(b.ConstantR2({{1, 2}, {3, 4}}), b.ConstantR2({{10, 20}, {30, 40}})); - auto computation = b.Build(); - ASSERT_TRUE(computation.ok()) << computation.status(); + TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, execute_layout); - std::unique_ptr data = - client_->Execute(computation.ValueOrDie(), {}, &execution_options) - .ConsumeValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr data, + client_->Execute(computation, {}, &execution_options)); std::unique_ptr expected_literal = Literal::CreateR2WithLayout( {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout)); - auto computed = client_->Transfer(*data, &expected_literal->shape()); + TF_ASSERT_OK_AND_ASSIGN( + auto computed, client_->Transfer(*data, &expected_literal->shape())); - LiteralTestUtil::AssertEqualShapesAndLayouts( - expected_literal->shape(), computed.ValueOrDie()->shape()); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); + LiteralTestUtil::AssertEqualShapesAndLayouts(expected_literal->shape(), + computed->shape()); + LiteralTestUtil::ExpectEqual(*expected_literal, *computed); } } } @@ -72,8 +74,7 @@ TEST_F(ClientTest, ExecuteWithTupleLayout) { b.Tuple({b.ConstantR2({{1, 2}, {3, 4}}), b.ConstantR2({{10, 20}, {30, 40}})}); - auto computation = b.Build(); - ASSERT_TRUE(computation.ok()) << computation.status(); + TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); ExecutionOptions execution_options = execution_options_; // Create a result shape with one element column major and the other row @@ -85,10 +86,9 @@ TEST_F(ClientTest, ExecuteWithTupleLayout) { ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, /*minor_to_major=*/{1, 0})}); - auto result = - client_ - ->ExecuteAndTransfer(computation.ValueOrDie(), {}, &execution_options) - .ConsumeValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN( + auto result, + client_->ExecuteAndTransfer(computation, {}, &execution_options)); LiteralTestUtil::ExpectR2Equal({{1, 2}, {3, 4}}, result->tuple_literals(0)); LiteralTestUtil::ExpectR2Equal({{10, 20}, {30, 40}}, @@ -107,5 +107,42 @@ TEST_F(ClientTest, ExecuteWithTupleLayout) { /*minor_to_major=*/{1, 0}))); } +TEST_F(ClientTest, DISABLED_ON_CPU_PARALLEL(DISABLED_ON_GPU(ExecuteParallel))) { + Computation add_with_one_arg, mul_with_two_args, dot_with_one_arg; + Shape shape = ShapeUtil::MakeShape(S32, {2, 2}); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr const_arg, + client_->TransferToServer(*Literal::CreateR2({{5, 6}, {7, 8}}))); + + ComputationBuilder b(client_, TestName() + ".add"); + b.Add(b.Parameter(0, shape, "param_0"), + b.ConstantR2({{1, 2}, {3, 4}})); + TF_ASSERT_OK_AND_ASSIGN(add_with_one_arg, b.Build()); + + // We can't really test parallel execution on CPU since all of the cores in a + // CPU are presented as a single device. So for now we test "parallel" + // execution on a single device. + std::vector computation_instances; + TF_ASSERT_OK_AND_ASSIGN(std::vector devices, + client_->GetDeviceHandles(1)); + ASSERT_EQ(devices.size(), 1); + + ExecutionOptions options = execution_options_; + *options.add_device_handles() = devices[0]; + computation_instances.push_back(Client::ComputationInstance( + add_with_one_arg, {const_arg.get()}, options, nullptr)); + + TF_ASSERT_OK_AND_ASSIGN(auto results, + client_->ExecuteParallel(computation_instances)); + auto expected_result = Literal::CreateR2({{6, 8}, {10, 12}}); + + TF_ASSERT_OK_AND_ASSIGN( + auto result_literal, + client_->Transfer(*results[0], &expected_result->shape())); + + LiteralTestUtil::ExpectEqual(*expected_result, *result_literal); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/codegen_test_base.cc b/tensorflow/compiler/xla/tests/codegen_test_base.cc index 43ea7f6019415a171123ee0315533b8a3b1ff984..e472408dcf7ed5fec74e886fd0092ce47ee2e7eb 100644 --- a/tensorflow/compiler/xla/tests/codegen_test_base.cc +++ b/tensorflow/compiler/xla/tests/codegen_test_base.cc @@ -19,8 +19,11 @@ namespace xla { StatusOr> CodegenTestBase::CompileToExecutable( std::unique_ptr hlo_module) { - return backend().compiler()->Compile(std::move(hlo_module), - backend().default_stream_executor()); + TF_ASSIGN_OR_RETURN(hlo_module, backend().compiler()->RunHloPasses( + std::move(hlo_module), + backend().default_stream_executor())); + return backend().compiler()->RunBackend(std::move(hlo_module), + backend().default_stream_executor()); } StatusOr> diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c8c4932be821e410e25c41741df436544ab876f0 --- /dev/null +++ b/tensorflow/compiler/xla/tests/conditional_test.cc @@ -0,0 +1,325 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +class ConditionalOpTest : public ClientLibraryTestBase { + protected: + Computation CreateR0F32ConstantComputation(float value) { + ComputationBuilder builder(client_, "Constant"); + builder.Parameter(0, empty_tuple_, "tuple"); + builder.ConstantR0(value); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateR0F32IdentityComputation() { + ComputationBuilder builder(client_, "Identity"); + builder.Parameter(0, r0f32_, "x"); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateR0F32CeilComputation() { + ComputationBuilder builder(client_, "Ceil"); + auto param = builder.Parameter(0, r0f32_, "param"); + builder.Ceil(param); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateR0F32FloorComputation() { + ComputationBuilder builder(client_, "Ceil"); + auto param = builder.Parameter(0, r0f32_, "param"); + builder.Floor(param); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateAddTupleComputation(const string& computation_name, + const Shape& tuple_shape) { + ComputationBuilder builder(client_, computation_name); + auto tuple = builder.Parameter(0, tuple_shape, "tuple"); + auto x = builder.GetTupleElement(tuple, 0); + auto y = builder.GetTupleElement(tuple, 1); + builder.Add(x, y); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateAddR0Computation() { + return CreateAddTupleComputation("AddR0", tuple_2_r0f32_); + } + + Computation CreateAddR1Computation() { + return CreateAddTupleComputation("AddR1", tuple_2_r1s2f32_); + } + + Computation CreateSubTupleComputation(const string& computation_name, + const Shape& tuple_shape) { + ComputationBuilder builder(client_, computation_name); + auto tuple = builder.Parameter(0, tuple_shape, "tuple"); + auto x = builder.GetTupleElement(tuple, 0); + auto y = builder.GetTupleElement(tuple, 1); + builder.Sub(x, y); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateSubR0Computation() { + return CreateSubTupleComputation("SubR0", tuple_2_r0f32_); + } + + Computation CreateSubR1Computation() { + return CreateSubTupleComputation("SubR1", tuple_2_r1s2f32_); + } + + Shape r0f32_ = ShapeUtil::MakeShape(F32, {}); + Shape tuple_2_r0f32_ = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}); + Shape tuple_2_r1s2f32_ = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeShape(F32, {2})}); + Shape empty_tuple_ = ShapeUtil::MakeTupleShape({}); + ErrorSpec error_spec_{0.001}; +}; + +// Test true and false computations that do not take any parameters. +XLA_TEST_F(ConditionalOpTest, Parameters0) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(true); + auto operands = builder.Tuple({}); + auto true_computation = CreateR0F32ConstantComputation(56.0f); + auto false_computation = CreateR0F32ConstantComputation(12.0f); + auto result = builder.Conditional(pred, operands, true_computation, operands, + false_computation); + + ComputeAndCompareR0(&builder, 56.0f, {}, error_spec_); +} + +// Test true and false computations that take in 1 parameter. +XLA_TEST_F(ConditionalOpTest, Parameters1) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operand1 = builder.ConstantR0(56.0f); + auto operand2 = builder.ConstantR0(12.0f); + auto identity = CreateR0F32IdentityComputation(); + auto result = + builder.Conditional(pred, operand1, identity, operand2, identity); + + ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); +} + +// Test conditional with two different computations in the true and false cases +// that take in different arguments. +XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operand1 = builder.ConstantR0(56.4f); + auto operand2 = builder.ConstantR0(12.6f); + auto result = + builder.Conditional(pred, operand1, CreateR0F32CeilComputation(), + operand2, CreateR0F32FloorComputation()); + + ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); +} + +// Test conditional with two different computations in the true and false cases +// that take in the same arguments. +XLA_TEST_F(ConditionalOpTest, DiffComputationsSameArg) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operand = builder.ConstantR0(12.6f); + auto result = builder.Conditional(pred, operand, CreateR0F32CeilComputation(), + operand, CreateR0F32FloorComputation()); + + ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); +} + +// Test conditional with the same computation in the true and false cases but +// take in different arguments. +XLA_TEST_F(ConditionalOpTest, SameComputationDiffArgs) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operand1 = builder.ConstantR0(56.4f); + auto operand2 = builder.ConstantR0(12.6f); + auto floor = CreateR0F32FloorComputation(); + auto result = builder.Conditional(pred, operand1, floor, operand2, floor); + + ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); +} + +// Test conditional with the same computation in the true and false cases that +// take in the same arguments. +XLA_TEST_F(ConditionalOpTest, SameComputationSameArg) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operand = builder.ConstantR0(12.6f); + auto floor = CreateR0F32FloorComputation(); + auto result = builder.Conditional(pred, operand, floor, operand, floor); + + ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); +} + +// Test conditional with different instances of the same computation in the true +// and false cases. +XLA_TEST_F(ConditionalOpTest, SameComputationDiffInstances) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operand1 = builder.ConstantR0(56.4f); + auto operand2 = builder.ConstantR0(12.6f); + auto result = + builder.Conditional(pred, operand1, CreateR0F32FloorComputation(), + operand2, CreateR0F32FloorComputation()); + + ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); +} + +// Test the case when a call invokes a computation that contains a conditional. +XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) { + Shape r0bool = ShapeUtil::MakeShape(PRED, {}); + ComputationBuilder inner_builder(client_, TestName() + ".inner_conditional"); + auto pred_cond = inner_builder.Parameter(0, r0bool, "param0"); + auto true_operand = inner_builder.Parameter(1, r0f32_, "param1"); + auto false_operand = inner_builder.Parameter(2, r0f32_, "param2"); + inner_builder.Conditional(pred_cond, true_operand, + CreateR0F32CeilComputation(), false_operand, + CreateR0F32FloorComputation()); + auto inner_builder_result = inner_builder.Build(); + + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operand1 = builder.ConstantR0(56.4f); + auto operand2 = builder.ConstantR0(12.6f); + builder.Call(inner_builder_result.ConsumeValueOrDie(), + {pred, operand1, operand2}); + + ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); +} + +// Test true and false computations that take in 2 parameters and predicate is +// true. +XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(true); + auto operand1 = builder.ConstantR0(56.0f); + auto operand2 = builder.ConstantR0(12.0f); + auto operands = builder.Tuple({operand1, operand2}); + auto result = builder.Conditional(pred, operands, CreateAddR0Computation(), + operands, CreateSubR0Computation()); + + ComputeAndCompareR0(&builder, 68.0f, {}, error_spec_); +} + +// Test true and false computations that take in 2 parameters and predicate is +// false. +XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operand1 = builder.ConstantR0(56.0f); + auto operand2 = builder.ConstantR0(12.0f); + auto operands = builder.Tuple({operand1, operand2}); + auto result = builder.Conditional(pred, operands, CreateAddR0Computation(), + operands, CreateSubR0Computation()); + + ComputeAndCompareR0(&builder, 44.0f, {}, error_spec_); +} + +// Test true and false computations that take in 2 array parameters and +// predicate is true. +XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(true); + auto operand1 = builder.ConstantR1({24.0f, 56.0f}); + auto operand2 = builder.ConstantR1({10.0f, 11.0f}); + auto operands = builder.Tuple({operand1, operand2}); + auto result = builder.Conditional(pred, operands, CreateAddR1Computation(), + operands, CreateSubR1Computation()); + + ComputeAndCompareR1(&builder, {34.0f, 67.0f}, {}, error_spec_); +} + +// Test true and false computations that take in 2 array parameters and +// predicate is false. +XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operand1 = builder.ConstantR1({24.0f, 56.0f}); + auto operand2 = builder.ConstantR1({10.0f, 11.0f}); + auto operands = builder.Tuple({operand1, operand2}); + auto result = builder.Conditional(pred, operands, CreateAddR1Computation(), + operands, CreateSubR1Computation()); + + ComputeAndCompareR1(&builder, {14.0f, 45.0f}, {}, error_spec_); +} + +// Test the case where one conditional is nested within another. +XLA_TEST_F(ConditionalOpTest, NestedConditionals) { + Shape r0bool = ShapeUtil::MakeShape(PRED, {}); + Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_}); + ComputationBuilder inner_builder(client_, TestName() + ".inner_conditional"); + auto param0 = inner_builder.Parameter(0, tuple_shape, "param0"); + auto pred_cond = inner_builder.GetTupleElement(param0, 0); + auto true_operand = inner_builder.GetTupleElement(param0, 1); + auto false_operand = inner_builder.GetTupleElement(param0, 2); + inner_builder.Conditional(pred_cond, true_operand, + CreateR0F32CeilComputation(), false_operand, + CreateR0F32FloorComputation()); + auto inner_builder_result = inner_builder.Build(); + + ComputationBuilder builder(client_, TestName()); + auto pred1 = builder.ConstantR0(true); + auto pred2 = builder.ConstantR0(false); + auto operand1 = builder.ConstantR0(1.1f); + auto operand2 = builder.ConstantR0(12.2f); + auto operand3 = builder.ConstantR0(43.3f); + auto tuple_operand = builder.Tuple({pred2, operand1, operand2}); + builder.Conditional(pred1, tuple_operand, + inner_builder_result.ConsumeValueOrDie(), operand3, + CreateR0F32IdentityComputation()); + + ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); +} + +// Test a mismatch in the shape of the true operand and true computation. +XLA_TEST_F(ConditionalOpTest, ShapeMismatch) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(true); + auto operand1 = builder.ConstantR0(56.0f); + auto operand2 = builder.ConstantR0(12.0f); + auto operands = builder.Tuple({operand1, operand2}); + builder.Conditional(pred, operands, CreateAddR1Computation(), operands, + CreateSubR0Computation()); + + auto result = builder.Build(); + EXPECT_FALSE(result.ok()); + EXPECT_THAT(result.status().error_message(), + ::testing::HasSubstr("true_operand must match the shape of the " + "only parameter of true_computation")); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc index b0a63bccbb93f226175beff2e30e2a243fdca1d3..896b34fb6e2762c14bd9ec2bf1ba13c548d4cf60 100644 --- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc @@ -39,8 +39,8 @@ class ConvolutionDimensionNumbersTest : public ClientLibraryTestBase {}; // Tests the convolution operation with invalid input dimension numbers. TEST_F(ConvolutionDimensionNumbersTest, InvalidInputDimensionNumbers) { auto dimension_numbers_status = - ComputationBuilder::CreateConvDimensionNumbers(0, 2, 0, 2, 2, 3, 0, 1, 2, - 3); + ComputationBuilder::CreateConvDimensionNumbers(0, 2, 2, 3, 0, 1, 2, 3, 0, + 1, 2, 3); ASSERT_FALSE(dimension_numbers_status.ok()); ASSERT_THAT(dimension_numbers_status.status().error_message(), ::testing::HasSubstr("input are not unique")); @@ -49,13 +49,23 @@ TEST_F(ConvolutionDimensionNumbersTest, InvalidInputDimensionNumbers) { // Tests the convolution operation with invalid weight dimension numbers. TEST_F(ConvolutionDimensionNumbersTest, InvalidWeightDimensionNumbers) { auto dimension_numbers_status = - ComputationBuilder::CreateConvDimensionNumbers(0, 1, 0, 1, 2, 3, 2, 3, 2, - 3); + ComputationBuilder::CreateConvDimensionNumbers(0, 1, 2, 3, 0, 1, 2, 3, 0, + 2, 2, 3); ASSERT_FALSE(dimension_numbers_status.ok()); ASSERT_THAT(dimension_numbers_status.status().error_message(), ::testing::HasSubstr("weight are not unique")); } +// Tests the convolution operation with invalid output dimension numbers. +TEST_F(ConvolutionDimensionNumbersTest, InvalidOutputDimensionNumbers) { + auto dimension_numbers_status = + ComputationBuilder::CreateConvDimensionNumbers(0, 1, 2, 3, 0, 2, 2, 3, 0, + 1, 2, 3); + ASSERT_FALSE(dimension_numbers_status.ok()); + ASSERT_THAT(dimension_numbers_status.status().error_message(), + ::testing::HasSubstr("output are not unique")); +} + XLA_TEST_F(ConvolutionDimensionNumbersTest, TwoConvsWithDifferentDimensionNumbers) { auto input_array = MakeUnique>(2, 3, 5, 5); diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 7425f778a635c3b52b046d18ff79176a9c26c577..2924c08615fa706bb19addf04bf58e1d5dd5a659 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -370,9 +370,12 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { ConvolutionDimensionNumbers dnums; dnums.set_input_batch_dimension(0); dnums.set_output_batch_dimension(0); - dnums.add_spatial_dimensions(1); - dnums.add_spatial_dimensions(2); - dnums.add_spatial_dimensions(3); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(3); + dnums.add_output_spatial_dimensions(3); dnums.set_input_feature_dimension(4); dnums.set_output_feature_dimension(4); dnums.add_kernel_spatial_dimensions(0); @@ -423,8 +426,10 @@ XLA_TEST_F(ConvolutionTest, Convolve2D_1x3x3x5_3x3x5x5_Valid) { ConvolutionDimensionNumbers dnums; dnums.set_input_batch_dimension(0); dnums.set_output_batch_dimension(0); - dnums.add_spatial_dimensions(1); - dnums.add_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); dnums.set_input_feature_dimension(3); dnums.set_output_feature_dimension(3); dnums.add_kernel_spatial_dimensions(0); @@ -458,6 +463,54 @@ XLA_TEST_F(ConvolutionTest, Convolve2D_1x3x3x5_3x3x5x5_Valid) { error_spec_); } +// Test fixture to run convolution tests with and without convolution +// canonicalization enabled. +class ConvolveWithAndWithoutCanonicalization + : public ConvolutionTest, + public ::testing::WithParamInterface {}; + +XLA_TEST_P(ConvolveWithAndWithoutCanonicalization, + DISABLED_ON_GPU(Convolve2D_NoSpatialDims)) { + if (GetParam()) { + execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( + "convolution-canonicalization"); + } + ComputationBuilder builder(client_, TestName()); + Shape input_shape = ShapeUtil::MakeShape(F32, {4, 29}); + Shape filter_shape = ShapeUtil::MakeShape(F32, {4, 10}); + + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + + ConvolutionDimensionNumbers dnums; + dnums.set_input_feature_dimension(0); + dnums.set_input_batch_dimension(1); + dnums.set_kernel_input_feature_dimension(0); + dnums.set_kernel_output_feature_dimension(1); + dnums.set_output_batch_dimension(0); + dnums.set_output_feature_dimension(1); + auto conv = builder.ConvWithGeneralDimensions(input, filter, {}, + Padding::kValid, dnums); + + Array2D param0(4, 29); + param0.FillUnique(); + + Array2D param1(4, 10); + param1.FillUnique(); + + Array2D expected_result(29, 10); + expected_result.Fill(0); + + ComputeAndCompare( + &builder, conv, + {*Literal::CreateFromArray(param0), *Literal::CreateFromArray(param1)}, + error_spec_); +} + +INSTANTIATE_TEST_CASE_P(ConvolveWithAndWithoutCanonicalization_Instantiation, + ConvolveWithAndWithoutCanonicalization, + ::testing::Values(true, false)); + struct Convolve1DTestParam { int64 input_feature; int64 output_feature; @@ -490,7 +543,8 @@ XLA_TEST_P(Convolve1D1WindowTest, Convolve1D1Window) { ConvolutionDimensionNumbers dnums; dnums.set_input_batch_dimension(0); dnums.set_output_batch_dimension(0); - dnums.add_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); dnums.set_input_feature_dimension(2); dnums.set_output_feature_dimension(2); dnums.add_kernel_spatial_dimensions(0); diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc index 9b36e3722b8f8a5d01c426425fdfb0c4b9ae3a16..9c1145def8c11f1222c63adf006102887d49f00d 100644 --- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc @@ -320,9 +320,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter3x3in2x2Padded) { Array4D input_array(1, 1, 2, 2, {1, 2, 3, 4}); auto input = builder.ConstantR4FromArray4D(input_array); - const Array4D filter_array(1, 1, 3, 3, {10000, 0, 1000, // row 0 - 0, 100, 0, // row 1 - 10, 0, 1}); // row 2 + const Array4D filter_array(1, 1, 3, 3, + {10000, 0, 1000, // row 0 + 0, 100, 0, // row 1 + 10, 0, 1}); // row 2 auto filter = builder.ConstantR4FromArray4D(filter_array); builder.Conv(input, filter, {1, 1}, Padding::kSame); @@ -472,7 +473,9 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input3x1x2x2) { builder.Conv(input, filter, {1, 1}, Padding::kValid); std::vector expected_data = { - 23, 33, 43, + 23, + 33, + 43, }; Array4D expected(bs, 1, 1, 1, expected_data); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -669,10 +672,11 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation) { std::iota(input_data.begin(), input_data.end(), 1.0); Array4D input_array(1, 1, 3, 4, input_data); - Array4D filter_array(1, 1, 4, 3, {100, 10, 1, // - 200, 20, 2, // - 300, 30, 3, // - 400, 40, 4}); + Array4D filter_array(1, 1, 4, 3, + {100, 10, 1, // + 200, 20, 2, // + 300, 30, 3, // + 400, 40, 4}); auto input = builder.ConstantR4FromArray4D(input_array); auto filter = builder.ConstantR4FromArray4D(filter_array); builder.ConvGeneralDilated( @@ -681,9 +685,10 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation) { /*rhs_dilation=*/{}, ComputationBuilder::CreateDefaultConvDimensionNumbers()); - Array4D expected(1, 1, 3, 5, {204, 40, 406, 60, 608, // - 1518, 180, 1821, 210, 2124, // - 4146, 460, 4651, 510, 5156}); + Array4D expected(1, 1, 3, 5, + {204, 40, 406, 60, 608, // + 1518, 180, 1821, 210, 2124, // + 4146, 460, 4651, 510, 5156}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); } @@ -926,7 +931,8 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter16x16x1x1) { ComputeAndCompareR4(&builder, *expected, {}, error_spec_); } -XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x16x16_Filter16x16x16x16) { +XLA_TEST_F(ConvolutionVariantsTest, + RandomData_Input16x16x16x16_Filter16x16x16x16) { constexpr int bs = 16; constexpr int iz = 16; constexpr int oz = 16; @@ -976,8 +982,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2x1x1Input1x2x3x1GeneralPadding) { // NHWC input format. dnums.set_input_batch_dimension(0); dnums.set_output_batch_dimension(0); - dnums.add_spatial_dimensions(1); - dnums.add_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); dnums.set_input_feature_dimension(3); dnums.set_output_feature_dimension(3); @@ -1018,8 +1026,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1GeneralPadding) { // NHWC input format. dnums.set_input_batch_dimension(0); dnums.set_output_batch_dimension(0); - dnums.add_spatial_dimensions(1); - dnums.add_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); dnums.set_input_feature_dimension(3); dnums.set_output_feature_dimension(3); @@ -1060,8 +1070,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1NoPadding) { // NHWC input format. dnums.set_input_batch_dimension(0); dnums.set_output_batch_dimension(0); - dnums.add_spatial_dimensions(1); - dnums.add_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); dnums.set_input_feature_dimension(3); dnums.set_output_feature_dimension(3); @@ -1099,8 +1111,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) { // NHWC input format. dnums.set_input_batch_dimension(0); dnums.set_output_batch_dimension(0); - dnums.add_spatial_dimensions(1); - dnums.add_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); dnums.set_input_feature_dimension(3); dnums.set_output_feature_dimension(3); @@ -1131,7 +1145,8 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) { // Conv([1,2,3], Reverse([5,6]), padding_low=1) // into // BackwardInputConv([1,2,3], [5,6], padding_low=0, padding_high=1) -XLA_TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingLessThanHighPadding) { +XLA_TEST_F(ConvolutionVariantsTest, + BackwardInputLowPaddingLessThanHighPadding) { ComputationBuilder builder(client_, TestName()); auto gradients = builder.ConstantR4FromArray4D( @@ -1149,7 +1164,8 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingLessThanHighPadding) // Conv([1], Reverse([1,10,100]), padding_high=3, base_dilation=3) // into // BackwardInputConv([1], [1,10,100], stride=3, padding=(2,1)) -XLA_TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingGreaterThanHighPadding) { +XLA_TEST_F(ConvolutionVariantsTest, + BackwardInputLowPaddingGreaterThanHighPadding) { ComputationBuilder builder(client_, TestName()); auto gradients = builder.ConstantR4FromArray4D( @@ -1206,7 +1222,8 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputWithNegativePaddingHigh) { ComputeAndCompareR4(&builder, {{{{12, 23, 30, 0}}}}, {}, error_spec_); } -XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterLowPaddingLessThanHighPadding) { +XLA_TEST_F(ConvolutionVariantsTest, + BackwardFilterLowPaddingLessThanHighPadding) { ComputationBuilder builder(client_, TestName()); // activations: 1,2,3,4 ---pad--> 0,1,2,3,4,0,0 @@ -1230,7 +1247,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterLowPaddingLessThanHighPadding) } XLA_TEST_F(ConvolutionVariantsTest, - BackwardFilterLowPaddingGreaterThanHighPadding) { + BackwardFilterLowPaddingGreaterThanHighPadding) { ComputationBuilder builder(client_, TestName()); // activations: 1,2,3,4 ---pad--> 0,0,1,2,3,4 diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index bcb85b04eefa349df1c055e010d584b85b55a4a8..d64bf0aa5bd5e9d6213ea07b3da3305a9c621c65 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -56,9 +56,13 @@ class CopyOpTest : public HloTestBase { tensorflow::gtl::ArraySlice permutation); }; -XLA_TEST_F(CopyOpTest, CopyR0Bool) { TestCopyOp(*Literal::CreateR0(true)); } +XLA_TEST_F(CopyOpTest, CopyR0Bool) { + TestCopyOp(*Literal::CreateR0(true)); +} -XLA_TEST_F(CopyOpTest, CopyR1S0U32) { TestCopyOp(*Literal::CreateR1({})); } +XLA_TEST_F(CopyOpTest, CopyR1S0U32) { + TestCopyOp(*Literal::CreateR1({})); +} XLA_TEST_F(CopyOpTest, CopyR1S3U32) { TestCopyOp(*Literal::CreateR1({1, 2, 3})); @@ -85,7 +89,6 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) { // Copy literal to device to use as parameter. auto literal = Literal::CreateR0(42.0); Shape shape = literal->shape(); - auto constant_device_base = TransferToDevice(*literal); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param0")); @@ -98,7 +101,7 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) { module->AddEntryComputation(std::move(computation)); std::unique_ptr result = - ExecuteAndTransfer(std::move(module), {constant_device_base}); + ExecuteAndTransfer(std::move(module), {literal.get()}); LiteralTestUtil::ExpectR0Near(42.0f, *result, error_spec_); } diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index bfb04fd9f9bf6887c4462cb00fee00250517f5c4..cc683701e6305510d202721fe645310f1009081c 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -51,8 +51,6 @@ class DotOperationTest : public ClientLibraryTestBase { template void TestNonsquareMatrixDot(bool lhs_row_major = false, bool rhs_row_major = false); - void TestMatrixDot(int M, int K, int N, bool lhs_row_major = false, - bool rhs_row_major = false); }; XLA_TEST_F(DotOperationTest, ZeroElementVectorDotF32) { @@ -199,158 +197,182 @@ void DotOperationTest::TestSquareMatrixDot(bool lhs_row_major, &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_); } -void DotOperationTest::TestMatrixDot(int M, int K, int N, bool lhs_row_major, - bool rhs_row_major) { - std::unique_ptr> lhs_data = - MakeLinspaceArray2D(0.0, 1.0, M, K); - std::unique_ptr lhs_lit = Literal::CreateR2FromArray2DWithLayout( - *lhs_data, - LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major))); - auto lhs_handle = client_->TransferToServer(*lhs_lit).ConsumeValueOrDie(); +struct DotTestParam { + int m; + int k; + int n; + bool dot_lhs_row_major; + bool dot_rhs_row_major; + bool has_addend; + bool addend_row_major; +}; - std::unique_ptr> rhs_data = - MakeLinspaceArray2D(0.0, 1.0, K, N); - std::unique_ptr rhs_lit = Literal::CreateR2FromArray2DWithLayout( - *rhs_data, - LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major))); - auto rhs_handle = client_->TransferToServer(*rhs_lit).ConsumeValueOrDie(); +string PrintDotTestParam( + const ::testing::TestParamInfo& test_param) { + const DotTestParam& param = test_param.param; + if (param.has_addend) { + return tensorflow::strings::StrCat(param.m, "x", param.k, "x", param.n, + "_MajorToMinor", + param.dot_lhs_row_major ? "T" : "F", + param.dot_rhs_row_major ? "T" : "F", + param.addend_row_major ? "T" : "F"); + } else { + return tensorflow::strings::StrCat(param.m, "x", param.k, "x", param.n, + "_MajorToMinor", + param.dot_lhs_row_major ? "T" : "F", + param.dot_rhs_row_major ? "T" : "F"); + } +} + +class ParametricDotTest : public DotOperationTest, + public ::testing::WithParamInterface {}; + +XLA_TEST_P(ParametricDotTest, TestF32) { + DotTestParam param = GetParam(); + + std::unique_ptr> dot_lhs_data = + MakeLinspaceArray2D(0.0, 1.0, param.m, param.k); + std::unique_ptr dot_lhs_lit = Literal::CreateR2FromArray2DWithLayout( + *dot_lhs_data, LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(param.dot_lhs_row_major))); + std::unique_ptr dot_lhs_handle = + client_->TransferToServer(*dot_lhs_lit).ConsumeValueOrDie(); + + std::unique_ptr> dot_rhs_data = + MakeLinspaceArray2D(0.0, 1.0, param.k, param.n); + std::unique_ptr dot_rhs_lit = Literal::CreateR2FromArray2DWithLayout( + *dot_rhs_data, LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(param.dot_rhs_row_major))); + std::unique_ptr dot_rhs_handle = + client_->TransferToServer(*dot_rhs_lit).ConsumeValueOrDie(); + + std::unique_ptr> addend_data; + std::unique_ptr addend_lit; + std::unique_ptr addend_handle; + + if (param.has_addend) { + addend_data = MakeLinspaceArray2D(0.0, 1.0, param.m, param.n); + addend_lit = Literal::CreateR2FromArray2DWithLayout( + *addend_data, LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(param.addend_row_major))); + addend_handle = client_->TransferToServer(*addend_lit).ConsumeValueOrDie(); + } ComputationBuilder builder(client_, TestName()); auto prim_type = primitive_util::NativeToPrimitiveType(); auto result = builder.Dot( - builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {M, K}), "lhs"), - builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {K, N}), "rhs")); - - std::unique_ptr> expected = - ReferenceUtil::MatmulArray2D(*lhs_data, *rhs_data); - - ComputeAndCompareR2(&builder, *expected, - {lhs_handle.get(), rhs_handle.get()}, - ErrorSpec(0.3, 3e-3)); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_12_117_7_MinorToMajorTF) { - TestMatrixDot(12, 117, 7, true, false); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_12_117_7_MinorToMajorFT) { - TestMatrixDot(12, 117, 7, false, true); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_12_117_7_MinorToMajorTT) { - TestMatrixDot(12, 117, 7, true, true); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_12_117_7_MinorToMajorFF) { - TestMatrixDot(12, 117, 7, false, false); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_270_270_520_MinorToMajorTT) { - TestMatrixDot(270, 270, 520, true, true); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_270_270_520_MinorToMajorTF) { - TestMatrixDot(270, 270, 520, true, false); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_270_270_520_MinorToMajorFT) { - TestMatrixDot(270, 270, 520, false, true); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_270_270_520_MinorToMajorFF) { - TestMatrixDot(270, 270, 520, false, false); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_260_3_520_MinorToMajorTT) { - TestMatrixDot(269, 3, 520, true, true); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_260_3_520_MinorToMajorTF) { - TestMatrixDot(260, 3, 520, true, false); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_260_3_520_MinorToMajorFT) { - TestMatrixDot(260, 3, 520, false, true); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_260_3_520_MinorToMajorFF) { - TestMatrixDot(260, 3, 520, false, false); -} - -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x8x8) { - TestMatrixDot(1, 8, 8, true, true); -} + builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {param.m, param.k}), + "dot_lhs"), + builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {param.k, param.n}), + "dot_rhs")); + + if (param.has_addend) { + result = builder.Add( + result, + builder.Parameter( + 2, ShapeUtil::MakeShape(prim_type, {param.m, param.n}), "addend")); + } -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x130x8) { - TestMatrixDot(1, 130, 8, true, true); -} + std::unique_ptr> expected; + if (param.has_addend) { + expected = ReferenceUtil::ApplyElementwise2D( + std::plus(), + *ReferenceUtil::MatmulArray2D(*dot_lhs_data, *dot_rhs_data), + *addend_data); + } else { + expected = ReferenceUtil::MatmulArray2D(*dot_lhs_data, *dot_rhs_data); + } -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x8x130) { - TestMatrixDot(1, 8, 130, true, true); -} + std::vector args = {dot_lhs_handle.get(), dot_rhs_handle.get()}; + if (param.has_addend) { + args.push_back(addend_handle.get()); + } -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x290x130) { - TestMatrixDot(1, 290, 130, true, true); + ComputeAndCompareR2(&builder, *expected, args, ErrorSpec(0.3, 3e-3)); } -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_2x1x1) { - TestMatrixDot(2, 1, 1, true, true); -} +std::vector CreateDotTestParameters() { + std::vector params; -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_8x8x1) { - TestMatrixDot(8, 8, 1, true, true); -} + auto add_matrix_matrix_dot_test = [&](int m, int k, int n) { + for (bool lhs_row_major : {true, false}) { + for (bool rhs_row_major : {true, false}) { + params.push_back({/*m=*/m, /*k=*/k, /*n=*/n, + /*dot_lhs_row_major=*/lhs_row_major, + /*dot_rhs_row_major=*/rhs_row_major, + /*has_addend=*/false, /*addend_row_major=*/true}); + } + } + }; + + auto add_matrix_vector_dot_test = [&](int k, int n) { + for (bool has_addend : {false, true}) { + params.push_back({/*m=*/1, /*k=*/k, /*n=*/n, + /*dot_lhs_row_major=*/true, /*dot_rhs_row_major=*/true, + /*has_addend=*/has_addend, /*addend_row_major=*/true}); + if (n != 1) { + params.push_back( + {/*m=*/n, /*k=*/k, /*n=*/1, + /*dot_lhs_row_major=*/true, /*dot_rhs_row_major=*/true, + /*has_addend=*/has_addend, /*addend_row_major=*/true}); + } + } + }; -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_16x1x1) { - TestMatrixDot(16, 1, 1, true, true); -} + add_matrix_matrix_dot_test(/*m=*/12, /*k=*/117, /*n=*/7); + add_matrix_matrix_dot_test(/*m=*/270, /*k=*/270, /*n=*/520); + add_matrix_matrix_dot_test(/*m=*/260, /*k=*/3, /*n=*/520); -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_16x3x1) { - TestMatrixDot(16, 3, 1, true, true); -} + add_matrix_vector_dot_test(/*k=*/8, /*n=*/8); + add_matrix_vector_dot_test(/*k=*/130, /*n=*/8); + add_matrix_vector_dot_test(/*k=*/8, /*n=*/130); + add_matrix_vector_dot_test(/*k=*/290, /*n=*/130); + add_matrix_vector_dot_test(/*k=*/1, /*n=*/1); + add_matrix_vector_dot_test(/*k=*/1, /*n=*/16); + add_matrix_vector_dot_test(/*k=*/3, /*n=*/16); + add_matrix_vector_dot_test(/*k=*/3, /*n=*/3); + add_matrix_vector_dot_test(/*k=*/29, /*n=*/29); + add_matrix_vector_dot_test(/*k=*/8, /*n=*/2); + add_matrix_vector_dot_test(/*k=*/2, /*n=*/8); + add_matrix_vector_dot_test(/*k=*/259, /*n=*/258); -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_3x3x1) { - TestMatrixDot(3, 3, 1, true, true); + return params; } -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_29x29x1) { - TestMatrixDot(29, 29, 1, true, true); -} +INSTANTIATE_TEST_CASE_P(DotTests, ParametricDotTest, + ::testing::ValuesIn(CreateDotTestParameters()), + PrintDotTestParam); -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x8x2) { - TestMatrixDot(1, 8, 2, true, true); +XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFF) { + TestSquareMatrixDot(false, false); } -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x2x8) { - TestMatrixDot(1, 2, 8, true, true); +XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFT) { + TestSquareMatrixDot(false, true); } -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_259x258x1) { - TestMatrixDot(259, 258, 1, true, true); +XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTF) { + TestSquareMatrixDot(true, false); } -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_259x258x1_FT) { - TestMatrixDot(259, 258, 1, false, true); +XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTT) { + TestSquareMatrixDot(true, true); } -XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFF) { - constexpr bool kLhsRowMajor = false; - constexpr bool kRhsRowMajor = false; - TestSquareMatrixDot(kLhsRowMajor, kRhsRowMajor); +XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorFF) { + TestSquareMatrixDot(false, false); } -XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFT) { - TestSquareMatrixDot(false, true); +XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorFT) { + TestSquareMatrixDot(false, true); } -XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTF) { - TestSquareMatrixDot(true, false); +XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorTF) { + TestSquareMatrixDot(true, false); } -TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTT) { - constexpr bool kLhsRowMajor = true; - constexpr bool kRhsRowMajor = true; - TestSquareMatrixDot(kLhsRowMajor, kRhsRowMajor); +XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorTT) { + TestSquareMatrixDot(true, true); } XLA_TEST_F(DotOperationTest, SquareMatrixDotF64) { @@ -561,5 +583,95 @@ TEST_F(DotOperationTest, TransposeFolding) { } } +TEST_F(DotOperationTest, DotOfConcatOptimizationWithConstLHS) { + auto prim_type = primitive_util::NativeToPrimitiveType(); + + std::unique_ptr> constant_lhs_array(new Array2D( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + + ComputationBuilder builder(client_, TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), + "rhs_arg_0"); + auto rhs_arg_1 = builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {3, 2}), + "rhs_arg_1"); + auto rhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShape(prim_type, {1, 2}), + "rhs_arg_2"); + auto result = builder.Dot( + lhs_constant, builder.ConcatInDim({rhs_arg_0, rhs_arg_1, rhs_arg_2}, 0)); + + std::unique_ptr> arg_0_value_array( + new Array2D({{1.0, 2.0}, {3.0, 4.0}})); + std::unique_ptr> arg_1_value_array( + new Array2D({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}})); + std::unique_ptr> arg_2_value_array( + new Array2D({{1.0, 2.0}})); + + TF_ASSERT_OK_AND_ASSIGN( + auto arg_0_value, + client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_0_value_array))); + TF_ASSERT_OK_AND_ASSIGN( + auto arg_1_value, + client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_1_value_array))); + TF_ASSERT_OK_AND_ASSIGN( + auto arg_2_value, + client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_2_value_array))); + + Array2D expected({{53.0, 74.0}, {45.0, 66.0}}); + ComputeAndCompareR2( + &builder, expected, + {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()}, error_spec_); +} + +TEST_F(DotOperationTest, DotOfConcatOptimizationWithConstRHS) { + auto prim_type = primitive_util::NativeToPrimitiveType(); + + std::unique_ptr> constant_rhs_array( + new Array2D({{1.0, 2.0}, + {3.0, 4.0}, + {5.0, 6.0}, + {6.0, 5.0}, + {4.0, 3.0}, + {2.0, 1.0}})); + + ComputationBuilder builder(client_, TestName()); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto lhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), + "lhs_arg_0"); + auto lhs_arg_1 = builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {2, 3}), + "lhs_arg_1"); + auto lhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShape(prim_type, {2, 1}), + "lhs_arg_2"); + auto result = builder.Dot( + builder.ConcatInDim({lhs_arg_0, lhs_arg_1, lhs_arg_2}, 1), rhs_constant); + + std::unique_ptr> arg_0_value_array( + new Array2D({{1.0, 2.0}, {3.0, 4.0}})); + std::unique_ptr> arg_1_value_array( + new Array2D({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})); + std::unique_ptr> arg_2_value_array( + new Array2D({{1.0}, {2.0}})); + + TF_ASSERT_OK_AND_ASSIGN( + auto arg_0_value, + client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_0_value_array))); + TF_ASSERT_OK_AND_ASSIGN( + auto arg_1_value, + client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_1_value_array))); + TF_ASSERT_OK_AND_ASSIGN( + auto arg_2_value, + client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_2_value_array))); + + Array2D expected({{38.0, 36.0}, {93.0, 91.0}}); + ComputeAndCompareR2( + &builder, expected, + {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()}, error_spec_); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 8baaf39e3cf8fa7f6fa4a0224c1297f82e0d92aa..59be32a8ff584a6189302a0835ba74b2e08956b1 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -559,20 +559,20 @@ void BM_DynamicSlice(int num_iters) { auto computation = builder.Build().ConsumeValueOrDie(); // Initialize and transfer parameter buffer. - auto shape_size_fn = [client](const Shape& shape) { - return client->backend().transfer_manager()->GetByteSizeRequirement(shape); - }; - auto buffer = ScopedShapedBuffer::Allocate(start_indices_shape, &allocator, 0, - shape_size_fn) + auto buffer = client->backend() + .transfer_manager() + ->AllocateScopedShapedBuffer( + start_indices_shape, &allocator, /*device_ordinal=*/0) .ConsumeValueOrDie(); auto start_indices_literal = Literal::CreateR1({0, 1, 2, 3}); ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - executors[device_ordinal], *start_indices_literal, - buffer->mutable_buffer({}))); + executors[device_ordinal], *start_indices_literal, *buffer)); std::unique_ptr executable = - client->Compile(computation, {&buffer->shape()}, ExecutableBuildOptions()) + client + ->Compile(computation, {&buffer->on_host_shape()}, + ExecutableBuildOptions()) .ConsumeValueOrDie(); // Run some warm-up executions. diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index 2686afccc216095345dbb7b43e916fbbe7c8ea39..a292eab1d198fbf69c6dc81c780487ea46756f72 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -816,7 +816,8 @@ void BM_ParallelFusion(int num_iters) { std::unique_ptr executable = client ->Compile(computation, - {&buffer0->shape(), &buffer1->shape(), &buffer2->shape()}, + {&buffer0->on_host_shape(), &buffer1->on_host_shape(), + &buffer2->on_host_shape()}, ExecutableBuildOptions()) .ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index d73c05ff92578209143e0679558848160cae99bd..a27e0f2c106c2ffa2ba108e1963e7111fd347482 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -15,13 +15,22 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include #include #include #include +#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -30,44 +39,237 @@ namespace se = ::perftools::gputools; namespace xla { +namespace { + +using tensorflow::StringPiece; +using tensorflow::gtl::ArraySlice; +using tensorflow::gtl::optional; + +constexpr char kInterpreter[] = "interpreter"; + +// Helper functions to get test and reference platforms. +se::Platform* GetReferencePlatform() { + auto result = PlatformUtil::GetPlatform(kInterpreter); + TF_CHECK_OK(result.status()) << "could not get interpreter platform"; + return result.ValueOrDie(); +} + +se::Platform* GetTestPlatform() { + auto result = PlatformUtil::GetDefaultPlatform(); + TF_CHECK_OK(result.status()) << "could not get test platform"; + return result.ValueOrDie(); +} + +bool ProgramShapesEqual(const ProgramShape& lhs, const ProgramShape& rhs) { + if (lhs.parameters_size() != rhs.parameters_size()) { + return false; + } + for (int i = 0; i < lhs.parameters_size(); i++) { + if (!ShapeUtil::Equal(lhs.parameters(i), rhs.parameters(i))) { + return false; + } + } + return ShapeUtil::Equal(lhs.result(), rhs.result()); +} + +ProgramShape GetProgramShapeWithLayout(const HloModule& module) { + ProgramShape program_shape; + const auto* entry = module.entry_computation(); + for (const auto* param : entry->parameter_instructions()) { + *program_shape.add_parameters() = param->shape(); + *program_shape.add_parameter_names() = param->name(); + } + *program_shape.mutable_result() = entry->root_instruction()->shape(); + return program_shape; +} + +} // namespace + +HloTestBase::HloTestBase() + : HloTestBase(GetTestPlatform(), GetReferencePlatform()) {} + +HloTestBase::HloTestBase(se::Platform* test_platform, + se::Platform* reference_platform) + : test_runner_(test_platform), reference_runner_(reference_platform) { + hlo_verifier_ = MakeUnique([this](const Shape& shape) { + return backend().transfer_manager()->GetByteSizeRequirement(shape); + }); +} + /* static */ std::unique_ptr HloTestBase::CreateNewModule() { HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + return MakeUnique(TestName(), VersionedComputationHandle(), + config); +} +/*static*/ DebugOptions HloTestBase::GetDebugOptionsForTest() { auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); // TODO(b/38354253): Change tests to use Parameters instead of Constants. debug_options.add_xla_disable_hlo_passes("constant_folding"); + return debug_options; +} - config.set_debug_options(debug_options); - - return MakeUnique(TestName(), VersionedComputationHandle(), - config); +StatusOr> HloTestBase::Execute( + std::unique_ptr module, + tensorflow::gtl::ArraySlice arguments) { + return test_runner_.Execute(std::move(module), arguments); } -StatusOr HloTestBase::Execute( +std::unique_ptr HloTestBase::ExecuteAndTransfer( std::unique_ptr module, - tensorflow::gtl::ArraySlice - arguments, - Shape* result_shape) { - return runner_.Execute(std::move(module), arguments, result_shape); + tensorflow::gtl::ArraySlice arguments) { + return test_runner_.Execute(std::move(module), arguments).ValueOrDie(); } -se::DeviceMemoryBase HloTestBase::TransferToDevice(const Literal& literal) { - return runner_.TransferToDevice(literal).ValueOrDie(); +StatusOr> HloTestBase::MakeReferenceModule( + const HloModule& test_module, + const std::function& reference_preprocessor) { + std::unique_ptr reference_module = test_module.Clone(); + const auto& program_shape = GetProgramShapeWithLayout(test_module); + + if (reference_preprocessor != nullptr) { + reference_preprocessor(reference_module.get()); + if (!ProgramShapesEqual(program_shape, + GetProgramShapeWithLayout(*reference_module))) { + return InvalidArgument( + "reference preprocessor must not modify the program shape"); + } + } + TF_RETURN_IF_ERROR(VerifyHloModule(*reference_runner_.backend().platform(), + reference_module.get())); + return std::move(reference_module); } -std::unique_ptr HloTestBase::TransferFromDevice( - const Shape& shape, se::DeviceMemoryBase device_base) { - return runner_.TransferFromDevice(shape, device_base).ValueOrDie(); +template +StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( + std::unique_ptr module, const ArraySlice arguments, + const optional& error, bool run_hlo_passes, + const std::function& reference_preprocessor) { + static_assert( + std::is_same::value || + std::is_same, LiteralPtr>::value, + "The LiteralPtr type only accepts Literal* or std::unique_ptr."); + TF_RETURN_IF_ERROR( + VerifyHloModule(*test_runner_.backend().platform(), module.get())); + TF_ASSIGN_OR_RETURN(auto reference_module, + MakeReferenceModule(*module, reference_preprocessor)); + + // Execute on two backends. + TF_ASSIGN_OR_RETURN( + auto test, + test_runner_.Execute(std::move(module), arguments, run_hlo_passes)); + TF_ASSIGN_OR_RETURN(auto reference, + reference_runner_.Execute(std::move(reference_module), + arguments, run_hlo_passes)); + return LiteralTestUtil::NearOrEqual(/*expected=*/*reference, /*actual=*/*test, + error); } -std::unique_ptr HloTestBase::ExecuteAndTransfer( - std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments) { - return runner_.ExecuteAndTransfer(std::move(module), arguments).ValueOrDie(); +template +::testing::AssertionResult HloTestBase::RunAndCompare( + std::unique_ptr module, const ArraySlice arguments, + const optional& error, + const std::function& reference_preprocessor) { + auto result = + RunAndCompareInternal(std::move(module), arguments, error, + /*run_hlo_passes=*/true, reference_preprocessor); + if (!result.ok()) { + return ::testing::AssertionFailure() << result.status(); + } + return result.ValueOrDie(); +} + +template +::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( + std::unique_ptr module, const ArraySlice arguments, + const optional& error, + const std::function& reference_preprocessor) { + auto result = + RunAndCompareInternal(std::move(module), arguments, error, + /*run_hlo_passes=*/false, reference_preprocessor); + if (!result.ok()) { + return ::testing::AssertionFailure() << result.status(); + } + return result.ValueOrDie(); +} + +::testing::AssertionResult HloTestBase::RunAndCompare( + std::unique_ptr module, const optional& error, + const std::function& reference_preprocessor) { + const auto& fake_arguments = + MakeFakeArguments(module.get()).ConsumeValueOrDie(); + return RunAndCompare>( + std::move(module), fake_arguments, error, reference_preprocessor); +} + +::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( + std::unique_ptr module, const optional& error, + const std::function& reference_preprocessor) { + const auto& fake_arguments = + MakeFakeArguments(module.get()).ConsumeValueOrDie(); + return RunAndCompareNoHloPasses>( + std::move(module), fake_arguments, error, reference_preprocessor); +} + +::testing::AssertionResult HloTestBase::RunAndCompare( + const StringPiece hlo_string, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor) { + auto module_or_status = + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + if (!module_or_status.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module_or_status.status().ToString(); + } + return RunAndCompare(module_or_status.ConsumeValueOrDie(), error, + reference_preprocessor); +} + +::testing::AssertionResult HloTestBase::RunAndCompareFromFile( + const string& filename, const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor) { + auto module_or_status = + HloRunner::ReadModule(filename, GetDebugOptionsForTest()); + if (!module_or_status.ok()) { + return ::testing::AssertionFailure() + << "failed reading hlo module from file"; + } + return RunAndCompare(module_or_status.ConsumeValueOrDie(), error, + reference_preprocessor); +} + +::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( + const StringPiece hlo_string, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor) { + auto module_or_status = + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + if (!module_or_status.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module_or_status.status().ToString(); + } + return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error, + reference_preprocessor); +} + +::testing::AssertionResult HloTestBase::RunAndCompareNoHloPassesFromFile( + const string& filename, const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor) { + auto module_or_status = + HloRunner::ReadModule(filename, GetDebugOptionsForTest()); + if (!module_or_status.ok()) { + return ::testing::AssertionFailure() + << "failed reading hlo module from file"; + } + return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error, + reference_preprocessor); } -Backend& HloTestBase::backend() { return runner_.backend(); } +Backend& HloTestBase::backend() { return test_runner_.backend(); } /* static */ string HloTestBase::TestName() { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 7f068dce36be3546298de2f06bf6d33446d07ca2..4aea9fc9fd027231106e529eb16bcd43f23fbe1c 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -24,52 +24,150 @@ limitations under the License. #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_runner.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/test.h" namespace xla { -// A base class for tests which build and run HLO code. This is a lower level of -// abstraction than using the client interface and enables, for one, explicitly -// building a graph of HLO instructions to run. +// A base class for tests which build and/or run HLO code. The class includes +// support for running an HLO module on two platforms and compare the results. +// This is a lower level of abstraction than using the client interface and +// enables, for one, explicitly building a graph of HLO instructions to run. +// +// This can also be used to write text/file-based test cases. Note that the test +// target is responsible for linking the needed backends. A covenient way to do +// this is to make it an xla_test: it will generate test targets linking with +// the respective backends, which will be used as the test backend; the +// interpreter backend is already linked with hlo_test_base so it will be the +// default reference backend. For example, if you want to compare both cpu vs. +// interpreter, and gpu vs. interpreter, you can: +// +// xla_test ( +// name = "sample_text_test", +// srcs = ["sample_text_test.cc"], +// backends = [ +// "cpu", +// "gpu", +// ], +// deps = [ +// "//third_party/tensorflow/compiler/xla/tests:hlo_test_base", +// ... +// ], +// ) +// +// For a more detailed example, see "../tests/sample_text_test.cc". class HloTestBase : public ::testing::Test { protected: - HloTestBase() {} + // This uses the interpreter backend as the reference backend and + // automatically finds another supported backend as the test backend. If the + // interpreter is the only supported backend, it will be both the test backend + // and the reference backend. + HloTestBase(); + + // If your test doesn't use interpreter as the reference backend, you can use + // this constructor. Note that your test target is responsible for linking in + // both needed backends. + HloTestBase(::perftools::gputools::Platform* test_platform, + ::perftools::gputools::Platform* reference_platform); ~HloTestBase() override {} // Creates a new HLO module for a test. The module created will have // TestName() for its name; it will also automatically populate its debug - // options from command-line flags. It's recommended to use this method to - // create all HloModules for tests. + // options from command-line flags. If you want a fresh HloModule object and + // then add HloComputations to it, it's recommended to use this method in your + // tests. static std::unique_ptr CreateNewModule(); - // Executes the given module and returns a global data handle. - StatusOr Execute( + // Populates debug options from command-line flags and adjusts the options for + // testing. It is recommended to use this when you need to pass in + // DebugOptions, e.g. when creating a module from a string or a file. + static DebugOptions GetDebugOptionsForTest(); + + // Executes the given module and return the result as a Literal. + StatusOr> Execute( std::unique_ptr module, - tensorflow::gtl::ArraySlice - arguments, - Shape* result_shape); + tensorflow::gtl::ArraySlice arguments); - // Transfers the given literal to the device and returns the data handle. - perftools::gputools::DeviceMemoryBase TransferToDevice( - const Literal& literal); + std::unique_ptr ExecuteAndTransfer( + std::unique_ptr module, + tensorflow::gtl::ArraySlice arguments); + + // Executes the given hlo module on two backends and compares results. + // + // 'arguments': the input of the hlo module. The LiteralPtr type accepts + // Literal* or std::unique_ptr. + // + // 'error': if has value, expects the results to be near (within the error + // bound). Otherwise, expects the results to be equal. + // + // 'reference_preprocessor': the module should be ready to run on the test + // backend, but it might need to be tailored so that it is able to run on the + // reference backend. Note that the program shape of the module must not be + // modified. + template + ::testing::AssertionResult RunAndCompare( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + + // Same as above, except that the module will be executed without Hlo + // optimization. + template + ::testing::AssertionResult RunAndCompareNoHloPasses( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; - // Transfers the array referred to by the given handle from the device and - // returns as a Literal. - std::unique_ptr TransferFromDevice( - const Shape& shape, perftools::gputools::DeviceMemoryBase device_base); + // Executes an hlo module with fake inputs and compares the results. + ::testing::AssertionResult RunAndCompare( + std::unique_ptr module, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; - // Executes the given module and return the result as a Literal. - std::unique_ptr ExecuteAndTransfer( + // Same as above, except that the module will be executed without Hlo + // optimization. + ::testing::AssertionResult RunAndCompareNoHloPasses( std::unique_ptr module, - tensorflow::gtl::ArraySlice - arguments); + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + + // Convenient wrappers for executing and comparing an hlo module with fake + // input. Module can be passed in directly, or parsed from an hlo_string, + // or loaded from a file. + ::testing::AssertionResult RunAndCompare( + const tensorflow::StringPiece hlo_string, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + ::testing::AssertionResult RunAndCompareFromFile( + const string& filename, const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + ::testing::AssertionResult RunAndCompareNoHloPasses( + const tensorflow::StringPiece hlo_string, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + ::testing::AssertionResult RunAndCompareNoHloPassesFromFile( + const string& filename, const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; // Convenience method to force the layout of a given parameter in a module. // The layout of parameter number 'param_no' in the 'module' is set to @@ -99,14 +197,38 @@ class HloTestBase : public ::testing::Test { ->Clear(); } + // Return an HLO verifier constructed for the test backend. + HloVerifier& verifier() const { return *hlo_verifier_; } + static string TestName(); - // Returns the backend owned by the HloRunner. + // Returns the backend owned by the test runner. Backend& backend(); - HloRunner runner_; + HloRunner test_runner_; + HloRunner reference_runner_; + + std::unique_ptr hlo_verifier_; ErrorSpec error_spec_{0.0001}; + + private: + // Given the test module, makes a reference module that is ready to run on the + // reference platform. This assumes that the given module is ready to run on + // the test platform. + StatusOr> MakeReferenceModule( + const HloModule& test_module, + const std::function& reference_preprocessor); + + // Runs the module on two platforms with or without running hlo passes and + // compares the results. Returns whether the results are near or equal. If any + // error happens before the results are computed, returns the error status. + template + StatusOr<::testing::AssertionResult> RunAndCompareInternal( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + const tensorflow::gtl::optional& error, bool run_hlo_passes, + const std::function& reference_preprocessor); }; } // namespace xla diff --git a/tensorflow/compiler/xla/tests/isolated_convolution.hlo b/tensorflow/compiler/xla/tests/isolated_convolution.hlo new file mode 100644 index 0000000000000000000000000000000000000000..9452780930efbb1ecc13b35cd4ab53678d36c37f --- /dev/null +++ b/tensorflow/compiler/xla/tests/isolated_convolution.hlo @@ -0,0 +1,8 @@ +HloModule convolution.167: + +ENTRY %convolution.167 (parameter.0: f32[16,28,28,128], parameter.1: f32[3,3,128,128]) -> f32[16,28,28,128] { + %parameter.0 = f32[16,28,28,128]{3,0,2,1} parameter(0) + %parameter.1 = f32[3,3,128,128]{3,2,1,0} parameter(1) + ROOT %convolution.167 = f32[16,28,28,128]{3,0,2,1} convolution(f32[16,28,28,128]{3,0,2,1} %parameter.0, f32[3,3,128,128]{3,2,1,0} %parameter.1), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01oi->b01f +} + diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index 75c9a0d3fb5f11bbf051cd94250212faa30d3688..fb425fe6f3cfbb35d7824f3dd1b7d3a2f869313f 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -57,7 +57,8 @@ namespace xla { } for (int i = 0; i < expected.tuple_shapes_size(); ++i) { ::testing::AssertionResult result = - EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i)); + EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i)) + << "mismatch in tuple index " << i; if (!result) { return result; } @@ -100,6 +101,58 @@ namespace xla { ASSERT_EQ(expected.ShortDebugString(), actual.ShortDebugString()); } +/* static */ std::unique_ptr LiteralTestUtil::ConvertBF16ToF32( + const Literal& literal) { + if (ShapeUtil::IsTuple(literal.shape())) { + std::vector> converted_elements; + for (const auto& element : literal.tuple_literals()) { + converted_elements.push_back(ConvertBF16ToF32(element)); + } + return Literal::MakeTupleOwned(std::move(converted_elements)); + } + + if (literal.shape().element_type() != BF16) { + return MakeUnique(literal); + } + Shape converted_shape = literal.shape(); + converted_shape.set_element_type(F32); + auto converted = Literal::CreateFromShape(converted_shape); + if (!ShapeUtil::HasZeroElements(converted_shape)) { + std::vector index(converted_shape.dimensions_size(), 0); + do { + converted->Set(index, + static_cast(literal.Get(index))); + } while (IndexUtil::BumpIndices(converted_shape, &index)); + } + return converted; +} + +/* static */ std::unique_ptr LiteralTestUtil::ConvertF32ToBF16( + const Literal& literal) { + if (ShapeUtil::IsTuple(literal.shape())) { + std::vector> converted_elements; + for (const auto& element : literal.tuple_literals()) { + converted_elements.push_back(ConvertF32ToBF16(element)); + } + return Literal::MakeTupleOwned(std::move(converted_elements)); + } + + if (literal.shape().element_type() != F32) { + return MakeUnique(literal); + } + Shape converted_shape = literal.shape(); + converted_shape.set_element_type(BF16); + auto converted = Literal::CreateFromShape(converted_shape); + if (!ShapeUtil::HasZeroElements(converted_shape)) { + std::vector index(converted_shape.dimensions_size(), 0); + do { + converted->Set( + index, static_cast(literal.Get(index))); + } while (IndexUtil::BumpIndices(converted_shape, &index)); + } + return converted; +} + namespace { string Hostname() { @@ -281,23 +334,45 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, return result; } -/* static */ void LiteralTestUtil::ExpectEqualTuple(const Literal& expected, - const Literal& actual) { +/* static */ ::testing::AssertionResult LiteralTestUtil::EqualTuple( + const Literal& expected, const Literal& actual) { VLOG(1) << "expected: " << expected.ToString(); VLOG(1) << "actual: " << actual.ToString(); - ASSERT_TRUE(ShapeUtil::IsTuple(expected.shape())); - ASSERT_TRUE(ShapeUtil::IsTuple(actual.shape())); + if (!ShapeUtil::IsTuple(expected.shape()) || + !ShapeUtil::IsTuple(actual.shape())) { + return ::testing::AssertionFailure() + << "tuples expected shape = " << expected.shape().ShortDebugString() + << " actual shape = " << actual.shape().ShortDebugString(); + } AssertEqualShapes(expected.shape(), actual.shape()); + + ::testing::AssertionResult err = ::testing::AssertionSuccess(); for (uint64 i = 0; i < expected.tuple_literals_size(); ++i) { + SCOPED_TRACE(tensorflow::strings::StrCat( + "Tuple index ", i, " in ", ShapeUtil::HumanString(expected.shape()))); const auto& expected_element = expected.tuple_literals(i); const auto& actual_element = actual.tuple_literals(i); - if (ShapeUtil::IsTuple(expected_element.shape())) { - ExpectEqualTuple(expected_element, actual_element); - } else { - ExpectEqual(expected_element, actual_element); + + ::testing::AssertionResult res = [&] { + if (ShapeUtil::IsTuple(expected_element.shape())) { + return EqualTuple(expected_element, actual_element); + } else { + return Equal(expected_element, actual_element); + } + }(); + + if (!res && err) { + err = res; } } + + return err; +} + +/* static */ void LiteralTestUtil::ExpectEqualTuple(const Literal& expected, + const Literal& actual) { + EXPECT_TRUE(EqualTuple(expected, actual)); } namespace { @@ -340,6 +415,9 @@ class NearComparator { multi_index_.resize(expected.shape().dimensions_size(), 0); switch (expected.shape().element_type()) { + case BF16: + ExpectLiteralsNear(expected, actual, 0); + break; case F32: ExpectLiteralsNear(expected, actual, 0); break; @@ -525,6 +603,13 @@ void NearComparator::ExpectNear(complex64 expected, complex64 actual, << message; } +template <> +bool NearComparator::ExpectValuesNear(bfloat16 expected, + bfloat16 actual) { + return ExpectValuesNear(static_cast(expected), + static_cast(actual)); +} + } // namespace /* static */ ::testing::AssertionResult LiteralTestUtil::Near( @@ -553,33 +638,33 @@ void NearComparator::ExpectNear(complex64 expected, complex64 actual, if (!ShapeUtil::IsTuple(expected.shape()) || !ShapeUtil::IsTuple(actual.shape())) { return ::testing::AssertionFailure() - << "tuples expected expected shape = " - << expected.shape().ShortDebugString() + << "tuples expected shape = " << expected.shape().ShortDebugString() << " actual shape = " << actual.shape().ShortDebugString(); } AssertEqualShapes(expected.shape(), actual.shape()); + + ::testing::AssertionResult err = ::testing::AssertionSuccess(); for (uint64 i = 0; i < expected.tuple_literals_size(); ++i) { + SCOPED_TRACE(tensorflow::strings::StrCat( + "Tuple index ", i, " in ", ShapeUtil::HumanString(expected.shape()))); const auto& expected_element = expected.tuple_literals(i); const auto& actual_element = actual.tuple_literals(i); - if (ShapeUtil::IsTuple(expected_element.shape())) { - auto ret = NearTuple(expected_element, actual_element, error); - if (!ret) { - return ret; - } - } else if (ShapeUtil::ElementIsFloating(expected_element.shape())) { - auto ret = Near(expected_element, actual_element, error); - if (!ret) { - return ret; - } - } else { - auto ret = Equal(expected_element, actual_element); - if (!ret) { - return ret; + + ::testing::AssertionResult res = [&] { + if (ShapeUtil::IsTuple(expected_element.shape())) { + return NearTuple(expected_element, actual_element, error); + } else if (ShapeUtil::ElementIsFloating(expected_element.shape())) { + return Near(expected_element, actual_element, error); + } else { + return Equal(expected_element, actual_element); } + }(); + + if (err && !res) { + err = res; } } - - return ::testing::AssertionSuccess(); + return err; } /* static */ void LiteralTestUtil::ExpectNearTuple(const Literal& expected, @@ -588,6 +673,32 @@ void NearComparator::ExpectNear(complex64 expected, complex64 actual, EXPECT_TRUE(NearTuple(expected, actual, error)); } +/*static*/ ::testing::AssertionResult LiteralTestUtil::NearOrEqual( + const Literal& expected, const Literal& actual, + const tensorflow::gtl::optional& error) { + bool is_tuple = ShapeUtil::IsTuple(expected.shape()); + if (error.has_value()) { + if (is_tuple) { + VLOG(1) << "Expects near tuple"; + return NearTuple(expected, actual, *error); + } + VLOG(1) << "Expects near"; + return Near(expected, actual, *error); + } + if (is_tuple) { + VLOG(1) << "Expects equal tuple"; + return EqualTuple(expected, actual); + } + VLOG(1) << "Expects equal"; + return Equal(expected, actual); +} + +/*static*/ void LiteralTestUtil::ExpectNearOrEqual( + const Literal& expected, const Literal& actual, + const tensorflow::gtl::optional& error) { + EXPECT_TRUE(NearOrEqual(expected, actual, error)); +} + /* static */ string LiteralTestUtil::MultiIndexAsString( tensorflow::gtl::ArraySlice multi_index) { return tensorflow::strings::StrCat( diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index 467d44b857b74d2a38e9b3f8a32a9b1d39a4a10d..f53553c70170bdcda717e72ffd791016effd0774 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -59,6 +60,16 @@ class LiteralTestUtil { static void AssertEqualShapesAndLayouts(const Shape& expected, const Shape& actual); + // If the given literal's data type is bfloat16, converts it to a float + // literal; otherwise, returns a copy of it. If the literal is a tuple, + // recursively converts its elements. + static std::unique_ptr ConvertBF16ToF32(const Literal& bf16_literal); + + // If the given literal's data type is float, converts it to a bfloat16 + // literal; otherwise, returns a copy of it. If the literal is a tuple, + // recursively converts its elements. + static std::unique_ptr ConvertF32ToBF16(const Literal& f32_literal); + // Asserts that the expected and actual literals are (bitwise) equal for all // elements in the literal. Also, asserts that the rank, dimensions sizes, and // primitive type are equal. @@ -100,6 +111,10 @@ class LiteralTestUtil { static void ExpectR4EqualArray4D(const Array4D& expected, const Literal& actual); + // Returns whether the two tuples are equal. + static ::testing::AssertionResult EqualTuple( + const Literal& expected, const Literal& actual) TF_MUST_USE_RESULT; + // Expects that the values of the elements in the expected and actual tuples // are equal. Tuples are matched recursively. static void ExpectEqualTuple(const Literal& expected, const Literal& actual); @@ -167,6 +182,19 @@ class LiteralTestUtil { static void ExpectNearTuple(const Literal& expected, const Literal& actual, const ErrorSpec& error); + // If the error spec is given, returns whether the expected and the actual are + // within the error bound; otherwise, returns whether they are equal. Tuples + // will be compared recursively. + static ::testing::AssertionResult NearOrEqual( + const Literal& expected, const Literal& actual, + const tensorflow::gtl::optional& error) TF_MUST_USE_RESULT; + + // If the error spec is given, expects the expected and the actual to be near; + // otherwise, expects them to be equal. Tuples will be compared recursively. + static void ExpectNearOrEqual( + const Literal& expected, const Literal& actual, + const tensorflow::gtl::optional& error); + // Returns a multi-dimensional index as a string. For example: '{7, 8}' will // be returned for a 2-dimensional index with dimension 0 index equal to 7, // dimension 1 equal to 8. diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc index 62fab6a22434ba20f5d7c068d876188e0661e02e..b5b95967ff9162301a092f3a57996e0f3f78658f 100644 --- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc +++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc @@ -73,8 +73,8 @@ class LLVMCompilerTest : public ::testing::Test { compiler->SetPostOptimizationHook(post_opt_hook); ASSERT_TRUE(compiler - ->Compile(std::move(hlo_module), - backend_->default_stream_executor()) + ->RunBackend(std::move(hlo_module), + backend_->default_stream_executor()) .ok()); // Test that hooks were called. diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index fbf9739dbceec2b941101881fe28acb38a2003be..e3298e98c67969f97adfdf15d22dfe72592b56aa 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -138,13 +138,13 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) { // Create x as a col-major array. auto x_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout( {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}))); - EXPECT_TRUE(LayoutUtil::Equal(x_array->shape().layout(), + EXPECT_TRUE(LayoutUtil::Equal(x_array->on_device_shape().layout(), LayoutUtil::MakeLayout({0, 1}))); // Create y as a row-major array. auto y_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout( {{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0}))); - EXPECT_TRUE(LayoutUtil::Equal(y_array->shape().layout(), + EXPECT_TRUE(LayoutUtil::Equal(y_array->on_device_shape().layout(), LayoutUtil::MakeLayout({1, 0}))); std::unique_ptr result_colmaj = @@ -179,7 +179,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { DefaultExecutableBuildOptions().set_result_layout( ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, {0, 1})), DefaultExecutableRunOptions()); - EXPECT_TRUE(LayoutUtil::Equal(result_colmaj->shape().layout(), + EXPECT_TRUE(LayoutUtil::Equal(result_colmaj->on_device_shape().layout(), LayoutUtil::MakeLayout({0, 1}))); LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, *ShapedBufferToLiteral(*result_colmaj), @@ -191,7 +191,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { DefaultExecutableBuildOptions().set_result_layout( ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, {1, 0})), DefaultExecutableRunOptions()); - EXPECT_TRUE(LayoutUtil::Equal(result_rowmaj->shape().layout(), + EXPECT_TRUE(LayoutUtil::Equal(result_rowmaj->on_device_shape().layout(), LayoutUtil::MakeLayout({1, 0}))); LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, *ShapedBufferToLiteral(*result_rowmaj), @@ -213,8 +213,8 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) { std::unique_ptr result = ExecuteLocallyOrDie(computation, {x_array.get(), y_array.get()}); - EXPECT_TRUE(ShapeUtil::IsTuple(result->shape())); - EXPECT_EQ(3, ShapeUtil::TupleElementCount(result->shape())); + EXPECT_TRUE(ShapeUtil::IsTuple(result->on_host_shape())); + EXPECT_EQ(3, ShapeUtil::TupleElementCount(result->on_host_shape())); std::unique_ptr result_literal = ShapedBufferToLiteral(*result); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, @@ -241,8 +241,8 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { std::unique_ptr result = ExecuteLocallyOrDie(computation, {x_array.get(), y_array.get()}); - EXPECT_TRUE(ShapeUtil::IsTuple(result->shape())); - EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape())); + EXPECT_TRUE(ShapeUtil::IsTuple(result->on_host_shape())); + EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->on_host_shape())); std::unique_ptr result_literal = ShapedBufferToLiteral(*result); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, @@ -320,8 +320,8 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { std::unique_ptr result = ExecuteLocallyOrDie(computation, {x_buffer.get(), y_buffer.get()}); - EXPECT_TRUE(ShapeUtil::IsTuple(result->shape())); - EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape())); + EXPECT_TRUE(ShapeUtil::IsTuple(result->on_host_shape())); + EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->on_host_shape())); std::unique_ptr result_literal = ShapedBufferToLiteral(*result); LiteralTestUtil::ExpectR2Equal({{56.0f, 46.0f}, {36.0f, 26.0f}}, @@ -874,11 +874,13 @@ XLA_TEST_F(LocalClientExecuteTest, tensorflow::ThreadOptions(), "execute_thread", [&] { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); })); - ASSERT_IS_OK(local_client_->TransferToInfeed( - *Literal::CreateR1({-5.0, 123.0, 42.0}))); + ASSERT_IS_OK(local_client_->TransferToInfeedLocal( + *Literal::CreateR1({-5.0, 123.0, 42.0}), + local_client_->default_device_ordinal())); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - local_client_->TransferFromOutfeed(&shape)); + local_client_->TransferFromOutfeedLocal( + shape, local_client_->default_device_ordinal())); LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, *result); } @@ -904,20 +906,18 @@ void BM_LocalClientOverhead(int num_iters) { builder.Add(x, x); auto computation = builder.Build().ConsumeValueOrDie(); - auto shape_size_fn = [client](const Shape& shape) { - return client->backend().transfer_manager()->GetByteSizeRequirement(shape); - }; - auto buffer = ScopedShapedBuffer::Allocate( - shape, &allocator, /*device_ordinal=*/0, shape_size_fn) - .ConsumeValueOrDie(); + auto buffer = + transfer_manager + ->AllocateScopedShapedBuffer(shape, &allocator, /*device_ordinal=*/0) + .ConsumeValueOrDie(); auto literal = Literal::CreateR2({{0, 0, 0}, {0, 0, 0}}); ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - executors[device_ordinal], *literal, buffer->mutable_buffer({}))); + executors[device_ordinal], *literal, *buffer)); const int kWarmups = 2; - auto executable_status = client->Compile(computation, {&buffer->shape()}, - ExecutableBuildOptions()); + auto executable_status = client->Compile( + computation, {&buffer->on_host_shape()}, ExecutableBuildOptions()); ASSERT_IS_OK(executable_status); std::unique_ptr executable = executable_status.ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index 062a9246e49598d5d03dce8c1f437138923449bf..96b976d25d75d35f46adfd104a03aceb363661eb 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -188,7 +188,7 @@ LocalClientTestBase::ExecuteLocally( const ExecutableRunOptions& run_options) { std::vector argument_layouts(arguments.size()); for (int i = 0; i < arguments.size(); ++i) { - argument_layouts[i] = &arguments[i]->shape(); + argument_layouts[i] = &arguments[i]->on_host_shape(); } TF_ASSIGN_OR_RETURN( std::unique_ptr executable, diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 22d2b917a1d55f4f453e21c2d8fea38e32ff796b..62d24a11fdb164ed6776d1e83877cf3acd319cc6 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -76,8 +76,11 @@ class MultiOutputFusionTest : public HloTestBase { elem_shape2, HloOpcode::kAdd, broadcast, param1)); HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary( elem_shape2, HloOpcode::kSubtract, param1, broadcast)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(elem_shape2, HloOpcode::kDot, sub, add2)); + HloInstruction::CreateDot(elem_shape2, sub, add2, dot_dnums)); auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); if (manual_fusion) { @@ -96,14 +99,13 @@ class MultiOutputFusionTest : public HloTestBase { nullptr); } - Literal input; - input.PopulateWithValue(2.5f, {size, size}); - auto p1 = TransferToDevice(input); - auto p0 = TransferToDevice(*Literal::CreateR0(-9.0f)); + Literal arg1; + arg1.PopulateWithValue(2.5f, {size, size}); Literal expect; expect.PopulateWithValue(size * 1.5f * 3.5f, {size, size}); - auto actual = ExecuteAndTransfer(std::move(hlo_module), {p0, p1}); + auto actual = ExecuteAndTransfer( + std::move(hlo_module), {Literal::CreateR0(-9.0f).get(), &arg1}); LiteralTestUtil::ExpectNear(expect, *actual, error_spec_); } @@ -133,8 +135,11 @@ class MultiOutputFusionTest : public HloTestBase { HloInstruction* reshape = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {size, 1}), add)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1}), HloOpcode::kDot, sub, reshape)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + ShapeUtil::MakeShape(F32, {1}), sub, reshape, dot_dnums)); auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); if (manual_fusion) { @@ -157,11 +162,9 @@ class MultiOutputFusionTest : public HloTestBase { Literal input0, input1; input0.PopulateWithValue(2.5f, {size}); input1.PopulateWithValue(1, {size}); - auto p0 = TransferToDevice(input0); - auto p1 = TransferToDevice(input1); Literal expect = *Literal::CreateR1({size * 1.5f * 3.5f}); - auto actual = ExecuteAndTransfer(std::move(hlo_module), {p0, p1}); + auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1}); LiteralTestUtil::ExpectNear(expect, *actual, error_spec_); } }; diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc index fda4389f479cdc7a659e4d7c8a2facba55e17e83..c260258d6e9af6dee6075c92cf35dac4ed46abed 100644 --- a/tensorflow/compiler/xla/tests/params_test.cc +++ b/tensorflow/compiler/xla/tests/params_test.cc @@ -252,8 +252,8 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { } // Only run the 3,000-parameter tests in opt mode to avoid test timeouts. -// Timeout last observed on 2017-09-12. -#ifndef NDEBUG +// Timeout last observed on 2017-11-20. +#ifdef NDEBUG // TODO(b/65525254) Fails on GPU on 2017-09-10 because we try to reserve too // much space in parameter memory for the kernel. @@ -334,6 +334,106 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU( ComputeAndCompareTuple(&builder, *Literal::MakeTuple(ptrs), param_data); } +// Test large number of parameters flowing into a while-loop. +// Construct conceptually the following HLO graph: +// +// p0 = parameter(0) +// p1 = parameter(1) +// ... +// pN = parameter(N) +// result = while (false) { +// p0 += (1, 1); +// p1 += (1, 1); +// ... +// pN += (1, 1) +// } +// result = {p0, p1, ..., pN} +// +// TODO(b/70173746): Times out during compilation on GPU and CPU backends as of +// 2017-12-12. +XLA_TEST_F(ParamsTest, + DISABLED_ON_CPU(DISABLED_ON_GPU(ManyParametersIntoWhileLoop))) { + ComputationBuilder builder(client_, TestName()); + + std::vector> param_data_owner; + constexpr int kParamCount = 1900; + std::vector params; + std::vector parameter_shapes; + for (int i = 0; i < kParamCount; ++i) { + std::unique_ptr literal = Literal::CreateR1({i, i}); + param_data_owner.push_back( + std::move(client_->TransferToServer(*literal)).ValueOrDie()); + ComputationDataHandle param = + builder.Parameter(i, literal->shape(), "param"); + params.push_back(param); + parameter_shapes.push_back(literal->shape()); + } + + // Add bool parameter for the loop condition. Use a parameter HLO instead of a + // constant because DCE may eliminate the while-body otherwise. + std::unique_ptr bool_literal = Literal::CreateR0(false); + param_data_owner.push_back( + std::move(client_->TransferToServer(*bool_literal)).ValueOrDie()); + ComputationDataHandle bool_param = + builder.Parameter(kParamCount, bool_literal->shape(), "bool_param"); + params.push_back(bool_param); + parameter_shapes.push_back(bool_literal->shape()); + + auto init = builder.Tuple(params); + + // Create a computation for the condition: while(bool_param). + Shape while_shape = ShapeUtil::MakeTupleShape(parameter_shapes); + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto condition_parameter = + builder.Parameter(0, while_shape, "condition_parameter"); + builder.GetTupleElement(condition_parameter, kParamCount); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body. + // Add {1, 1} to the each tuple element. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto body_parameter = builder.Parameter(0, while_shape, "body_parameter"); + std::vector updates; + for (int i = 0; i < kParamCount; ++i) { + auto add = builder.Add(builder.GetTupleElement(body_parameter, i), + builder.ConstantR1({1, 1})); + updates.push_back(add); + } + // Add bool parameter. + updates.push_back(builder.GetTupleElement(body_parameter, kParamCount)); + + builder.Tuple(updates); + body = builder.Build().ConsumeValueOrDie(); + } + + auto loop = builder.While(condition, body, init); + + std::vector outputs; + for (int i = 0; i < kParamCount; ++i) { + outputs.push_back(builder.GetTupleElement(loop, i)); + } + builder.Tuple(outputs); + + std::vector param_data; + param_data.reserve(param_data_owner.size()); + for (const std::unique_ptr& data : param_data_owner) { + param_data.push_back(data.get()); + } + + std::vector> elements; + std::vector ptrs; + for (int i = 0; i < kParamCount; ++i) { + elements.push_back(Literal::CreateR1({i, i})); + ptrs.push_back(elements.back().get()); + } + ComputeAndCompareTuple(&builder, *Literal::MakeTuple(ptrs), param_data); +} + #endif XLA_TEST_F(ParamsTest, diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 7bc3185c367f076c9a7d211c9799557e1a91d92f..b09ccdd679b6c8f628e40f78f58dbd1734926af6 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -352,15 +352,13 @@ XLA_TEST_F(ReduceTest, ReduceR2_111x50_01_To_R1) { XLA_TEST_F(ReduceTest, ReduceR2_1024x1024_To_R1) { RunR2ToR1Test(1024, 1024); } XLA_TEST_F(ReduceTest, ReduceR2_1000x1500_To_R1) { RunR2ToR1Test(1000, 1500); } -// TODO(b/34969189): Invalid CAS generated on GPU. -XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(AndReduceAllOnesR1_10_Pred)) { +XLA_TEST_F(ReduceTest, AndReduceAllOnesR1_10_Pred) { constexpr int element_count = 10; std::vector input(element_count, 1); RunR1ToR0PredTest(/*and_reduce=*/true, input); } -// TODO(b/34969189): Invalid CAS generated on GPU. -XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(AndReduceOnesAndZerosR1_10_Pred)) { +XLA_TEST_F(ReduceTest, AndReduceOnesAndZerosR1_10_Pred) { constexpr int element_count = 10; std::vector input(element_count); for (int i = 0; i < element_count; ++i) { @@ -369,15 +367,13 @@ XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(AndReduceOnesAndZerosR1_10_Pred)) { RunR1ToR0PredTest(/*and_reduce=*/true, input); } -// TODO(b/34969189): Invalid CAS generated on GPU. -XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OrReduceAllOnesR1_10_Pred)) { +XLA_TEST_F(ReduceTest, OrReduceAllOnesR1_10_Pred) { constexpr int element_count = 10; std::vector input(element_count, 1); RunR1ToR0PredTest(/*and_reduce=*/false, input); } -// TODO(b/34969189): Invalid CAS generated on GPU. -XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OrReduceOnesAndZerosR1_10_Pred)) { +XLA_TEST_F(ReduceTest, OrReduceOnesAndZerosR1_10_Pred) { constexpr int element_count = 10; std::vector input(element_count); for (int i = 0; i < element_count; ++i) { diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 6c9b62b48d8bb2ad93b2ce98839e5e52d8eaa8cc..bf81514bc900792d6c687a6044b83e91920ed8bb 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -41,16 +41,40 @@ limitations under the License. namespace xla { namespace { -class ReduceWindowTest : public ClientLibraryTestBase { +#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 +// Tests both F32 and BF16. +static std::array use_bfloat16_params{false, true}; +#else +// Only tests F32. +static std::array use_bfloat16_params{false}; +#endif + +class ReduceWindowTestBase : public ClientLibraryTestBase { public: - ReduceWindowTest() : builder_(client_, TestName()) {} + ErrorSpec DefaultErrorSpec() const { + if (use_bfloat16()) { + return ErrorSpec(1e-1, 5e-2); + } else { + return ErrorSpec(1e-3, 1e-3); + } + } +}; + +class ReduceWindowTest : public ::testing::WithParamInterface, + public ReduceWindowTestBase { + public: + ReduceWindowTest() : builder_(client_, TestName()) { + set_use_bfloat16(GetParam()); + } void ReduceWindowAdd(const ComputationDataHandle& input, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { - builder_.ReduceWindow(input, builder_.ConstantR0(0.0f), - CreateScalarAddComputation(F32, &builder_), + auto init = + CreateConstantFromLiteral(*Literal::CreateR0(0.0f), &builder_); + builder_.ReduceWindow(input, init, + CreateScalarAddComputation(FloatType(), &builder_), window_dimensions, window_strides, padding); } @@ -58,30 +82,32 @@ class ReduceWindowTest : public ClientLibraryTestBase { tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { - builder_.ReduceWindow( - input, builder_.ConstantLiteral(Literal::MinValue(F32)), - CreateScalarMax(), window_dimensions, window_strides, padding); + auto init = CreateConstantFromLiteral(Literal::MinValue(F32), &builder_); + builder_.ReduceWindow(input, init, CreateScalarMax(), window_dimensions, + window_strides, padding); } void ReduceWindowMin(const ComputationDataHandle& input, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { - builder_.ReduceWindow(input, - builder_.ConstantLiteral(Literal::MaxValue(F32)), - CreateScalarMinComputation(F32, &builder_), + auto init = CreateConstantFromLiteral(Literal::MaxValue(F32), &builder_); + builder_.ReduceWindow(input, init, + CreateScalarMinComputation(FloatType(), &builder_), window_dimensions, window_strides, padding); } ComputationBuilder builder_; }; -TEST_F(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { - const auto input = builder_.ConstantR1({1, 1, 1, 1}); - const auto init_value = builder_.ConstantR0(0); +TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { + const auto input = CreateConstantFromLiteral( + *Literal::CreateR1({1, 1, 1, 1}), &builder_); + const auto init_value = + CreateConstantFromLiteral(*Literal::CreateR0(0), &builder_); TF_ASSERT_OK(builder_.first_error()); builder_.ReduceWindow(input, init_value, - CreateScalarAddComputation(F32, &builder_), + CreateScalarAddComputation(FloatType(), &builder_), /*window_dimensions=*/{1, 2}, /*window_strides=*/{1}, Padding::kValid); ASSERT_EQ(builder_.first_error().code(), tensorflow::error::INVALID_ARGUMENT) @@ -90,79 +116,97 @@ TEST_F(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { ::testing::HasSubstr("Want input dimensions size")); } -TEST_F(ReduceWindowTest, Min3In5Stride2) { - const auto input = builder_.ConstantR1({10000, 1000, 100, 10, 1}); +// Regression test for b/68964348. +TEST_P(ReduceWindowTest, R0ReduceWindow) { + const auto input = + CreateConstantFromLiteral(*Literal::CreateR0(42.0), &builder_); + const auto init = + CreateConstantFromLiteral(*Literal::CreateR0(1.0), &builder_); + builder_.ReduceWindow(input, init, + CreateScalarAddComputation(FloatType(), &builder_), + /*window_dimensions=*/{}, + /*window_strides=*/{}, Padding::kSame); + ComputeAndCompareLiteral(&builder_, *Literal::CreateR0(43.0), {}, + ErrorSpec(0.00001)); +} + +TEST_P(ReduceWindowTest, Min3In5Stride2) { + const auto input = CreateConstantFromLiteral( + *Literal::CreateR1({10000, 1000, 100, 10, 1}), &builder_); ReduceWindowMin(input, {3}, {2}, Padding::kValid); - ComputeAndCompareR1(&builder_, {100, 1}, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateR1({100, 1}), {}, + ErrorSpec(0.00001)); } -XLA_TEST_F(ReduceWindowTest, ZeroElementSmall) { +XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) { Array4D input_array(1, 0, 2, 1); - - const auto input = builder_.ConstantR4FromArray4D(input_array); + const auto input = CreateConstantFromArray(input_array, &builder_); Padding padding = Padding::kSame; ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, NonSquareSmall) { +TEST_P(ReduceWindowTest, NonSquareSmall) { Array4D input_array(1, 2, 2, 1); - input_array.FillRandom(2.f); + input_array.FillRandom(2.f, 2.f); + const auto input = CreateConstantFromArray(input_array, &builder_); - const auto input = builder_.ConstantR4FromArray4D(input_array); Padding padding = Padding::kSame; ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, MiddleDimsSmall) { +TEST_P(ReduceWindowTest, MiddleDimsSmall) { Array4D input_array(1, 3, 3, 1); - input_array.FillRandom(2.f); - - const auto input = builder_.ConstantR4FromArray4D(input_array); + input_array.FillRandom(2.f, 2.f); + const auto input = CreateConstantFromArray(input_array, &builder_); Padding padding = Padding::kSame; ReduceWindowAdd(input, {1, 1, 1, 1}, {1, 2, 2, 1}, padding); auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 1, 1}, {1, 2, 2, 1}, padding); - ComputeAndCompareR4(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, Along2ndMinorDim) { +TEST_P(ReduceWindowTest, Along2ndMinorDim) { Array4D input_array(3, 6, 7, 32); - input_array.FillRandom(2.f); + input_array.FillRandom(2.f, 2.f); + const auto input = CreateConstantFromArray(input_array, &builder_); // The parameters of this reduction mimic feature norm (e.g. LRN). int lrn_diameter = 7; // diameter = 2*radius + 1 --> must be odd - const auto input = builder_.ConstantR4FromArray4D(input_array); Padding padding = Padding::kSame; ReduceWindowAdd(input, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding); auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, AmongMajor2Dims) { +TEST_P(ReduceWindowTest, AmongMajor2Dims) { Array4D input_array(4, 4, 6, 8); input_array.FillWithMinorDimNum(); + const auto input_data_handle = + CreateConstantFromArray(input_array, &builder_); int win_len = 3; int win_stride = 1; Padding padding = Padding::kSame; - const auto input_data_handle = - builder_.ConstantR4FromArray4D(input_array); // Reduce only along the x and y dimensions, according to the win_len. ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); @@ -170,18 +214,20 @@ TEST_F(ReduceWindowTest, AmongMajor2Dims) { auto result = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3)); + + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, + DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, AmongMajor2DimsMediumSize) { +TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) { Array4D input_array(9, 12, 4, 89); - input_array.FillRandom(2.0f); + input_array.FillRandom(2.f, 2.f); int win_len = 3; int win_stride = 2; const auto input_data_handle = - builder_.ConstantR4FromArray4D(input_array); + CreateConstantFromArray(input_array, &builder_); Padding padding = Padding::kSame; // Reduce only along the x and y dimensions, according to the win_len. @@ -192,56 +238,28 @@ TEST_F(ReduceWindowTest, AmongMajor2DimsMediumSize) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3)); -} - -// TODO(b/32173947): Test support for arbitrary-sized padding. -TEST_F(ReduceWindowTest, DISABLED_AmongMajor2DimsMediumSizeLargePadding) { - Array4D input_array(9, 12, 4, 89); // simulate Dim0IsMinor layout - input_array.FillRandom(2.0f); - - int64 rank = 4; - int win_len = 3; - int win_stride = 2; - - const auto input_data_handle = - builder_.ConstantR4FromArray4D(input_array); - - Padding padding = Padding::kSame; - // Reduce only along the x and y dimensions, according to the win_len. - // Create padding vector with large padding values in the reduction dims. - std::vector> low_high_padding; - low_high_padding.resize(rank, {4, 4}); - - builder_.ReduceWindowWithGeneralPadding( - input_data_handle, builder_.ConstantR0(0.0f), - CreateScalarAddComputation(F32, &builder_), {win_len, win_len, 1, 1}, - {win_stride, win_stride, 1, 1}, low_high_padding); - - auto result = ReferenceUtil::ReduceWindow4DAdd( - input_array, 0.0f, {win_len, win_len, 1, 1}, - {win_stride, win_stride, 1, 1}, padding); - - ComputeAndCompareR4(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, + DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x2) { +XLA_TEST_P(ReduceWindowTest, Add1x1x2In2x1x2) { Array3D input_array(2, 1, 2); input_array(0, 0, 0) = 1000; input_array(0, 0, 1) = 100; input_array(1, 0, 0) = 10; input_array(1, 0, 1) = 1; - auto input = builder_.ConstantR3FromArray3D(input_array); + const auto input = CreateConstantFromArray(input_array, &builder_); ReduceWindowAdd(input, {1, 1, 2}, {1, 1, 1}, Padding::kValid); Array3D expected(2, 1, 1); expected(0, 0, 0) = 1100; expected(1, 0, 0) = 11; - ComputeAndCompareR3(&builder_, expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(expected), {}, + DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x3Stride1x1x2) { +XLA_TEST_P(ReduceWindowTest, Add1x1x2In2x1x3Stride1x1x2) { Array3D input_array(2, 1, 3); input_array(0, 0, 0) = 100; input_array(0, 0, 1) = 10; @@ -249,17 +267,18 @@ XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x3Stride1x1x2) { input_array(1, 0, 0) = 500; input_array(1, 0, 1) = 50; input_array(1, 0, 2) = 5; - auto input = builder_.ConstantR3FromArray3D(input_array); + const auto input = CreateConstantFromArray(input_array, &builder_); ReduceWindowAdd(input, {1, 1, 2}, {1, 1, 2}, Padding::kValid); Array3D expected(2, 1, 1); expected(0, 0, 0) = 110; expected(1, 0, 0) = 550; - ComputeAndCompareR3(&builder_, expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(expected), {}, + DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x3SamePad) { +XLA_TEST_P(ReduceWindowTest, Add1x1x2In2x1x3SamePad) { Array3D input_array(2, 1, 3); input_array(0, 0, 0) = 100; input_array(0, 0, 1) = 10; @@ -267,7 +286,7 @@ XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x3SamePad) { input_array(1, 0, 0) = 500; input_array(1, 0, 1) = 50; input_array(1, 0, 2) = 5; - auto input = builder_.ConstantR3FromArray3D(input_array); + const auto input = CreateConstantFromArray(input_array, &builder_); ReduceWindowAdd(input, {1, 1, 2}, {1, 1, 1}, Padding::kSame); @@ -278,30 +297,34 @@ XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x3SamePad) { expected(1, 0, 0) = 550; expected(1, 0, 1) = 55; expected(1, 0, 2) = 5; - ComputeAndCompareR3(&builder_, expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(expected), {}, + DefaultErrorSpec()); } // Tests a reduction function that is not a simple add/min/max/etc. -XLA_TEST_F(ReduceWindowTest, NonstandardReduceFunction) { +XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { Array4D input_array(1, 2, 2, 1); input_array(0, 0, 0, 0) = 1; input_array(0, 0, 1, 0) = 2; input_array(0, 1, 0, 0) = 3; input_array(0, 1, 1, 0) = 4; + const auto input = CreateConstantFromArray(input_array, &builder_); - const auto input = builder_.ConstantR4FromArray4D(input_array); Padding padding = Padding::kValid; - - const Shape scalar = ShapeUtil::MakeShape(F32, {}); + const Shape scalar = ShapeUtil::MakeShape(FloatType(), {}); auto b = builder_.CreateSubBuilder("unusual"); auto lhs = b->Parameter(0, scalar, "lhs"); auto rhs = b->Parameter(1, scalar, "rhs"); - b->Min(b->Add(lhs, rhs), b->ConstantR0(8.0f)); + b->Min(b->Add(lhs, rhs), + CreateConstantFromLiteral(*Literal::CreateR0(8.0f), b.get())); Computation reduce_fn = b->BuildAndNoteError(); - builder_.ReduceWindow(input, builder_.ConstantR0(3.0f), reduce_fn, - /*window_dimensions=*/{1, 1, 2, 1}, - /*window_strides=*/{1, 1, 1, 1}, padding); + builder_.ReduceWindow( + input, + CreateConstantFromLiteral(*Literal::CreateR0(3.0f), &builder_), + reduce_fn, + /*window_dimensions=*/{1, 1, 2, 1}, + /*window_strides=*/{1, 1, 1, 1}, padding); const auto reduce_func = [](float arg1, float arg2) { return std::min(arg1 + arg2, 8.0f); @@ -312,17 +335,19 @@ XLA_TEST_F(ReduceWindowTest, NonstandardReduceFunction) { /*window=*/{1, 1, 2, 1}, /*stride=*/{1, 1, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *expected, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*expected), {}, + DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, R4UnitWindow) { +TEST_P(ReduceWindowTest, R4UnitWindow) { Array4D input_array(13, 12, 8, 15); - input_array.Fill(1.0f); + input_array.FillRandom(2.f, 2.f); std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({0, 3, 2, 1})); - ComputationDataHandle input = - builder_.Parameter(0, input_literal->shape(), "operand"); + ComputationDataHandle input; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "parameter", &builder_, &input); Padding padding = Padding::kSame; ReduceWindowAdd(input, {1, 1, 7, 1}, {1, 4, 1, 1}, padding); @@ -330,15 +355,11 @@ TEST_F(ReduceWindowTest, R4UnitWindow) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 7, 1}, {1, 4, 1, 1}, padding); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); - ComputeAndCompareR4(&builder_, *res, {input_data.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + {input_data.get()}, DefaultErrorSpec()); } -XLA_TEST_F(HloTestBase, R6AddMultipleStrides) { - auto b = HloComputation::Builder(TestName()); - +XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { std::vector input_dims(6, 8); auto shape = ShapeUtil::MakeShape(F32, input_dims); @@ -348,56 +369,15 @@ XLA_TEST_F(HloTestBase, R6AddMultipleStrides) { }; TF_EXPECT_OK(arg_literal->Populate(generator)); - auto input = - b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); - - auto init_value = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.f))); - - HloComputation::Builder add_computation("add"); - Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); - auto param_lhs = add_computation.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape, "lhs")); - auto param_rhs = add_computation.AddInstruction( - HloInstruction::CreateParameter(1, scalar_shape, "rhs")); - add_computation.AddInstruction(HloInstruction::CreateBinary( - scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); - - auto module = CreateNewModule(); - auto add_func = module->AddEmbeddedComputation(add_computation.Build()); - - WindowDimension trivial_dim; - trivial_dim.set_size(1); - trivial_dim.set_stride(1); - trivial_dim.set_padding_low(0); - trivial_dim.set_padding_high(0); - trivial_dim.set_window_dilation(1); - trivial_dim.set_base_dilation(1); - - WindowDimension active_dim; - active_dim.set_size(3); - active_dim.set_stride(1); - active_dim.set_padding_low(0); - active_dim.set_padding_high(0); - active_dim.set_window_dilation(1); - active_dim.set_base_dilation(1); - - Window window; - *window.add_dimensions() = active_dim; - *window.add_dimensions() = trivial_dim; - *window.add_dimensions() = active_dim; - *window.add_dimensions() = active_dim; - *window.add_dimensions() = trivial_dim; - *window.add_dimensions() = trivial_dim; - - // Non-monotonic output layout with minor dims trivial. + const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); + + Padding padding = Padding::kValid; + ReduceWindowAdd(input, {3, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding); + std::vector output_layout = {1, 5, 3, 2, 0, 4}; std::vector output_dims = {6, 8, 6, 6, 8, 8}; Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout); - b.AddInstruction(HloInstruction::CreateReduceWindow( - result_shape, input, init_value, window, add_func)); - std::unique_ptr expected = Literal::CreateFromShape(result_shape); auto out_generator = [&](tensorflow::gtl::ArraySlice indexes) -> float { @@ -405,82 +385,37 @@ XLA_TEST_F(HloTestBase, R6AddMultipleStrides) { }; TF_EXPECT_OK(expected->Populate(out_generator)); - module->AddEntryComputation(b.Build()); - auto actual = ExecuteAndTransfer(std::move(module), {}); - - LiteralTestUtil::ExpectNear(*actual, *expected, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); } -XLA_TEST_F(HloTestBase, R6Add) { - auto b = HloComputation::Builder(TestName()); - +XLA_TEST_P(ReduceWindowTest, R6Add) { std::vector input_dims(6, 8); + auto shape = ShapeUtil::MakeShape(F32, input_dims); + std::unique_ptr arg_literal = Literal::CreateFullWithMonotonicDim0MajorLayout(input_dims, 1.0f); - auto input = - b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); - - auto init_value = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.f))); - - HloComputation::Builder add_computation("add"); - Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); - auto param_lhs = add_computation.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape, "lhs")); - auto param_rhs = add_computation.AddInstruction( - HloInstruction::CreateParameter(1, scalar_shape, "rhs")); - add_computation.AddInstruction(HloInstruction::CreateBinary( - scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); - - auto module = CreateNewModule(); - auto add_func = module->AddEmbeddedComputation(add_computation.Build()); - - WindowDimension trivial_dim; - trivial_dim.set_size(1); - trivial_dim.set_stride(1); - trivial_dim.set_padding_low(0); - trivial_dim.set_padding_high(0); - trivial_dim.set_window_dilation(1); - trivial_dim.set_base_dilation(1); - - WindowDimension active_dim; - active_dim.set_size(3); - active_dim.set_stride(1); - active_dim.set_padding_low(0); - active_dim.set_padding_high(0); - active_dim.set_window_dilation(1); - active_dim.set_base_dilation(1); - - Window window; - *window.add_dimensions() = trivial_dim; - *window.add_dimensions() = trivial_dim; - *window.add_dimensions() = active_dim; - *window.add_dimensions() = active_dim; - *window.add_dimensions() = trivial_dim; - *window.add_dimensions() = trivial_dim; - - Shape shape = ShapeUtil::MakeShape(F32, {8, 8, 6, 6, 8, 8}); - b.AddInstruction(HloInstruction::CreateReduceWindow(shape, input, init_value, - window, add_func)); + + const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); + + Padding padding = Padding::kValid; + ReduceWindowAdd(input, {1, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding); std::vector output_dims = {8, 8, 6, 6, 8, 8}; std::unique_ptr expected = Literal::CreateFullWithMonotonicDim0MajorLayout(output_dims, 9.0f); - module->AddEntryComputation(b.Build()); - auto actual = ExecuteAndTransfer(std::move(module), {}); - - LiteralTestUtil::ExpectNear(*actual, *expected, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, R4SecondMinorStride) { +XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) { Array4D input_array(2, 1, 27, 119); input_array.FillRandom(2.0f); std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationDataHandle input = - builder_.Parameter(0, input_literal->shape(), "operand"); + ComputationDataHandle input; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "parameter", &builder_, &input); int win_len = 1; int stride = 8; @@ -490,20 +425,19 @@ XLA_TEST_F(ReduceWindowTest, R4SecondMinorStride) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); - ComputeAndCompareR4(&builder_, *res, {input_data.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + {input_data.get()}, DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, R4SecondMinorUnitStride) { +XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) { Array4D input_array(3, 2, 4, 64); input_array.FillRandom(2.0f); std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationDataHandle input = - builder_.Parameter(0, input_literal->shape(), "operand"); + ComputationDataHandle input; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "parameter", &builder_, &input); int win_len = 3; int stride = 1; @@ -513,20 +447,19 @@ XLA_TEST_F(ReduceWindowTest, R4SecondMinorUnitStride) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); - ComputeAndCompareR4(&builder_, *res, {input_data.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + {input_data.get()}, DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, R4SecondMinorWin) { +XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) { Array4D input_array(1, 3, 12, 200); input_array.FillRandom(2.0f); std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationDataHandle input = - builder_.Parameter(0, input_literal->shape(), "operand"); + ComputationDataHandle input; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "parameter", &builder_, &input); int win_len = 8; int stride = 5; @@ -536,13 +469,11 @@ XLA_TEST_F(ReduceWindowTest, R4SecondMinorWin) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); - ComputeAndCompareR4(&builder_, *res, {input_data.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + {input_data.get()}, DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, AmongMajor2DimsMultipleMinor) { +TEST_P(ReduceWindowTest, AmongMajor2DimsMultipleMinor) { Array4D input_array(6, 4, 10, 130); input_array.FillRandom(2.0f); @@ -551,7 +482,7 @@ TEST_F(ReduceWindowTest, AmongMajor2DimsMultipleMinor) { Padding padding = Padding::kSame; const auto input_data_handle = - builder_.ConstantR4FromArray4D(input_array); + CreateConstantFromArray(input_array, &builder_); // Reduce only along the x and y dimensions, according to the win_len. ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); @@ -559,36 +490,42 @@ TEST_F(ReduceWindowTest, AmongMajor2DimsMultipleMinor) { auto result = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, + DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, Add24In1152_NoOverlap) { +XLA_TEST_P(ReduceWindowTest, Add24In1152_NoOverlap) { std::vector input_vector(128 * 9, 1); - const auto input = builder_.ConstantR1(input_vector); + const auto input = CreateConstantFromLiteral( + *Literal::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {32}, {128}, Padding::kValid); - ComputeAndCompareR1(&builder_, {32, 32, 32, 32, 32, 32, 32, 32, 32}, - {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral( + &builder_, + *Literal::CreateR1({32, 32, 32, 32, 32, 32, 32, 32, 32}), {}, + DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, Add128In128Stride128) { - const auto input = builder_.ConstantR1( - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); +XLA_TEST_P(ReduceWindowTest, Add128In128Stride128) { + std::vector input_vector{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + const auto input = CreateConstantFromLiteral( + *Literal::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {128}, {128}, Padding::kValid); - ComputeAndCompareR1(&builder_, {1088}, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateR1({1088}), {}, + DefaultErrorSpec()); } // Regression test for a bug that appeared in Inception (b/34784899). -TEST_F(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) { +TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) { Array2D input_array(14, 14, 1.0f); - ComputationDataHandle input = - builder_.Broadcast(builder_.ConstantLiteral(Literal::One(F32)), {14, 14}); + const auto input = CreateConstantFromArray(input_array, &builder_); int win_len = 3; int stride = 1; @@ -598,13 +535,14 @@ TEST_F(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) { auto res = ReferenceUtil::ReduceWindow2DAdd( input_array, 0.0f, {win_len, win_len}, {stride, stride}, padding); - ComputeAndCompareR2(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + {}, DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { +TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { Array2D input_array(6, 4, 1.0f); - ComputationDataHandle input = - builder_.Broadcast(builder_.ConstantLiteral(Literal::One(F32)), {6, 4}); + ComputationDataHandle input = builder_.Broadcast( + CreateConstantFromLiteral(Literal::One(F32), &builder_), {6, 4}); Padding padding = Padding::kSame; ReduceWindowAdd(input, {4, 2}, {3, 3}, padding); @@ -612,9 +550,13 @@ TEST_F(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { auto res = ReferenceUtil::ReduceWindow2DAdd(input_array, 0.0f, {4, 2}, {3, 3}, padding); - ComputeAndCompareR2(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + {}, DefaultErrorSpec()); } +INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest, + ::testing::ValuesIn(use_bfloat16_params)); + enum Reducer { kAdd, kMax }; struct R4ReduceWindowTestData { @@ -628,30 +570,36 @@ struct R4ReduceWindowTestData { }; string R4ReduceWindowTestDataToString( - const ::testing::TestParamInfo& data) { + const ::testing::TestParamInfo< + ::testing::tuple>& data) { + const auto& param = ::testing::get<0>(data.param); string str = tensorflow::strings::StrCat( - "base_bounds_", - tensorflow::str_util::Join(data.param.base_bounds, "x"), // + "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), // "__window_bounds_", - tensorflow::str_util::Join(data.param.window_bounds, "x"), // - "__strides_", tensorflow::str_util::Join(data.param.strides, "x"), // - "__pad_low_", tensorflow::str_util::Join(data.param.pad_low, "x"), // - "__pad_high_", tensorflow::str_util::Join(data.param.pad_high, "x"), // - (data.param.reducer == kAdd) ? "add" : "max"); - CHECK(data.param.reducer == kAdd || data.param.reducer == kMax); + tensorflow::str_util::Join(param.window_bounds, "x"), // + "__strides_", tensorflow::str_util::Join(param.strides, "x"), // + "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"), // + "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"), // + (param.reducer == kAdd) ? "add" : "max"); + CHECK(param.reducer == kAdd || param.reducer == kMax); // Test names are not allowed to contain the '-' character. std::replace(str.begin(), str.end(), '-', 'n'); + if (::testing::get<1>(data.param)) { + str = tensorflow::strings::StrCat(str, "_bfloat16"); + } return str; } -class R4ReduceWindowTest - : public ClientLibraryTestBase, - public ::testing::WithParamInterface { +class R4ReduceWindowTest : public ReduceWindowTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> { protected: + R4ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } + void DoIt() { ComputationBuilder b(client_, TestName()); - const auto& param = GetParam(); + const auto& param = ::testing::get<0>(GetParam()); const float kInitValue = 0.0f; @@ -660,23 +608,24 @@ class R4ReduceWindowTest input.FillIota(1); std::unique_ptr input_literal = Literal::CreateR4FromArray4D(input); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_arg, - client_->TransferToServer(*input_literal)); + ComputationDataHandle parameter; + auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", + &b, ¶meter); std::vector> padding(4); for (int i = 0; i < 4; ++i) { padding[i] = {param.pad_low[i], param.pad_high[i]}; } - auto parameter = b.Parameter(0, input_literal->shape(), "p0"); - auto pad_value = b.ConstantR0(kInitValue); + auto init_value = + CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); CHECK(param.reducer == kAdd || param.reducer == kMax); auto computation = param.reducer == kAdd - ? CreateScalarAddComputation(F32, &b) - : CreateScalarMaxComputation(F32, &b); + ? CreateScalarAddComputation(FloatType(), &b) + : CreateScalarMaxComputation(FloatType(), &b); b.ReduceWindowWithGeneralPadding( /*operand=*/parameter, - /*init_value=*/pad_value, + /*init_value=*/init_value, /*computation=*/computation, /*window_dimensions=*/param.window_bounds, /*window_strides=*/param.strides, @@ -694,8 +643,8 @@ class R4ReduceWindowTest /*window=*/param.window_bounds, /*stride=*/param.strides, /*padding=*/padding); - ComputeAndCompareR4(&b, *expected, {input_arg.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected), + {input_arg.get()}, DefaultErrorSpec()); } }; @@ -711,6 +660,14 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*pad_high=*/{0, 0, 0, 0}, /*reducer=*/kAdd}, + // Arbitrary padding (not kSame or kValid). + R4ReduceWindowTestData{/*base_bounds=*/{9, 12, 4, 89}, + /*window_bounds=*/{3, 3, 1, 1}, + /*strides=*/{2, 2, 1, 1}, + /*pad_low=*/{4, 4, 0, 0}, + /*pad_high=*/{4, 4, 0, 0}, + /*reducer=*/kAdd}, + // Zero base bound edge case. R4ReduceWindowTestData{/*base_bounds=*/{1, 0, 1, 1}, /*window_bounds=*/{1, 1, 1, 1}, @@ -824,9 +781,11 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*reducer=*/kAdd}, }; -INSTANTIATE_TEST_CASE_P(R4ReduceWindowTestInstantiation, R4ReduceWindowTest, - ::testing::ValuesIn(kR4ReduceWindowTestValues), - R4ReduceWindowTestDataToString); +INSTANTIATE_TEST_CASE_P( + R4ReduceWindowTestInstantiation, R4ReduceWindowTest, + ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowTestValues), + ::testing::ValuesIn(use_bfloat16_params)), + R4ReduceWindowTestDataToString); class R4ReduceWindowLargeTest : public R4ReduceWindowTest {}; @@ -849,10 +808,11 @@ const R4ReduceWindowTestData kR4ReduceWindowLargeTestValues[] = { /*reducer=*/kAdd}, }; -INSTANTIATE_TEST_CASE_P(R4ReduceWindowLargeTestInstantiation, - R4ReduceWindowLargeTest, - ::testing::ValuesIn(kR4ReduceWindowLargeTestValues), - R4ReduceWindowTestDataToString); +INSTANTIATE_TEST_CASE_P( + R4ReduceWindowLargeTestInstantiation, R4ReduceWindowLargeTest, + ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowLargeTestValues), + ::testing::ValuesIn(use_bfloat16_params)), + R4ReduceWindowTestDataToString); struct R2ReduceWindowTestData { int64 base_bounds[2]; @@ -900,26 +860,33 @@ struct R2ReduceWindowTestData { }; string R2ReduceWindowTestDataToString( - const ::testing::TestParamInfo& data) { + const ::testing::TestParamInfo< + ::testing::tuple>& data) { + const auto& param = ::testing::get<0>(data.param); string str = tensorflow::strings::StrCat( - "base_bounds_", - tensorflow::str_util::Join(data.param.base_bounds, "x"), // + "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), // "__window_bounds_", - tensorflow::str_util::Join(data.param.window_bounds, "x"), // - "__strides_", tensorflow::str_util::Join(data.param.strides, "x"), // - "__padding_", data.param.padding == Padding::kSame ? "same" : "valid", // - "__layout_", data.param.layout[0], "_", data.param.layout[1], // - "__reducer_", data.param.reducer == kAdd ? "add" : "max"); + tensorflow::str_util::Join(param.window_bounds, "x"), // + "__strides_", tensorflow::str_util::Join(param.strides, "x"), // + "__padding_", param.padding == Padding::kSame ? "same" : "valid", // + "__layout_", param.layout[0], "_", param.layout[1], // + "__reducer_", param.reducer == kAdd ? "add" : "max"); + if (::testing::get<1>(data.param)) { + str = tensorflow::strings::StrCat(str, "_bfloat16"); + } return str; } -class R2ReduceWindowTest - : public ClientLibraryTestBase, - public ::testing::WithParamInterface {}; +class R2ReduceWindowTest : public ReduceWindowTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> { + protected: + R2ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } +}; TEST_P(R2ReduceWindowTest, Add) { ComputationBuilder b(client_, TestName()); - const auto& param = GetParam(); + const auto& param = ::testing::get<0>(GetParam()); CHECK(param.reducer == kAdd); const float kInitValue = 0.0f; @@ -927,12 +894,15 @@ TEST_P(R2ReduceWindowTest, Add) { std::unique_ptr input_literal = Literal::CreateR2FromArray2DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_arg, - client_->TransferToServer(*input_literal)); - b.ReduceWindow(/*operand=*/ - b.Parameter(0, input_literal->shape(), "p0"), - /*init_value=*/b.ConstantR0(kInitValue), - /*computation=*/CreateScalarAddComputation(F32, &b), + + ComputationDataHandle parameter; + auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", + &b, ¶meter); + auto init_value = + CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); + b.ReduceWindow(/*operand=*/parameter, + /*init_value=*/init_value, + /*computation=*/CreateScalarAddComputation(FloatType(), &b), /*window_dimensions=*/param.window_bounds, /*window_strides=*/param.strides, /*padding=*/param.padding); @@ -940,90 +910,145 @@ TEST_P(R2ReduceWindowTest, Add) { /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds, /*stride=*/param.strides, /*padding=*/param.padding); - ComputeAndCompareR2(&b, *expected, {input_arg.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected), + {input_arg.get()}, DefaultErrorSpec()); } -INSTANTIATE_TEST_CASE_P(R2ReduceWindowTestInstantiation, R2ReduceWindowTest, - ::testing::ValuesIn(kR2TestCases), - R2ReduceWindowTestDataToString); +INSTANTIATE_TEST_CASE_P( + R2ReduceWindowTestInstantiation, R2ReduceWindowTest, + ::testing::Combine(::testing::ValuesIn(kR2TestCases), + ::testing::ValuesIn(use_bfloat16_params)), + R2ReduceWindowTestDataToString); struct R1ReduceWindowTestData { int64 base_bounds[1]; int64 window_bounds[1]; int64 strides[1]; - Padding padding; + int64 pad_low[1]; + int64 pad_high[1]; Reducer reducer; } kR1TestCases[] = { {/*base_bounds=*/{1}, /*window_bounds=*/{1}, /*strides=*/{1}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({1}, {1}, {1}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({1}, {1}, {1}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{3}, /*window_bounds=*/{3}, /*strides=*/{1}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({3}, {3}, {1}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({3}, {3}, {1}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{3}, /*window_bounds=*/{2}, /*strides=*/{1}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({3}, {2}, {1}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({3}, {2}, {1}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{5}, /*window_bounds=*/{1}, /*strides=*/{1}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax}, + /*pad_low=*/{xla::MakePadding({5}, {1}, {1}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({5}, {1}, {1}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kMax}, {/*base_bounds=*/{16}, /*window_bounds=*/{4}, /*strides=*/{4}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax}, + /*pad_low=*/{xla::MakePadding({16}, {4}, {4}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({16}, {4}, {4}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kMax}, {/*base_bounds=*/{16}, /*window_bounds=*/{4}, /*strides=*/{3}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({16}, {4}, {3}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({16}, {4}, {3}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, - {/*base_bounds=*/{128 * 2}, /*window_bounds=*/{30}, + {/*base_bounds=*/{128 * 2}, + /*window_bounds=*/{30}, /*strides=*/{27}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, - - {/*base_bounds=*/{128 * 17}, /*window_bounds=*/{7}, + /*pad_low=*/ + {xla::MakePadding({128 * 2}, {30}, {27}, Padding::kValid)[0].first}, + /*pad_high=*/ + {xla::MakePadding({128 * 2}, {30}, {27}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{128 * 17}, + /*window_bounds=*/{7}, /*strides=*/{64}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, - - {/*base_bounds=*/{128 * 2}, /*window_bounds=*/{32}, + /*pad_low=*/ + {xla::MakePadding({128 * 17}, {7}, {64}, Padding::kValid)[0].first}, + /*pad_high=*/ + {xla::MakePadding({128 * 17}, {7}, {64}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{128 * 2}, + /*window_bounds=*/{32}, /*strides=*/{56}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/ + {xla::MakePadding({128 * 2}, {32}, {56}, Padding::kValid)[0].first}, + /*pad_high=*/ + {xla::MakePadding({128 * 2}, {32}, {56}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{3}, /*window_bounds=*/{2}, /*strides=*/{1}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({3}, {2}, {1}, Padding::kSame)[0].first}, + /*pad_high=*/{xla::MakePadding({3}, {2}, {1}, Padding::kSame)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{5}, /*window_bounds=*/{3}, /*strides=*/{2}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({5}, {3}, {2}, Padding::kSame)[0].first}, + /*pad_high=*/{xla::MakePadding({5}, {3}, {2}, Padding::kSame)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{16}, /*window_bounds=*/{4}, /*strides=*/{3}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({16}, {4}, {3}, Padding::kSame)[0].first}, + /*pad_high=*/{xla::MakePadding({16}, {4}, {3}, Padding::kSame)[0].second}, + /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{5}, /*window_bounds=*/{5}, + /*strides=*/{1}, + /*pad_low=*/{0}, + /*pad_high=*/{5}, + /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{5}, /*window_bounds=*/{5}, + /*strides=*/{1}, + /*pad_low=*/{5}, + /*pad_high=*/{0}, + /*reducer=*/Reducer::kAdd}, }; string R1ReduceWindowTestDataToString( - const ::testing::TestParamInfo& data) { + const ::testing::TestParamInfo< + ::testing::tuple>& data) { + const auto& param = ::testing::get<0>(data.param); string str = tensorflow::strings::StrCat( - "base_bounds_", - tensorflow::str_util::Join(data.param.base_bounds, "x"), // - "__window_bounds_", - tensorflow::str_util::Join(data.param.window_bounds, "x"), // - "__strides_", tensorflow::str_util::Join(data.param.strides, "x"), // - "__padding_", data.param.padding == Padding::kSame ? "same" : "valid", // - "__reducer_", data.param.reducer == kAdd ? "add" : "max"); + "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), + "__window_bounds_", tensorflow::str_util::Join(param.window_bounds, "x"), + "__strides_", tensorflow::str_util::Join(param.strides, "x"), + "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"), + "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"), + "__reducer_", param.reducer == kAdd ? "add" : "max"); + if (::testing::get<1>(data.param)) { + str = tensorflow::strings::StrCat(str, "_bfloat16"); + } return str; } -class R1ReduceWindowTest - : public ClientLibraryTestBase, - public ::testing::WithParamInterface {}; +class R1ReduceWindowTest : public ReduceWindowTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> { + protected: + R1ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } +}; TEST_P(R1ReduceWindowTest, DoIt) { ComputationBuilder b(client_, TestName()); - const auto& param = GetParam(); + const auto& param = ::testing::get<0>(GetParam()); CHECK(param.reducer == kAdd || param.reducer == kMax); const float kInitValue = 0.0f; @@ -1031,18 +1056,24 @@ TEST_P(R1ReduceWindowTest, DoIt) { std::iota(std::begin(input_vector), std::end(input_vector), 0); std::unique_ptr input_literal = Literal::CreateR1(tensorflow::gtl::ArraySlice(input_vector)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_arg, - client_->TransferToServer(*input_literal)); + ComputationDataHandle parameter; + auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", + &b, ¶meter); + + std::vector> padding(1); + padding[0] = {param.pad_low[0], param.pad_high[0]}; auto computation = param.reducer == kAdd - ? CreateScalarAddComputation(F32, &b) - : CreateScalarMaxComputation(F32, &b); - b.ReduceWindow(/*operand=*/ - b.Parameter(0, input_literal->shape(), "p0"), - /*init_value=*/b.ConstantR0(kInitValue), - /*computation=*/computation, - /*window_dimensions=*/param.window_bounds, - /*window_strides=*/param.strides, /*padding=*/param.padding); + ? CreateScalarAddComputation(FloatType(), &b) + : CreateScalarMaxComputation(FloatType(), &b); + auto init_value = + CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); + b.ReduceWindowWithGeneralPadding( + /*operand=*/parameter, + /*init_value=*/init_value, + /*computation=*/computation, + /*window_dimensions=*/param.window_bounds, + /*window_strides=*/param.strides, /*padding=*/padding); auto reduce_func = param.reducer == kAdd ? +[](float a, float b) { return a + b; } @@ -1052,14 +1083,17 @@ TEST_P(R1ReduceWindowTest, DoIt) { /*init=*/kInitValue, /*reduce_func=*/reduce_func, /*window=*/param.window_bounds, - /*stride=*/param.strides, /*padding=*/param.padding); + /*stride=*/param.strides, + /*padding=*/padding); - ComputeAndCompareR1(&b, tensorflow::gtl::ArraySlice(*expected), - {input_arg.get()}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&b, *Literal::CreateR1(*expected), + {input_arg.get()}, DefaultErrorSpec()); } -INSTANTIATE_TEST_CASE_P(R1ReduceWindowTestInstantiation, R1ReduceWindowTest, - ::testing::ValuesIn(kR1TestCases), - R1ReduceWindowTestDataToString); +INSTANTIATE_TEST_CASE_P( + R1ReduceWindowTestInstantiation, R1ReduceWindowTest, + ::testing::Combine(::testing::ValuesIn(kR1TestCases), + ::testing::ValuesIn(use_bfloat16_params)), + R1ReduceWindowTestDataToString); } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index d235b9a1580ecbd6b82a69fca53d259912ff375e..ddd50d7a5864d73de7916ce736bb7cd40c1c4dc9 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -41,326 +41,467 @@ limitations under the License. namespace xla { namespace { -class ReshapeTest : public ClientLibraryTestBase { +// Use a bool parameter to indicate whether to use bfloat16. +class ReshapeTest : public ::testing::WithParamInterface, + public ClientLibraryTestBase { public: + ReshapeTest() { set_use_bfloat16(GetParam()); } + ErrorSpec zero_error_spec_{0.0}; }; // Collapses 2-dimensional pseudo-scalar (single-element array) to 1 dimension. -XLA_TEST_F(ReshapeTest, CollapseTrivial1x1) { +XLA_TEST_P(ReshapeTest, CollapseTrivial1x1) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2({{1.0}}); - builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); - - ComputeAndCompareR1(&builder, {1.0f}, {}, zero_error_spec_); + Array2D input_array(1, 1); + input_array.Fill(1.0f); + auto input_literal = Literal::CreateR2FromArray2D(input_array); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + + auto expected_literal = Literal::CreateR1({1.0f}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, CollapseTrivialR1EmptyDims) { +XLA_TEST_P(ReshapeTest, CollapseTrivialR1EmptyDims) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1({1.0}); - builder.Collapse(/*operand=*/a, /*dimensions=*/{}); - - ComputeAndCompareR1(&builder, {1.0f}, {}, zero_error_spec_); + auto input_literal = Literal::CreateR1({1.0f}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{}); + + auto expected_literal = Literal::CreateR1({1.0f}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, CollapseTrivialR1OnlyDim) { +XLA_TEST_P(ReshapeTest, CollapseTrivialR1OnlyDim) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1({1.0}); - builder.Collapse(/*operand=*/a, /*dimensions=*/{0}); - - ComputeAndCompareR1(&builder, {1.0f}, {}, zero_error_spec_); + auto input_literal = Literal::CreateR1({1.0f}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0}); + + auto expected_literal = Literal::CreateR1({1.0f}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Collapses 2-dimensional pseudo-scalar (single-element array) to scalar. -XLA_TEST_F(ReshapeTest, SingleElementArrayToScalar) { +XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2({{1.0}}); - auto reshape = - builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1}, /*new_sizes=*/{}); + Array2D input_array(1, 1); + input_array.Fill(1.0f); + auto input_literal = Literal::CreateR2FromArray2D(input_array); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + &builder, ¶meter); + auto reshape = builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{}); auto new_shape = builder.GetShape(reshape).ConsumeValueOrDie(); - ComputeAndCompareR0(&builder, 1.0f, {}, zero_error_spec_); + auto expected_literal = Literal::CreateR0(1.0f); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, ScalarToSingleElementArray) { +XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = Literal::CreateR0(1.0f); - std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - - auto a = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0"); - a = builder.Neg(a); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", + &builder, ¶meter); + auto a = builder.Neg(parameter); auto reshape = builder.Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1}); - ComputeAndCompareR1(&builder, {-1.0f}, {param0_data.get()}, - zero_error_spec_); + auto expected_literal = Literal::CreateR1({-1.0f}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, Trivial0x3) { +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2FromArray2D(Array2D(0, 3)); - auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); - - ComputeAndCompareR1(&builder, {}, {}, zero_error_spec_); + Array2D input_array(0, 3); + auto input_literal = Literal::CreateR2FromArray2D(input_array); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + auto expected_literal = Literal::CreateR1({}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // TODO(b/29185393): Make this work with the GPU backend. The GPU backend // does not handle zero-sized shapes correctly. Failed last on 2017-05-15 // with an incorrect result rank. -XLA_TEST_F(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) { +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = Literal::CreateR2FromArray2D(Array2D(0, 3)); - std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - - auto a = builder.Parameter(0, ShapeUtil::MakeShape(F32, {0, 3}), "param0"); - auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); - - ComputeAndCompareR1(&builder, {}, {param0_data.get()}, - zero_error_spec_); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + auto expected_literal = Literal::CreateR1({}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, Trivial3x0) { +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial3x0)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2FromArray2D(Array2D(3, 0)); - auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); - - ComputeAndCompareR1(&builder, {}, {}, zero_error_spec_); + Array2D input_array(3, 0); + auto input_literal = Literal::CreateR2FromArray2D(input_array); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + auto expected_literal = Literal::CreateR1({}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Collapses a 2-dimensional row vector to 1 dimension. -XLA_TEST_F(ReshapeTest, Trivial1x3) { +XLA_TEST_P(ReshapeTest, Trivial1x3) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2({{1.0f, 2.0f, 3.0f}}); - auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); - - ComputeAndCompareR1(&builder, {1.0f, 2.0f, 3.0f}, {}, - zero_error_spec_); + auto input_literal = Literal::CreateR2({{1.0f, 2.0f, 3.0f}}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + auto expected_literal = Literal::CreateR1({1.0f, 2.0f, 3.0f}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Collapses a 2-dimensional column vector to 1 dimension. -XLA_TEST_F(ReshapeTest, Trivial3x1) { +XLA_TEST_P(ReshapeTest, Trivial3x1) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2({{1.0f}, {2.0f}, {3.0f}}); - auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); - - ComputeAndCompareR1(&builder, {1.0f, 2.0f, 3.0f}, {}, - zero_error_spec_); + auto input_literal = Literal::CreateR2({{1.0f}, {2.0f}, {3.0f}}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + auto expected_literal = Literal::CreateR1({1.0f, 2.0f, 3.0f}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +// // Splits an empty vector into an empty matrix. -XLA_TEST_F(ReshapeTest, R1ToR2_0_To_2x0) { +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(R1ToR2_0_To_2x0)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1({}); - auto result = - builder.Reshape(/*operand=*/a, /*dimensions=*/{0}, /*new_sizes=*/{2, 0}); - ComputeAndCompareR2(&builder, Array2D(2, 0), {}, - zero_error_spec_); + auto input_literal = Literal::CreateR1({}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0}, + /*new_sizes=*/{2, 0}); + auto expected_literal = Literal::CreateR2({{}, {}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Splits a vector into a matrix. -XLA_TEST_F(ReshapeTest, R1ToR2_6_To_2x3) { +XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); - auto result = - builder.Reshape(/*operand=*/a, /*dimensions=*/{0}, /*new_sizes=*/{2, 3}); - Array2D expected_2x3({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); - ComputeAndCompareR2(&builder, expected_2x3, {}, zero_error_spec_); + auto input_literal = + Literal::CreateR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0}, + /*new_sizes=*/{2, 3}); + auto expected_literal = + Literal::CreateR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +// // Transposes a 2x0 array to a 0x2 array. -XLA_TEST_F(ReshapeTest, Reshape0x2To2x0) { +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Reshape0x2To2x0)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2FromArray2D(Array2D(0, 2)); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1}, - /*new_sizes=*/{2, 0}); - - ComputeAndCompareR2(&builder, Array2D(2, 0), {}, - zero_error_spec_); + auto input_literal = Literal::CreateFromArray(Array2D(0, 2)); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{2, 0}); + auto expected_literal = Literal::CreateR2({{}, {}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Transposes a 2-dimensional row vector to a column vector. -XLA_TEST_F(ReshapeTest, ReshapeRowToCol) { +XLA_TEST_P(ReshapeTest, ReshapeRowToCol) { ComputationBuilder builder(client_, TestName()); auto simple = MakeLinspaceArray2D(1.0f, 3.0f, 1, 3); - auto a = builder.ConstantR2FromArray2D(*simple); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1}, - /*new_sizes=*/{3, 1}); + auto input_literal = Literal::CreateFromArray(*simple); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{3, 1}); auto expected = ReferenceUtil::TransposeArray2D(*simple); - ComputeAndCompareR2(&builder, *expected, {}, zero_error_spec_); + auto expected_literal = Literal::CreateFromArray(*expected); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Transposes a 2-dimensional array. -XLA_TEST_F(ReshapeTest, TransposeAsReshape) { +XLA_TEST_P(ReshapeTest, TransposeAsReshape) { ComputationBuilder builder(client_, TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); - auto a = builder.ConstantR2FromArray2D(*a4x3); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{1, 0}, - /*new_sizes=*/{3, 4}); - - auto expected3x4 = ReferenceUtil::TransposeArray2D(*a4x3); - ComputeAndCompareR2(&builder, *expected3x4, {}, zero_error_spec_); + auto input_literal = Literal::CreateFromArray(*a4x3); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, + /*new_sizes=*/{3, 4}); + + auto expected = ReferenceUtil::TransposeArray2D(*a4x3); + auto expected_literal = Literal::CreateFromArray(*expected); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +// // Transposes a 0x4 array with ComputationBuilder::Trans. -XLA_TEST_F(ReshapeTest, Transpose0x4) { +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Transpose0x4)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2FromArray2D(Array2D(0, 4)); - auto result = builder.Transpose(a, {1, 0}); - - ComputeAndCompareR2(&builder, Array2D(4, 0), {}, - zero_error_spec_); + auto input_literal = Literal::CreateFromArray(Array2D(0, 4)); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Transpose(parameter, {1, 0}); + auto expected_literal = Literal::CreateR2({{}, {}, {}, {}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Transposes a 2-dimensional array with ComputationBuilder::Trans. -XLA_TEST_F(ReshapeTest, Transpose4x3) { +XLA_TEST_P(ReshapeTest, Transpose4x3) { ComputationBuilder builder(client_, TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); - auto a = builder.ConstantR2FromArray2D(*a4x3); - auto result = builder.Transpose(a, {1, 0}); - - auto expected3x4 = ReferenceUtil::TransposeArray2D(*a4x3); - ComputeAndCompareR2(&builder, *expected3x4, {}, zero_error_spec_); + auto input_literal = Literal::CreateFromArray(*a4x3); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Transpose(parameter, {1, 0}); + + auto expected = ReferenceUtil::TransposeArray2D(*a4x3); + auto expected_literal = Literal::CreateFromArray(*expected); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +// // Reshapes an empty 2-dimensional array with dimensions that are not just a // rearrangement of the originals (split), but no reordering (no shuffle). -XLA_TEST_F(ReshapeTest, ReshapeSplitNoShuffleZeroElements) { +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitNoShuffleZeroElements)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2FromArray2D(Array2D(6, 0)); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1}, - /*new_sizes=*/{2, 3, 0, 0}); - - ComputeAndCompareR4(&builder, Array4D(2, 3, 0, 0), {}, - zero_error_spec_); + auto input_literal = Literal::CreateFromArray(Array2D(6, 0)); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{2, 3, 0, 0}); + auto expected_literal = Literal::CreateFromArray(Array4D(2, 3, 0, 0)); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, ReshapeR4ToR2ZeroElements) { +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeR4ToR2ZeroElements)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR4FromArray4D(Array4D(2, 3, 4, 0)); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1, 2, 3}, - /*new_sizes=*/{24, 0}); - - ComputeAndCompareR2(&builder, Array2D(24, 0), {}, - zero_error_spec_); + auto input_literal = Literal::CreateFromArray(Array4D(2, 3, 4, 0)); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, + /*new_sizes=*/{24, 0}); + auto expected_literal = Literal::CreateFromArray(Array2D(24, 0)); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Reshapes a 2-dimensional array with dimensions that are not just a // rearrangement of the originals (split), but no reordering (no shuffle). -XLA_TEST_F(ReshapeTest, ReshapeSplitNoShuffle) { +XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) { ComputationBuilder builder(client_, TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); - auto a = builder.ConstantR2FromArray2D(*a4x3); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1}, - /*new_sizes=*/{2, 6}); - - auto expected2x6 = MakeLinspaceArray2D(1.0f, 12.0f, 2, 6); - ComputeAndCompareR2(&builder, *expected2x6, {}, zero_error_spec_); + auto input_literal = Literal::CreateFromArray(*a4x3); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{2, 6}); + + auto expected = MakeLinspaceArray2D(1.0f, 12.0f, 2, 6); + auto expected_literal = Literal::CreateFromArray(*expected); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -// Reshapes a 2-dimensional array with dimensions that are not just a -// rearrangement of the originals (split), and reorder the input (shuffle). -XLA_TEST_F(ReshapeTest, ReshapeSplitAndShuffleZeroElements) { +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +// +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitAndShuffleZeroElements)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2FromArray2D(Array2D(0, 6)); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{1, 0}, - /*new_sizes=*/{3, 0}); - - ComputeAndCompareR2(&builder, Array2D(3, 0), {}, - zero_error_spec_); + auto input_literal = Literal::CreateFromArray(Array2D(0, 6)); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, + /*new_sizes=*/{3, 0}); + auto expected_literal = Literal::CreateFromArray(Array2D(3, 0)); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Reshapes a 2-dimensional array with dimensions that are not just a // rearrangement of the originals (split), and reorder the input (shuffle). -XLA_TEST_F(ReshapeTest, ReshapeSplitAndShuffle) { +XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffle) { ComputationBuilder builder(client_, TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); - auto a = builder.ConstantR2FromArray2D(*a4x3); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{1, 0}, - /*new_sizes=*/{2, 6}); - - Array2D expected2x6({{1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f}, - {8.0f, 11.0f, 3.0f, 6.0f, 9.0f, 12.0f}}); - ComputeAndCompareR2(&builder, expected2x6, {}, zero_error_spec_); + auto input_literal = Literal::CreateFromArray(*a4x3); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, + /*new_sizes=*/{2, 6}); + Array2D expected({{1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f}, + {8.0f, 11.0f, 3.0f, 6.0f, 9.0f, 12.0f}}); + auto expected_literal = Literal::CreateFromArray(expected); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // The following tests use the same input 3D array; they test the examples we // show for the Reshape operation in the operation_semantics document. // TODO(b/34503277): find a way to show this code in the documentation without // duplication on the TF documentation server. -Array3D v_array_for_doc_R3_tests({{{10, 11, 12}, {15, 16, 17}}, - {{20, 21, 22}, {25, 26, 27}}, - {{30, 31, 32}, {35, 36, 37}}, - {{40, 41, 42}, {45, 46, 47}}}); - -XLA_TEST_F(ReshapeTest, DocR3_R1_Collapse_012) { - ComputationBuilder builder(client_, TestName()); - auto v = builder.ConstantR3FromArray3D(v_array_for_doc_R3_tests); - auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{0, 1, 2}, - /*new_sizes=*/{24}); - ComputeAndCompareR1(&builder, - {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27, - 30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47}, - {}); -} - -XLA_TEST_F(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { - ComputationBuilder builder(client_, TestName()); - auto v = builder.ConstantR3FromArray3D(v_array_for_doc_R3_tests); - auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{0, 1, 2}, - /*new_sizes=*/{8, 3}); - Array2D expected({{10, 11, 12}, - {15, 16, 17}, - {20, 21, 22}, - {25, 26, 27}, - {30, 31, 32}, - {35, 36, 37}, - {40, 41, 42}, - {45, 46, 47}}); - ComputeAndCompareR2(&builder, expected, {}); -} - -XLA_TEST_F(ReshapeTest, DocR3_R1_Collapse_120) { - ComputationBuilder builder(client_, TestName()); - auto v = builder.ConstantR3FromArray3D(v_array_for_doc_R3_tests); - auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{1, 2, 0}, - /*new_sizes=*/{24}); - ComputeAndCompareR1(&builder, - {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42, - 15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47}, - {}); -} - -XLA_TEST_F(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { - ComputationBuilder builder(client_, TestName()); - auto v = builder.ConstantR3FromArray3D(v_array_for_doc_R3_tests); - auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{1, 2, 0}, - /*new_sizes=*/{8, 3}); - Array2D expected({{10, 20, 30}, - {40, 11, 21}, - {31, 41, 12}, - {22, 32, 42}, - {15, 25, 35}, - {45, 16, 26}, - {36, 46, 17}, - {27, 37, 47}}); - ComputeAndCompareR2(&builder, expected, {}); -} - -XLA_TEST_F(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { - ComputationBuilder builder(client_, TestName()); - auto v = builder.ConstantR3FromArray3D(v_array_for_doc_R3_tests); - auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{1, 2, 0}, - /*new_sizes=*/{2, 6, 2}); - Array3D expected( +static Array3D ArrayForDocR3Tests() { + return Array3D({{{10, 11, 12}, {15, 16, 17}}, + {{20, 21, 22}, {25, 26, 27}}, + {{30, 31, 32}, {35, 36, 37}}, + {{40, 41, 42}, {45, 46, 47}}}); +} + +XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) { + ComputationBuilder builder(client_, TestName()); + auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, + /*new_sizes=*/{24}); + auto expected_literal = Literal::CreateR1( + {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27, + 30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); +} + +XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { + ComputationBuilder builder(client_, TestName()); + auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, + /*new_sizes=*/{8, 3}); + auto expected_literal = Literal::CreateR2({{10, 11, 12}, + {15, 16, 17}, + {20, 21, 22}, + {25, 26, 27}, + {30, 31, 32}, + {35, 36, 37}, + {40, 41, 42}, + {45, 46, 47}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); +} + +XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) { + ComputationBuilder builder(client_, TestName()); + auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, + /*new_sizes=*/{24}); + auto expected_literal = Literal::CreateR1( + {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42, + 15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); +} + +XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { + ComputationBuilder builder(client_, TestName()); + auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, + /*new_sizes=*/{8, 3}); + auto expected_literal = Literal::CreateR2({{10, 20, 30}, + {40, 11, 21}, + {31, 41, 12}, + {22, 32, 42}, + {15, 25, 35}, + {45, 16, 26}, + {36, 46, 17}, + {27, 37, 47}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); +} + +XLA_TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { + ComputationBuilder builder(client_, TestName()); + auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, + /*new_sizes=*/{2, 6, 2}); + auto expected_literal = Literal::CreateR3( {{{10, 20}, {30, 40}, {11, 21}, {31, 41}, {12, 22}, {32, 42}}, {{15, 25}, {35, 45}, {16, 26}, {36, 46}, {17, 27}, {37, 47}}}); - ComputeAndCompareR3(&builder, expected, {}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Collapses the low dimensions of a 4D tensor to get a 2D matrix, without @@ -378,23 +519,26 @@ XLA_TEST_F(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { // Then we collapse Z be collapsed so we just end up with planes: // // 1 2 3 4 5 6 1 2 3 4 5 6 -XLA_TEST_F(ReshapeTest, FullyConnectedCollapse) { +XLA_TEST_P(ReshapeTest, FullyConnectedCollapse) { ComputationBuilder builder(client_, TestName()); Array4D t2x2x2x3(2, 2, 2, 3); auto filler2x3 = MakeLinspaceArray2D(1.0f, 6.0f, 2, 3); t2x2x2x3.FillWithYX(*filler2x3); - auto a = builder.ConstantR4FromArray4D(t2x2x2x3); - auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{1, 2, 3}); - - Array2D expected2x12( + auto input_literal = Literal::CreateFromArray(t2x2x2x3); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{1, 2, 3}); + auto expected_literal = Literal::CreateR2( {{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}}); - ComputeAndCompareR2(&builder, expected2x12, {}, zero_error_spec_); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // As above, but uses reshape directly. -XLA_TEST_F(ReshapeTest, FullyConnectedCollapseDesugared) { +XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) { ComputationBuilder builder(client_, TestName()); Array4D t(2, 1, 2, 2); t(0, 0, 0, 0) = 0; @@ -405,52 +549,67 @@ XLA_TEST_F(ReshapeTest, FullyConnectedCollapseDesugared) { t(1, 0, 0, 1) = 5; t(1, 0, 1, 0) = 6; t(1, 0, 1, 1) = 7; - auto a = builder.ConstantR4FromArray4D(t); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1, 2, 3}, - /*new_sizes=*/{2, 4}); - - Array2D expected({{0, 1, 2, 3}, {4, 5, 6, 7}}); - ComputeAndCompareR2(&builder, expected, {}, zero_error_spec_); + auto input_literal = Literal::CreateFromArray(t); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, + /*new_sizes=*/{2, 4}); + + auto expected_literal = + Literal::CreateR2({{0, 1, 2, 3}, {4, 5, 6, 7}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Reshape various ranks to a scalar. -XLA_TEST_F(ReshapeTest, ToScalar) { +XLA_TEST_P(ReshapeTest, ToScalar) { for (int rank = 0; rank < 8; ++rank) { ComputationBuilder b(client_, TestName()); - auto input = Literal::CreateR1({83.0f}); + auto input_literal = Literal::CreateR1({83.0f}); std::vector ones(rank, 1); // this is {1, ..., 1}. std::vector dimensions(rank); std::iota(dimensions.begin(), dimensions.end(), 0); - *input->mutable_shape() = ShapeUtil::MakeShape(F32, ones); - b.Reshape(b.ConstantLiteral(*input), dimensions, {}); + *input_literal->mutable_shape() = ShapeUtil::MakeShape(F32, ones); - ComputeAndCompareR0(&b, 83.0f, {}, zero_error_spec_); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &b, ¶meter); + b.Reshape(parameter, dimensions, {}); + + auto expected_literal = Literal::CreateR0(83.0f); + ComputeAndCompareLiteral(&b, *expected_literal, {input.get()}, + zero_error_spec_); } } -XLA_TEST_F(ReshapeTest, BadDimensions) { +XLA_TEST_P(ReshapeTest, BadDimensions) { ComputationBuilder b(client_, TestName()); - b.Reshape(b.ConstantR1({1}), {}, {}); + auto input_literal = Literal::CreateR1({1.0f}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, + ¶meter); + b.Reshape(parameter, {}, {}); EXPECT_THAT( ExecuteToString(&b, {}), ::testing::HasSubstr("not a permutation of the operand dimensions")); } -XLA_TEST_F(ReshapeTest, BadNewSizes) { +XLA_TEST_P(ReshapeTest, BadNewSizes) { ComputationBuilder b(client_, TestName()); - b.Reshape(b.ConstantR1({1, 2}), {1}, {}); + auto input_literal = Literal::CreateR1({1.0f, 2.0f}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, + ¶meter); + b.Reshape(parameter, {1}, {}); EXPECT_THAT(ExecuteToString(&b, {}), ::testing::HasSubstr("mismatched element counts")); } -XLA_TEST_F(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { - const Shape parameter_shape = ShapeUtil::MakeShape(F32, {2, 2, 2, 2}); +XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, parameter_shape, "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8}); - // clang-format off - auto literal = Literal::CreateR4FromArray4DWithLayout(Array4D{ + auto input_literal = Literal::CreateR4FromArray4DWithLayout(Array4D{ { { {0, 1}, @@ -474,8 +633,12 @@ XLA_TEST_F(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { }, LayoutUtil::MakeLayout({0, 1, 2, 3})); // clang-format on - std::unique_ptr input = - client_->TransferToServer(*literal).ConsumeValueOrDie(); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + + builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8}); + Array2D expected_array({ {0, 1, 2, 3, 100, 101, 102, 103}, {222, 333, 444, 555, 666, 777, 888, 999}, @@ -484,72 +647,75 @@ XLA_TEST_F(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { Computation computation = builder.Build().ConsumeValueOrDie(); ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = - ShapeUtil::MakeShapeWithLayout(F32, {2, 8}, {1, 0}); + ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {2, 8}, + {1, 0}); std::unique_ptr actual = client_ ->ExecuteAndTransfer(computation, {input.get()}, &execution_options) .ConsumeValueOrDie(); std::unique_ptr expected = Literal::CreateR2FromArray2D(expected_array); + if (use_bfloat16()) { + expected = LiteralTestUtil::ConvertF32ToBF16(*expected); + } LiteralTestUtil::ExpectEqual(*expected, *actual); } -XLA_TEST_F(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { - std::unique_ptr input = Literal::CreateR2({ +XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { + ComputationBuilder builder(client_, TestName()); + std::unique_ptr input_literal = Literal::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, }); - std::unique_ptr input_data = - client_->TransferToServer(*input).ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4}); // clang-format off - Array4D expected = { + auto expected_literal = Literal::CreateR4({ {{{0, 1, 2, 3}}, {{4, 5, 6, 7}}}, {{{100, 101, 102, 103}}, {{104, 105, 106, 107}}}, {{{200, 201, 202, 203}}, {{204, 205, 206, 207}}} - }; + }); // clang-format on - ComputeAndCompareR4(&builder, expected, {input_data.get()}, - zero_error_spec_); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Tests R2->R4 reshape with the reshape dimensions {1, 0}. -XLA_TEST_F(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { - std::unique_ptr input = Literal::CreateR2({ +XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { + ComputationBuilder builder(client_, TestName()); + std::unique_ptr input_literal = Literal::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, }); - std::unique_ptr input_data = - client_->TransferToServer(*input).ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4}); // clang-format off - Array4D expected = { + auto expected_literal = Literal::CreateR4({ {{{0, 100, 200, 1}}, {{101, 201, 2, 102}}}, {{{202, 3, 103, 203}}, {{4, 104, 204, 5}}}, {{{105, 205, 6, 106}}, {{206, 7, 107, 207}}} - }; + }); // clang-format on - ComputeAndCompareR4(&builder, expected, {input_data.get()}, - zero_error_spec_); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { +XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { + ComputationBuilder builder(client_, TestName()); std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input(2, 1, 1, 1); @@ -559,12 +725,10 @@ XLA_TEST_F(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1}); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1}); std::unique_ptr expected = LiteralTestUtil::Reshape({2, 1}, {1, 0}, *input_literal); @@ -572,7 +736,8 @@ XLA_TEST_F(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { zero_error_spec_); } -XLA_TEST_F(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { +XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { + ComputationBuilder builder(client_, TestName()); std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input(2, 1, 4, 1); @@ -582,12 +747,10 @@ XLA_TEST_F(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2}); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2}); std::unique_ptr expected = LiteralTestUtil::Reshape({4, 2}, {1, 0}, *input_literal); @@ -596,7 +759,8 @@ XLA_TEST_F(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { } // Tests R4->R2 reshape with the reshape dimensions {0, 2, 1, 3}. -XLA_TEST_F(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { +XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { + ComputationBuilder builder(client_, TestName()); std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input(5, 10, 2, 3); @@ -606,12 +770,11 @@ XLA_TEST_F(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 2, 1, 3}, /*new_sizes=*/{5, 60}); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 2, 1, 3}, + /*new_sizes=*/{5, 60}); Array2D expected_array(5, 60); input.Each([&](tensorflow::gtl::ArraySlice indices, float* cell) { @@ -619,10 +782,12 @@ XLA_TEST_F(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { *cell; }); auto expected = Literal::CreateR2FromArray2D(expected_array); - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}); + ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, NoopReshape) { +XLA_TEST_P(ReshapeTest, NoopReshape) { + ComputationBuilder builder(client_, TestName()); std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input_array(2, 3, 5, 7); @@ -632,18 +797,17 @@ XLA_TEST_F(ReshapeTest, NoopReshape) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({1, 2, 3, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto input = builder.Parameter(0, input_literal->shape(), "input"); - builder.Reshape(input, /*dimensions=*/{3, 0, 1, 2}, + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{3, 0, 1, 2}, /*new_sizes=*/{7, 2, 3, 5}); Computation computation = builder.Build().ConsumeValueOrDie(); ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = - ShapeUtil::MakeShapeWithLayout(F32, {7, 2, 3, 5}, {2, 3, 0, 1}); + ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {7, 2, 3, 5}, + {2, 3, 0, 1}); std::unique_ptr output_literal = client_ ->ExecuteAndTransfer(computation, {input_data.get()}, @@ -652,35 +816,45 @@ XLA_TEST_F(ReshapeTest, NoopReshape) { // Since the reshape is a no-op, verify that it does not change the underlying // data. - EXPECT_EQ(tensorflow::gtl::ArraySlice(input_literal->f32s()), - tensorflow::gtl::ArraySlice(output_literal->f32s())); + if (use_bfloat16()) { + auto expected = LiteralTestUtil::ConvertF32ToBF16(*input_literal); + EXPECT_EQ(tensorflow::gtl::ArraySlice(expected->bf16s()), + tensorflow::gtl::ArraySlice(output_literal->bf16s())); + } else { + EXPECT_EQ(tensorflow::gtl::ArraySlice(input_literal->f32s()), + tensorflow::gtl::ArraySlice(output_literal->f32s())); + } } -XLA_TEST_F(ReshapeTest, R4ToR4Reshape_Trivial) { - auto literal_1x2x3x4 = Literal::CreateR4( +XLA_TEST_P(ReshapeTest, R4ToR4Reshape_Trivial) { + ComputationBuilder builder(client_, TestName()); + auto literal_1x2x3x4 = Literal::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); - ComputationBuilder builder(client_, TestName()); - auto input = builder.ConstantLiteral(*literal_1x2x3x4); - builder.Reshape(input, /*dimensions=*/{0, 1, 2, 3}, + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input", + &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{1, 2, 3, 4}); - ComputeAndCompareLiteral(&builder, *literal_1x2x3x4, {}); + ComputeAndCompareLiteral(&builder, *literal_1x2x3x4, {input.get()}); } -XLA_TEST_F(ReshapeTest, R4ToR4Reshape) { - auto literal_1x2x3x4 = Literal::CreateR4( +XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { + auto literal_1x2x3x4 = Literal::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); ComputationBuilder builder(client_, TestName()); - auto input = builder.ConstantLiteral(*literal_1x2x3x4); - builder.Reshape(input, /*dimensions=*/{1, 3, 2, 0}, + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input", + &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{1, 3, 2, 0}, /*new_sizes=*/{2, 4, 3, 1}); // clang-format off - auto expected_2x4x3x1 = Literal::CreateR4( + auto expected_2x4x3x1 = Literal::CreateR4( {{{{1}, {5}, {9}}, {{2}, {6}, {10}}, {{3}, {7}, {11}}, @@ -691,10 +865,10 @@ XLA_TEST_F(ReshapeTest, R4ToR4Reshape) { {{16}, {20}, {24}}}}); // clang-format on - ComputeAndCompareLiteral(&builder, *expected_2x4x3x1, {}); + ComputeAndCompareLiteral(&builder, *expected_2x4x3x1, {input.get()}); } -XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeSimple) { +XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { std::mt19937 rng; std::uniform_real_distribution distribution; std::vector bounds = {2, 2, 2, 2}; @@ -706,12 +880,12 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeSimple) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) @@ -723,7 +897,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeSimple) { zero_error_spec_, &expected->shape()); } -XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { +XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { std::mt19937 rng; std::uniform_real_distribution distribution; std::vector bounds = {1, 1, 250, 300}; @@ -735,12 +909,12 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) @@ -752,7 +926,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { zero_error_spec_, &expected->shape()); } -XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { +XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { std::mt19937 rng; std::uniform_real_distribution distribution; std::vector bounds = {5, 5, 1, 10}; @@ -764,12 +938,12 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) @@ -781,7 +955,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { zero_error_spec_, &expected->shape()); } -XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { +XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { std::mt19937 rng; std::uniform_real_distribution distribution; // This happens in NN-Builder MNIST. @@ -794,12 +968,12 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) @@ -811,7 +985,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { zero_error_spec_, &expected->shape()); } -XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeTrivialR2) { +XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { std::mt19937 rng; std::uniform_real_distribution distribution; std::vector bounds = {3, 3, 1, 3}; @@ -823,12 +997,12 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeTrivialR2) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({0, 1, 2, 3})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{1, 0, 2, 3}, /*new_sizes=*/new_bounds); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{1, 0, 2, 3}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = LiteralTestUtil::Reshape(new_bounds, {1, 0, 2, 3}, *input_literal) @@ -840,5 +1014,12 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeTrivialR2) { zero_error_spec_, &expected->shape()); } +#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 +INSTANTIATE_TEST_CASE_P(ReshapeTestInstance, ReshapeTest, ::testing::Bool()); +#else +INSTANTIATE_TEST_CASE_P(ReshapeTestInstance, ReshapeTest, + ::testing::ValuesIn(std::vector{false})); +#endif + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/sample_file_test.cc b/tensorflow/compiler/xla/tests/sample_file_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..31b104f4e37f77d47f56ff8183ee1de1cc22e44d --- /dev/null +++ b/tensorflow/compiler/xla/tests/sample_file_test.cc @@ -0,0 +1,51 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This demonstrates how to use hlo_test_base to create a file based testcase +// and compare results on gpu and cpu. + +#include +#include + +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class SampleFileTest : public HloTestBase { + protected: + SampleFileTest() + : HloTestBase( + /*test_platform=*/PlatformUtil::GetPlatform("gpu").ValueOrDie(), + /*reference_platform=*/PlatformUtil::GetPlatform("cpu") + .ValueOrDie()) {} +}; + +TEST_F(SampleFileTest, Convolution) { + const string& filename = "compiler/xla/tests/isolated_convolution.hlo"; + string test_srcdir = tensorflow::testing::TensorFlowSrcRoot(); + EXPECT_TRUE(RunAndCompareFromFile( + tensorflow::io::JoinPath(test_srcdir, filename), ErrorSpec{0.01})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/sample_text_test.cc b/tensorflow/compiler/xla/tests/sample_text_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b4f2b74e3dc9e80f50454b28eb6f2502cef3e681 --- /dev/null +++ b/tensorflow/compiler/xla/tests/sample_text_test.cc @@ -0,0 +1,66 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This demonstrates how to use hlo_test_base to create textual IR based +// testcases. + +#include +#include + +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +using tensorflow::gtl::nullopt; + +class SampleTextTest : public HloTestBase {}; + +TEST_F(SampleTextTest, Axpy) { + const string& hlo_string = R"( +HloModule axpy_module: +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %x = f32[2,4]{1,0} parameter(1) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + %y = f32[2,4]{1,0} parameter(2) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_string, ErrorSpec{0.0001})); +} + +TEST_F(SampleTextTest, Tuple) { + const string& hlo_string = R"( +HloModule TupleCreate_module: +ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) { + %v1 = f32[] parameter(0) + %v2 = f32[3]{0} parameter(1) + %v3 = f32[2,3]{1,0} parameter(2) + ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3) +} +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, nullopt)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index c21124750ad512cad69b1483e708613ee2857ac0..4db566f7841829359ea06fe25408048418c547ad 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -211,6 +212,13 @@ class SliceR1Test : public ClientLibraryTestBase, } }; +string SliceR1TestDataToString(const ::testing::TestParamInfo& data) { + const R1Spec& spec = data.param; + return ::tensorflow::strings::Printf("%lld_%lld_%lld_%lld", spec.input_dim0, + spec.slice_start, spec.slice_limit, + spec.slice_stride); +} + XLA_TEST_P(SliceR1Test, DoIt_F32) { Run(GetParam()); } XLA_TEST_P(SliceR1Test, DoIt_F64) { Run(GetParam()); } @@ -223,30 +231,66 @@ XLA_TEST_P(SliceR1Test, DoIt_U64) { Run(GetParam()); } XLA_TEST_P(SliceR1Test, DoIt_S64) { Run(GetParam()); } -INSTANTIATE_TEST_CASE_P( // - SliceR1TestInstantiation, // - SliceR1Test, // - ::testing::Values( // - R1Spec{10, 0, 0, 1}, // - R1Spec{10, 7, 7, 1}, // - R1Spec{10, 2, 4, 1}, // - R1Spec{10, 2, 4, 2}, // - R1Spec{10, 0, 10, 1}, // - R1Spec{1024, 1024 - 4, 1024, 1}, // - R1Spec{4096, 7, 7 + 1024, 1}, // - R1Spec{10, 0, 10, 2}, // - R1Spec{10, 0, 10, 3}, // - R1Spec{10, 0, 10, 4}, // - R1Spec{10, 0, 10, 5}, // - R1Spec{10, 0, 10, 10}, // - R1Spec{500, 200, 400, 7}, // - R1Spec{4096, 1, 4095, 3}, // - R1Spec{2047, 1024 - 24, 1024 + 160, 31}, // - R1Spec{2047, 1, 2046, 3 * 128}, // - R1Spec{4096, 1024 + 3, 4095, 500}, // - R1Spec{8192, 0, 8192, 1024 * 3 + 400} // - ) // +// Tests for R1 slice ops. +// The format for each testcase is {input size, start, limit, stride}. +// clang-format off +INSTANTIATE_TEST_CASE_P( + SliceR1TestInstantiation, + SliceR1Test, + ::testing::Values( + R1Spec{10, 0, 0, 1}, + R1Spec{10, 7, 7, 1}, + R1Spec{10, 0, 5, 1}, + R1Spec{10, 3, 5, 1}, + R1Spec{10, 0, 10, 1}, + R1Spec{1024, 0, 5, 1}, + R1Spec{1024, 3, 5, 1}, + R1Spec{1024 + 17, 0, 5, 1}, + R1Spec{1024 + 17, 3, 5, 1}, + R1Spec{1024 + 17, 1024, 1024 + 6, 1}, + R1Spec{1024 + 17, 1024 + 1, 1024 + 6, 1}, + R1Spec{1024, 1024 - 4, 1024, 1}, + R1Spec{4 * 1024, 7, 7 + 1024, 1}, + R1Spec{4 * 1024, 0, 4 * 1024, 1}, + R1Spec{4 * 1024, 1, 4 * 1024 - 1, 1}, + R1Spec{4 * 1024, 1024, 3 * 1024, 1}, + R1Spec{4 * 1024, 1024 + 1, 3 * 1024 - 1, 1}, + R1Spec{16 * 1024, 0, 5, 1}, + R1Spec{16 * 1024, 3, 5, 1}, + R1Spec{16 * 1024 + 17, 0, 5, 1}, + R1Spec{16 * 1024 + 17, 3, 5, 1}, + R1Spec{16 * 1024 + 17, 16 * 1024, 16 * 1024 + 6, 1}, + R1Spec{16 * 1024 + 17, 16 * 1024 + 1, 16 * 1024 + 6, 1}, + R1Spec{16 * 1024, 4 * 1024 - 17, 8 * 1024 - 18, 1}, + R1Spec{64 * 1024, 0, 64 * 1024, 1}, + R1Spec{64 * 1024, 1, 64 * 1024 - 1, 1}, + R1Spec{64 * 1024, 1024, 63 * 1024, 1}, + R1Spec{64 * 1024, 1024 + 1, 63 * 1024 - 1, 1}, + R1Spec{64 * 1024, 32 * 1024, 33 * 1024, 1}, + R1Spec{64 * 1024, 32 * 1024 + 1, 33 * 1024 - 1, 1}, + R1Spec{64 * 1024, 32 * 1024 - 17, 36 * 1024 - 18, 1}, +// TODO(b/69425338): This uses too much memory on GPU. +#ifndef XLA_TEST_BACKEND_GPU + R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024, 12 * 1024 * 1024, 1}, + R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024 + 1, 12 * 1024 * 1024 - 1, 1}, + R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024 - 1, 12 * 1024 * 1024 + 1, 1}, +#endif + R1Spec{10, 2, 4, 2}, + R1Spec{10, 0, 10, 2}, + R1Spec{10, 0, 10, 3}, + R1Spec{10, 0, 10, 4}, + R1Spec{10, 0, 10, 5}, + R1Spec{10, 0, 10, 10}, + R1Spec{500, 200, 400, 7}, + R1Spec{4096, 1, 4095, 3}, + R1Spec{2047, 1024 - 24, 1024 + 160, 31}, + R1Spec{2047, 1, 2046, 3 * 128}, + R1Spec{4096, 1024 + 3, 4095, 500}, + R1Spec{8192, 0, 8192, 1024 * 3 + 400} + ), + SliceR1TestDataToString ); +// clang-format on struct R2Spec { int64 input_dim0; diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 0d56c9f48363d0569921d7c76050dcc66208931b..f9c62ec217d085e5c5a55f484c4bd712c6ccf05a 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -24,6 +25,7 @@ namespace { template void PopulateWithRandomFloatingPointData(Literal* literal) { + // TODO(b/69179121): Generate data that is less self-similar. CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); std::minstd_rand0 engine; @@ -34,6 +36,19 @@ void PopulateWithRandomFloatingPointData(Literal* literal) { })); } +// The standard library does not have a case for bfloat16, unsurprisingly, so we +// handle that one specially. +template <> +void PopulateWithRandomFloatingPointData(Literal* literal) { + CHECK_EQ(literal->shape().element_type(), BF16); + std::minstd_rand0 engine; + std::uniform_real_distribution generator(0.0f, 1.0f); + TF_CHECK_OK(literal->Populate( + [&](tensorflow::gtl::ArraySlice /*indices*/) { + return static_cast(generator(engine)); + })); +} + template void PopulateWithRandomIntegralData(Literal* literal) { CHECK_EQ(literal->shape().element_type(), @@ -47,42 +62,131 @@ void PopulateWithRandomIntegralData(Literal* literal) { })); } -bool LooksLikeSum(const HloInstruction& instruction) { - return instruction.opcode() == HloOpcode::kAdd && - instruction.operand(0)->opcode() == HloOpcode::kParameter && - instruction.operand(1)->opcode() == HloOpcode::kParameter && - instruction.operand(0) != instruction.operand(1); +// Matches binary addition computations. +bool LooksLikeSum(const HloComputation& computation) { + const HloInstruction* const root = computation.root_instruction(); + return root->opcode() == HloOpcode::kAdd && + computation.num_parameters() == 2 && + root->operand(0)->opcode() == HloOpcode::kParameter && + root->operand(1)->opcode() == HloOpcode::kParameter && + root->operand(0) != root->operand(1); +} + +// Reduce, ReduceWindow, and SelectAndScatter ops may use binary addition, +// which requires an init_value of 0 rather than a random value. +bool NeedsZeroInitValue(const HloUse& use) { + const HloInstruction* const instruction = use.instruction; + const HloOpcode opcode = instruction->opcode(); + const int64 op_num = use.operand_number; + return ( + ((opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow) && + op_num == 1 && LooksLikeSum(*instruction->to_apply())) || + (opcode == HloOpcode::kSelectAndScatter && op_num == 2 && + LooksLikeSum(*instruction->scatter()))); } -// Given an instruction and operand number, replace the given operand with -// a Literal Constant Zero. Handle the case of a fusion instruction by -// replacing the fusion's parent's parameter with a Literal Constant Zero, -// unless the fusion's parent is itself a fusion. -Status MaybeReplaceParameterInputWithZero(HloInstruction* const instruction, - const int64 operand_number) { - CHECK_LT(operand_number, instruction->operand_count()); - if (instruction->operand(operand_number)->opcode() != HloOpcode::kParameter) { - return Status::OK(); +// Generate random values that are constrained to the input_shape minus the +// output_shape so as not to produce wrapping slices, for instance. +std::unique_ptr MakeRandomNonwrappingSliceIndex( + const Shape& input_shape, const Shape& slice_shape) { + const int64 rank = ShapeUtil::Rank(input_shape); + std::vector start_indices(rank); + std::minstd_rand0 engine; + for (int i = 0; i < rank; ++i) { + const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) - + ShapeUtil::GetDimension(slice_shape, i); + std::uniform_int_distribution generator(0, upper_bound); + start_indices[i] = generator(engine); } + return Literal::CreateR1(start_indices); +} - HloComputation* const computation = instruction->parent(); - std::unique_ptr zero = HloInstruction::CreateConstant( - MakeUnique(Literal::Zero(instruction->shape().element_type()))); +// Use dataflow analysis on each parameter to see if there are uses that would +// be problematic when generating input data. Returns the list of instructions +// that correspond to their uses. +// +// Should be paired with the CreateLiteralForConstrainedUses() function below. +std::vector FindConstrainedUses( + const HloDataflowAnalysis& dataflow, const HloInstruction& param) { + std::vector constrained_uses; + for (const auto& pair : dataflow.GetInstructionValueSet(¶m)) { + const HloValue& value = dataflow.GetUniqueValueAt(¶m, pair.first); + for (const HloUse& use : value.uses()) { + HloInstruction* instruction = use.instruction; + const HloOpcode opcode = instruction->opcode(); + const int64 op_num = use.operand_number; + if ((opcode == HloOpcode::kDynamicSlice && op_num == 1) || + (opcode == HloOpcode::kDynamicUpdateSlice && op_num == 2)) { + constrained_uses.push_back(instruction); + } else if (opcode == HloOpcode::kFusion) { + const HloInstruction* const to_analyze = + instruction->fused_parameter(op_num); + auto fused_uses = FindConstrainedUses(dataflow, *to_analyze); + constrained_uses.insert(constrained_uses.end(), fused_uses.begin(), + fused_uses.end()); + } else if (NeedsZeroInitValue(use)) { + constrained_uses.push_back(instruction); + } + } + } + return constrained_uses; +} - if (computation->IsFusionComputation()) { - HloInstruction* const fusion_instruction = computation->FusionInstruction(); - if (fusion_instruction->IsFused()) { - return Unimplemented( - "Unable to replace fused parameter of fusion instruction"); +// Given a parameter, generate a random Literal to use as input if there exist +// no constrained uses in the dataflow graph. If such constraints exist, +// generate a constrained literal (either bounded in the case of indices, or +// zero in the case of init_values for reductions). +StatusOr> CreateLiteralForConstrainedUses( + const tensorflow::gtl::ArraySlice constrained_uses, + const HloInstruction& param) { + HloInstruction* needs_index = nullptr; + HloInstruction* needs_zero = nullptr; + for (HloInstruction* use : constrained_uses) { + switch (use->opcode()) { + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + TF_RET_CHECK(ShapeUtil::Equal(param.shape(), use->operand(0)->shape())); + if (needs_index != nullptr && + !ShapeUtil::Equal(needs_index->shape(), use->shape())) { + return Unimplemented( + "Conflicting operand generation slice index constraints\n"); + } + needs_index = use; + break; + + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: + needs_zero = use; + break; + + default: + return Unimplemented( + "Constrained operand generation not implemented for %s.", + use->ToString().c_str()); } - TF_RETURN_IF_ERROR(fusion_instruction->ReplaceOperandWith( - instruction->operand(operand_number)->parameter_number(), - fusion_instruction->parent()->AddInstruction(std::move(zero)))); + } + if (needs_index != nullptr && needs_zero != nullptr) { + return Unimplemented( + "Conflicting operand generation constraints.\nNeeds index: %s\nNeeds " + "zero: %s\n", + needs_index->ToString().c_str(), needs_zero->ToString().c_str()); + } + if (needs_index != nullptr) { + return MakeRandomNonwrappingSliceIndex(param.shape(), needs_index->shape()); + } else if (needs_zero != nullptr) { + return Literal::CreateFromShape(param.shape()); } else { - TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith( - operand_number, computation->AddInstruction(std::move(zero)))); + return MakeFakeLiteral(param.shape()); } - return Status::OK(); +} + +// Given a module entry parameter, use the dataflow analysis to see if a +// special case literal must be created, or if we can generate fake data. +StatusOr> MakeConstrainedArgument( + const HloDataflowAnalysis& dataflow, const HloInstruction& param) { + const auto constrained_uses = FindConstrainedUses(dataflow, param); + return CreateLiteralForConstrainedUses(constrained_uses, param); } } // namespace @@ -99,6 +203,9 @@ StatusOr> MakeFakeLiteral(const Shape& shape) { } std::unique_ptr literal = Literal::CreateFromShape(shape); switch (shape.element_type()) { + case BF16: + PopulateWithRandomFloatingPointData(literal.get()); + break; case F32: PopulateWithRandomFloatingPointData(literal.get()); break; @@ -146,33 +253,17 @@ StatusOr> MakeFakeLiteral(const Shape& shape) { } StatusOr>> MakeFakeArguments( - const HloModule& module) { - std::vector> arguments; - for (const ShapeLayout& shape_layout : - module.config().entry_computation_layout().parameter_layouts()) { - TF_ASSIGN_OR_RETURN(auto literal, MakeFakeLiteral(shape_layout.shape())); - arguments.push_back(std::move(literal)); + HloModule* const module) { + TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(module)); + const auto params = module->entry_computation()->parameter_instructions(); + std::vector> arguments(params.size()); + for (int i = 0; i < params.size(); ++i) { + TF_ASSIGN_OR_RETURN(arguments[i], + MakeConstrainedArgument(*dataflow, *params[i])); } return std::move(arguments); } -Status ReplaceInitsWithConstants(HloModule* const module) { - for (HloComputation* const computation : module->computations()) { - for (HloInstruction* const instruction : computation->instructions()) { - const HloOpcode opcode = instruction->opcode(); - if ((opcode == HloOpcode::kReduce || - opcode == HloOpcode::kReduceWindow) && - LooksLikeSum(*instruction->to_apply()->root_instruction())) { - TF_RETURN_IF_ERROR(MaybeReplaceParameterInputWithZero(instruction, 1)); - } else if (opcode == HloOpcode::kSelectAndScatter && - LooksLikeSum(*instruction->scatter()->root_instruction())) { - TF_RETURN_IF_ERROR(MaybeReplaceParameterInputWithZero(instruction, 2)); - } - } - } - return Status::OK(); -} - Status VerifyHloModule(const perftools::gputools::Platform& platform, HloModule* const module) { return HloVerifier( diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index 9aca162a185e5b22888229555b7bce88769c79a6..0fb024ffb074f1c90b75022bc7f5a8b58b03c0c2 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -60,13 +60,11 @@ StatusOr> MakeFakeLiteral(const Shape& shape); // Generates a vector of arguments containing fake data. The number, shape and // layout of the arguments is appropriate for given HLO module. +// +// Will handle special cases such as making sure that indices used for dynamic +// slices are bounded, reduces that call adds use 0 as an init value, etc. StatusOr>> MakeFakeArguments( - const HloModule& module); - -// Reductions using Adds, ReduceWindow, and SelectAndScatter, require their -// init_value to be replaced with the constant 0.0f when testing, otherwise we -// may generate a bad init_value when looking at the op in isolation. -Status ReplaceInitsWithConstants(HloModule* const module); + HloModule* const module); // Check that a given module satisfies various constraints before trying to // execute it. diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc index c30cd1b7b8e9be50d33fafb12d70e204e7321864..ed556fafb17cb2d243141783f822400d3c54b459 100644 --- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc +++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc @@ -33,29 +33,27 @@ limitations under the License. #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" -namespace se = ::perftools::gputools; - namespace xla { - namespace { class TransferManagerTest : public LocalClientTestBase { protected: - TransferManagerTest() { - shape_size_fn_ = [this](const Shape& shape) { - return transfer_manager_->GetByteSizeRequirement(shape); - }; - } + TransferManagerTest() + : shape_size_fn_([this](const Shape& shape) { + return transfer_manager_->GetByteSizeRequirement(shape); + }) {} - ~TransferManagerTest() override {} + ~TransferManagerTest() override = default; std::unique_ptr AllocateDeviceBuffer(const Shape& shape) { - return ScopedShapedBuffer::Allocate( - shape, GetOrCreateAllocator(local_client_->platform()), - /*device_ordinal=*/0, shape_size_fn_) - .ConsumeValueOrDie(); + return transfer_manager_ + ->AllocateScopedShapedBuffer( + shape, GetOrCreateAllocator(local_client_->platform()), + /*device_ordinal=*/0) + .ValueOrDie(); } + private: std::function shape_size_fn_; }; @@ -214,6 +212,39 @@ XLA_TEST_F(TransferManagerTest, TransferNestedTuple) { LiteralTestUtil::ExpectEqual(*literal, *result); } -} // namespace +XLA_TEST_F(TransferManagerTest, TransferComplexValue) { + std::unique_ptr literal = Literal::CreateR1( + {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}); + auto device_buffer = AllocateDeviceBuffer(literal->shape()); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + LiteralTestUtil::ExpectEqual(*literal, *result); +} +XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) { + std::unique_ptr literal = Literal::MakeTuple( + {Literal::CreateR1( + {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}) + .get(), + Literal::CreateR1({1, 2, 3, 4, 5, 6}).get(), + Literal::CreateR0(complex64(0.3f, -0.4f)).get()}); + auto device_buffer = AllocateDeviceBuffer(literal->shape()); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + LiteralTestUtil::ExpectEqual(*literal, *result); +} + +} // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 4920f17a7ed21d587c15b8deac550d5e5bb566c9..65489cfff19c8fecbdead8a7e295bf9cca56038f 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -180,7 +180,8 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) { ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } -XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { +// TODO(b/68395210): GPU does not tolerate ambiguous top-level buffers. +XLA_TEST_F(TupleTest, DISABLED_ON_GPU(SelectBetweenPredTuples)) { ComputationBuilder b(client_, TestName()); ComputationDataHandle v1, v2; @@ -444,5 +445,61 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { ComputeAndCompareR1(&builder, expected, arguments, ErrorSpec(1e-5)); } +XLA_TEST_F(TupleTest, ComplexTuples) { + ComputationBuilder builder(client_, TestName()); + { + Shape c64r0 = ShapeUtil::MakeShape(C64, {}); + Shape c64r1 = ShapeUtil::MakeShape(C64, {2}); + Shape c64r2 = ShapeUtil::MakeShape(C64, {3, 2}); + Shape arg0_shape = ShapeUtil::MakeTupleShape( + {c64r0, ShapeUtil::MakeTupleShape({c64r1, c64r2})}); + auto input0 = builder.Parameter(0, arg0_shape, "input0"); + auto t0 = builder.GetTupleElement(input0, 0); + auto t1 = builder.GetTupleElement(input0, 1); + auto t10 = builder.GetTupleElement(t1, 0); + auto t11 = builder.GetTupleElement(t1, 1); + auto sum = builder.Add(builder.Add(t10, t11, {1}), t0); + auto input1 = builder.Parameter(1, c64r1, "input1"); + auto prod = builder.Mul(input1, sum, {1}); + builder.Tuple({builder.Tuple({prod, sum}), + builder.ConstantR0({123, 456})}); + } + + std::unique_ptr arg0 = + client_ + ->TransferToServer(*Literal::MakeTuple( + {Literal::CreateR0({1, 2}).get(), + Literal::MakeTuple( + {Literal::CreateR1({{10, 20}, {30, 40}}).get(), + Literal::CreateR2( + {{{100, 200}, {300, 400}}, + {{1000, 2000}, {3000, 4000}}, + {{10000, 20000}, {30000, 40000}}}) + .get()}) + .get()})) + .ConsumeValueOrDie(); + std::unique_ptr arg1 = + client_ + ->TransferToServer(*Literal::CreateR1({{1, 2}, {1, -2}})) + .ConsumeValueOrDie(); + auto sum = Literal::CreateR2({{{111, 222}, {331, 442}}, + {{1011, 2022}, {3031, 4042}}, + {{10011, 20022}, {30031, 40042}}}); + auto prod = Literal::CreateFromShape(sum->shape()); + ASSERT_TRUE(prod->Populate( + [&sum](tensorflow::gtl::ArraySlice indexes) { + return sum->Get(indexes) * + (indexes[indexes.size() - 1] == 0 + ? complex64(1, 2) + : complex64(1, -2)); + }) + .ok()); + auto expected = + Literal::MakeTuple({Literal::MakeTuple({prod.get(), sum.get()}).get(), + Literal::CreateR0({123, 456}).get()}); + ComputeAndCompareTuple(&builder, *expected, {arg0.get(), arg1.get()}, + error_spec_); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 49f673f5f0bf9b844ab4030383784208b4e2c58a..0b3430ee1ee515c2c98c64a947b7a7021c04f22b 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -357,8 +357,7 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } -// TODO(b/63003356): 11-06-2017: fails on all back-ends with incorrect result. -TEST_F(WhileTest, DISABLED_WhileWithPermutationAndTupleResult) { +TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { std::vector shape_elements = { ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})}; @@ -411,8 +410,7 @@ TEST_F(WhileTest, DISABLED_WhileWithPermutationAndTupleResult) { ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } -// TODO(b/63003356): 11-06-2017: fails on all back-ends with incorrect result. -TEST_F(WhileTest, DISABLED_WhileWithPermutationAndVectorResult) { +TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { std::vector shape_elements = { ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})}; @@ -913,8 +911,7 @@ TEST_F(WhileTest, WhileWithPrngScalarResult) { } } -// TODO(b/34969189) Fails with bad AtomicCmpSwap on GPU on 2017-09-11. -TEST_F(WhileTest, DISABLED_ON_GPU(WhileThatSwapsParameterWithTupleElement)) { +TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) { auto element_shape = ShapeUtil::MakeShape(F32, {2}); ComputationBuilder outer(client_, "outer"); @@ -950,8 +947,7 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileThatSwapsParameterWithTupleElement)) { ErrorSpec(1e-6)); } -// TODO(b/34969189) Fails with bad AtomicCmpSwap on GPU on 2017-09-11. -TEST_F(WhileTest, DISABLED_ON_GPU(WhileThatSwapsParameterWithBroadcast)) { +TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) { auto element_shape = ShapeUtil::MakeShape(F32, {2}); ComputationBuilder outer(client_, "outer"); diff --git a/tensorflow/compiler/xla/tools/parser/BUILD b/tensorflow/compiler/xla/tools/parser/BUILD index ce936af6c3376387c1ed9fa48da23b8af537f6e5..97aacf6b39f83978e732060817cd93ede81ca782 100644 --- a/tensorflow/compiler/xla/tools/parser/BUILD +++ b/tensorflow/compiler/xla/tools/parser/BUILD @@ -34,9 +34,9 @@ cc_library( deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", "//tensorflow/core:regexp_internal", ], diff --git a/tensorflow/compiler/xla/tools/parser/README.md b/tensorflow/compiler/xla/tools/parser/README.md index 6232967f5f04cbf316d985357ae84c28335531e2..2e329cc513dfa83070065a34b67b70ec2ca4b2e9 100644 --- a/tensorflow/compiler/xla/tools/parser/README.md +++ b/tensorflow/compiler/xla/tools/parser/README.md @@ -1,24 +1,26 @@ # HloModule string syntax -TODO: Support all subcomputations (for fusion, reduce, ...). - -TODO: Support all extra attributes, e.g. dimensions, strides. - ```yacc hlo_module : 'HloModule' name computations ; +/* If no computation is marked as ENTRY, the last computation will be the entry +computation of the module.*/ computations : computation | computation computations ; computation - : 'ENTRY' name param_list '->' shape instruction_list - | name param_list '->' shape instruction_list + : 'ENTRY' name param_list_to_shape instruction_list + | name param_list_to_shape instruction_list + | 'ENTRY' name instruction_list + | name instruction_list ; +/* If no instruction is marked as ROOT, the last instruction will be the root of +its computation. */ instruction_list : '{' instruction_list1 '}' ; @@ -41,6 +43,7 @@ operands1 ; operand : shape name + | name ; attributes @@ -60,6 +63,10 @@ attribute_value | '{' sub_attributes '}' ; +param_list_to_shape + : param_list '->' shape + ; + param_list : '(' param_list1 ')' ; @@ -84,6 +91,7 @@ tuple_elements name : identifier ':' | '%' identifier + | identifier ; identifier diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc index 56744440db1b17aa1cc8823feb1bad279f8f4f75..6d1e4173d25a032970284fc7abbc3d2ec30b27cd 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc @@ -17,7 +17,6 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" @@ -153,15 +152,15 @@ TokKind HloLexer::LexToken() { } } -// Lex a shape, name, keyword, opcode, attribute name, or the dim labels -// pattern. +// Lex a shape, name, keyword, attribute name, the dim labels pattern, and +// other identifiers. // // shape ::= ([a-zA-Z0-9_]*[0-9]*)\[([0-9,]*)\](?:\s*{([0-9,]*)})? // name ::= [a-zA-Z_][a-zA-Z0-9_.-]*: // keyword ::= HloModule, ENTRY, ... -// opcode ::= add, greater-than, ... // attribute_name ::= condition, body, dimensions, ... // dim_labels_pattern ::= [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} +// identifiers ::= other cases that match [a-zA-Z_][a-zA-Z0-9_.-]* TokKind HloLexer::LexIdentifier() { { auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); @@ -220,20 +219,6 @@ TokKind HloLexer::LexIdentifier() { #undef KEYWORD - // See if this is an opcode. - auto opcode = StringToHloOpcode(identifier.ToString()); - if (opcode.ok()) { - opcode_val_ = opcode.ValueOrDie(); - return TokKind::kOpcode; - } - - // See if this is an fusion kind. - auto kind = xla::StringToFusionKind(identifier.ToString()); - if (kind.ok()) { - fusion_kind_val_ = kind.ValueOrDie(); - return TokKind::kFusionKind; - } - { auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); static LazyRE2 dim_labels_pattern = { @@ -244,8 +229,9 @@ TokKind HloLexer::LexIdentifier() { return TokKind::kDimLabels; } } - current_ptr_ = token_start_ + 1; - return TokKind::kError; + + str_val_ = identifier.ToString(); + return TokKind::kIdent; } // Lex names after a % character. @@ -271,7 +257,8 @@ TokKind HloLexer::LexPercent() { // fp without exp ::= [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+) // dim_labels_pattern ::= [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} // dxd_pattern ::= [0-9]+(x[0-9]+)+ -// pad_pattern ::= [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* +// pad_pattern ::= +// [-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?(x[-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?)* // int ::= [-]?[0-9]+ // negative inf ::= '-inf' TokKind HloLexer::LexNumberOrPattern() { @@ -289,7 +276,7 @@ TokKind HloLexer::LexNumberOrPattern() { R"([0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,})"}; static LazyRE2 dxd_pattern = {R"([0-9]+(x[0-9]+)+)"}; static LazyRE2 pad_pattern = { - R"([0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)*)"}; + R"([-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?(x[-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?)*)"}; if (RE2::Consume(&consumable, *dim_labels_pattern)) { current_ptr_ = consumable.begin(); @@ -326,18 +313,43 @@ TokKind HloLexer::LexNumberOrPattern() { return TokKind::kError; } -StringPiece HloLexer::GetCurrentLine() const { - const char* start = token_start_; - const char* end = current_ptr_; - if (!CanDereference(start) || !CanDereference(end)) { - return "LINE OUT OF RANGE"; +std::pair HloLexer::GetLineAndColumn(LocTy location) const { + unsigned line_no = 1; + const char* start = buf_.begin(); + const char* ptr = start; + if (line_no_cache_.last_query && CanDereference(line_no_cache_.last_query) && + line_no_cache_.last_query <= location) { + ptr = line_no_cache_.last_query; + line_no = line_no_cache_.line_no_of_query; } - while (start > buf_.begin() && *start != '\n') { - start--; + for (; ptr != location; ptr++) { + if (*ptr == '\n') { + line_no++; + } } - while (end < buf_.end() && *end != '\n') { - end++; + + // Update the line number cache. + line_no_cache_.last_query = ptr; + line_no_cache_.line_no_of_query = line_no; + size_t line_offset = StringPieceFromPointers(start, ptr).rfind('\n'); + if (line_offset == StringPiece::npos) { + line_offset = 0; } + return {line_no, ptr - start - line_offset}; +} + +StringPiece HloLexer::GetLine(LocTy loc) const { + if (!CanDereference(loc)) { + return "LINE OUT OF RANGE"; + } + size_t line_start = + StringPieceFromPointers(buf_.begin(), loc + 1).rfind('\n'); + const char* start = line_start == StringPiece::npos + ? buf_.begin() + : buf_.begin() + line_start + 1; + size_t line_end = StringPieceFromPointers(loc, buf_.end()).find('\n'); + const char* end = line_end == StringPiece::npos ? buf_.end() : loc + line_end; + return StringPieceFromPointers(start, end); } @@ -428,14 +440,12 @@ string TokKindToString(TokKind kind) { return "kDxD"; case TokKind::kPad: return "kPad"; + case TokKind::kIdent: + return "kIdent"; case TokKind::kString: return "kString"; case TokKind::kShape: return "kShape"; - case TokKind::kOpcode: - return "kOpcode"; - case TokKind::kFusionKind: - return "kFusionKind"; case TokKind::kInt: return "kInt"; case TokKind::kDecimal: diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h index 5c9d1bf3912584040dc5260cc6730247d439fd60..27880b9b8afbfa58abfedc3b2cecd5236b78a6d6 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h +++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h @@ -18,9 +18,8 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/tools/parser/hlo_token.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" @@ -48,6 +47,7 @@ class HloLexer { case TokKind::kDxD: case TokKind::kPad: case TokKind::kString: + case TokKind::kIdent: return str_val_; default: LOG(FATAL) << "This token does not have string value"; @@ -57,14 +57,6 @@ class HloLexer { CHECK(GetKind() == TokKind::kShape); return shape_val_; } - HloOpcode GetOpcodeVal() const { - CHECK(GetKind() == TokKind::kOpcode); - return opcode_val_; - } - HloInstruction::FusionKind GetFusionKindVal() const { - CHECK(GetKind() == TokKind::kFusionKind); - return fusion_kind_val_; - } int64 GetInt64Val() const { CHECK(GetKind() == TokKind::kInt); return int64_val_; @@ -74,8 +66,16 @@ class HloLexer { return decimal_val_; } - // Returns the line of text that is currently being lexed. - tensorflow::StringPiece GetCurrentLine() const; + typedef const char* LocTy; + + // Returns the location of the current token. + LocTy GetLoc() const { return token_start_; } + + // Returns the line and column of a location in the buffer. + std::pair GetLineAndColumn(LocTy location) const; + + // Returns the whole line given the location. + tensorflow::StringPiece GetLine(LocTy loc) const; private: // Returns the current character. If it's neither the end of input buffer nor @@ -114,10 +114,15 @@ class HloLexer { TokKind current_kind_; string str_val_; Shape shape_val_; - HloOpcode opcode_val_; - HloInstruction::FusionKind fusion_kind_val_; int64 int64_val_; double decimal_val_; + + struct LineNoCacheTy { + const char* last_query; + unsigned line_no_of_query; + }; + // This caches the line number of the previous query. + mutable LineNoCacheTy line_no_cache_{nullptr, 0}; }; } // namespace tools diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 2112b3e710a4543d14f0e31243aef74dc6943b54..68fb9dd9ec8fa60b68906448ef55aa669c2506cb 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -40,6 +41,8 @@ const double kF16max = 65504; // Parser for the HloModule::ToString() format text. class HloParser { public: + using LocTy = HloLexer::LocTy; + explicit HloParser(StringPiece str, const HloModuleConfig& config) : lexer_(str), config_(config) {} @@ -56,7 +59,7 @@ class HloParser { // ParseXXX returns false if an error occurred. bool ParseHloModule(); bool ParseComputations(); - bool ParseComputation(); + bool ParseComputation(HloComputation** entry_computation); bool ParseInstructionList(HloComputation::Builder* builder, string* root_name); bool ParseInstruction(HloComputation::Builder* builder, string* root_name); @@ -104,6 +107,7 @@ class HloParser { kPaddingConfig, kMetadata, kFusionKind, + kDistribution, }; struct AttrConfig { @@ -167,6 +171,7 @@ class HloParser { bool ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, std::vector* result); + bool ParseParamListToShape(Shape* shape, LocTy* shape_loc); bool ParseParamList(); bool ParseName(string* result); bool ParseAttributeName(string* result); @@ -174,13 +179,21 @@ class HloParser { bool ParseShape(Shape* result); bool ParseOpcode(HloOpcode* result); bool ParseFusionKind(HloInstruction::FusionKind* result); + bool ParseRandomDistribution(RandomDistribution* result); bool ParseInt64(int64* result); bool ParseDouble(double* result); bool ParseBool(bool* result); bool ParseToken(TokKind kind, const string& msg); + // Returns true if the current token is the beginning of a shape. + bool CanBeShape(); + // Returns true if the current token is the beginning of a + // param_list_to_shape. + bool CanBeParamListToShape(); + // Logs the current parsing line and the given message. Always returns false. bool TokenError(StringPiece msg); + bool Error(LocTy loc, StringPiece msg); // If the current token is 'kind', eats it (i.e. lexes the next token) and // returns true. @@ -191,10 +204,12 @@ class HloParser { // Adds the instruction to the pool. Returns false and emits an error if the // instruction already exists. - bool AddInstruction(const string& name, HloInstruction* instruction); + bool AddInstruction(const string& name, HloInstruction* instruction, + LocTy name_loc); // Adds the computation to the pool. Returns false and emits an error if the // computation already exists. - bool AddComputation(const string& name, HloComputation* computation); + bool AddComputation(const string& name, HloComputation* computation, + LocTy name_loc); // The map from the instruction name to the instruction. This does not own the // instructions. @@ -203,19 +218,30 @@ class HloParser { HloLexer lexer_; std::unique_ptr module_; + std::vector> computations_; const HloModuleConfig config_; std::vector error_; }; -bool HloParser::TokenError(StringPiece msg) { - const string error = - StrCat("was parsing \"", lexer_.GetCurrentLine(), "\"; token ", - TokKindToString(lexer_.GetKind()), "; ", msg); - VLOG(1) << "TokenError: " << error; - error_.push_back(error); +bool HloParser::Error(LocTy loc, StringPiece msg) { + auto line_col = lexer_.GetLineAndColumn(loc); + const unsigned line = line_col.first; + const unsigned col = line_col.second; + std::vector error_lines; + error_lines.push_back( + StrCat("was parsing ", line, ":", col, ": error: ", msg)); + error_lines.push_back(lexer_.GetLine(loc).ToString()); + error_lines.push_back(col == 0 ? "" : StrCat(string(col - 1, ' '), "^")); + + error_.push_back(tensorflow::str_util::Join(error_lines, "\n")); + VLOG(1) << "Error: " << error_.back(); return false; } +bool HloParser::TokenError(StringPiece msg) { + return Error(lexer_.GetLoc(), msg); +} + bool HloParser::Run() { lexer_.Lex(); return ParseHloModule(); @@ -241,27 +267,67 @@ bool HloParser::ParseHloModule() { // computations ::= (computation)+ bool HloParser::ParseComputations() { + HloComputation* entry_computation = nullptr; do { - if (!ParseComputation()) { + if (!ParseComputation(&entry_computation)) { return false; } } while (lexer_.GetKind() != TokKind::kEof); + + for (int i = 0; i < computations_.size(); i++) { + // If entry_computation is not nullptr, it means the computation it pointed + // to is marked with "ENTRY"; otherwise, no computation is marked with + // "ENTRY", and we use the last computation as the entry computation. We + // add the non-entry computations as embedded computations to the module. + if ((entry_computation != nullptr && + computations_[i].get() != entry_computation) || + (entry_computation == nullptr && i != computations_.size() - 1)) { + module_->AddEmbeddedComputation(std::move(computations_[i])); + continue; + } + auto computation = + module_->AddEntryComputation(std::move(computations_[i])); + // The parameters and result layouts were set to default layout. Here we + // set the layouts to what the hlo text says. + for (int p = 0; p < computation->num_parameters(); p++) { + const Shape& param_shape = computation->parameter_instruction(p)->shape(); + if (param_shape.has_layout()) { + module_->mutable_entry_computation_layout() + ->mutable_parameter_layout(p) + ->ResetLayout(param_shape.layout()); + } + } + const Shape& result_shape = computation->root_instruction()->shape(); + if (result_shape.has_layout()) { + module_->mutable_entry_computation_layout() + ->mutable_result_layout() + ->ResetLayout(result_shape.layout()); + } + } + return true; } -// computation ::= ('ENTRY')? name param_list '->' shape instruction_list -bool HloParser::ParseComputation() { +// computation ::= ('ENTRY')? name (param_list_to_shape)? instruction_list +bool HloParser::ParseComputation(HloComputation** entry_computation) { + LocTy maybe_entry_loc = lexer_.GetLoc(); const bool is_entry_computation = EatIfPresent(TokKind::kw_ENTRY); + string name; + LocTy name_loc = lexer_.GetLoc(); if (!ParseName(&name)) { return false; } auto builder = MakeUnique(name); + LocTy shape_loc = nullptr; Shape shape; + if (CanBeParamListToShape() && !ParseParamListToShape(&shape, &shape_loc)) { + return false; + } + string root_name; - if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'") || - !ParseShape(&shape) || !ParseInstructionList(builder.get(), &root_name)) { + if (!ParseInstructionList(builder.get(), &root_name)) { return false; } @@ -273,14 +339,37 @@ bool HloParser::ParseComputation() { LOG(FATAL) << "instruction " << root_name << " was marked as ROOT but the parser has not seen it before"; } + // Now root can be either an existing instruction or a nullptr. If it's a // nullptr, the implementation of Builder will set the last instruction as // root instruction. - HloComputation* computation = - is_entry_computation - ? module_->AddEntryComputation(builder->Build(root)) - : module_->AddEmbeddedComputation(builder->Build(root)); - return AddComputation(name, computation); + computations_.emplace_back(builder->Build(root)); + HloComputation* computation = computations_.back().get(); + + if (!root) { + root = computation->root_instruction(); + } else { + CHECK_EQ(root, computation->root_instruction()); + } + + // If param_list_to_shape was present, check compatibility. + if (shape_loc != nullptr && !ShapeUtil::Compatible(root->shape(), shape)) { + return Error( + shape_loc, + StrCat("Shape of computation ", name, ", ", + ShapeUtil::HumanString(shape), + ", is not compatible with that of its root instruction ", + root_name, ", ", ShapeUtil::HumanString(root->shape()))); + } + + if (is_entry_computation) { + if (*entry_computation != nullptr) { + return Error(maybe_entry_loc, "expects only one ENTRY"); + } + *entry_computation = computation; + } + + return AddComputation(name, computation, name_loc); } // instruction_list ::= '{' instruction_list1 '}' @@ -307,13 +396,21 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, Shape shape; HloOpcode opcode; std::vector operands; + + LocTy maybe_root_loc = lexer_.GetLoc(); bool is_root = EatIfPresent(TokKind::kw_ROOT); + + const LocTy name_loc = lexer_.GetLoc(); if (!ParseName(&name) || !ParseToken(TokKind::kEqual, "expects '=' in instruction") || !ParseShape(&shape) || !ParseOpcode(&opcode)) { return false; } + if (is_root) { + if (!root_name->empty()) { + return Error(maybe_root_loc, "one computation should have only one ROOT"); + } *root_name = name; } @@ -434,13 +531,21 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, HloInstruction::CreateConvert(shape, operands[0])); break; } - case HloOpcode::kCrossReplicaSum: { + case HloOpcode::kBitcastConvert: { if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( - HloInstruction::CreateCrossReplicaSum(shape, operands[0])); + HloInstruction::CreateBitcastConvert(shape, operands[0])); + break; + } + case HloOpcode::kCrossReplicaSum: { + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateCrossReplicaSum(shape, operands)); break; } case HloOpcode::kReshape: { @@ -549,13 +654,16 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kReduceWindow: { optional reduce_computation; optional window; - attrs["window"] = {/*required=*/true, AttrTy::kWindow, &window}; + attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &reduce_computation}; if (!ParseOperands(&operands, /*expected_size=*/2) || !ParseAttributes(attrs)) { return false; } + if (!window) { + window.emplace(); + } instruction = builder->AddInstruction(HloInstruction::CreateReduceWindow( shape, /*operand=*/operands[0], /*init_value=*/operands[1], *window, *reduce_computation)); @@ -564,13 +672,16 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kConvolution: { optional window; optional dnums; - attrs["window"] = {/*required=*/true, AttrTy::kWindow, &window}; + attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; attrs["dim_labels"] = {/*required=*/true, AttrTy::kConvolutionDimensionNumbers, &dnums}; if (!ParseOperands(&operands, /*expected_size=*/2) || !ParseAttributes(attrs)) { return false; } + if (!window) { + window.emplace(); + } instruction = builder->AddInstruction(HloInstruction::CreateConvolve( shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums)); break; @@ -644,11 +755,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional scatter; attrs["scatter"] = {/*required=*/true, AttrTy::kHloComputation, &scatter}; optional window; - attrs["window"] = {/*required=*/true, AttrTy::kWindow, &window}; + attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; if (!ParseOperands(&operands, /*expected_size=*/3) || !ParseAttributes(attrs)) { return false; } + if (!window) { + window.emplace(); + } instruction = builder->AddInstruction(HloInstruction::CreateSelectAndScatter( shape, /*operand=*/operands[0], *select, *window, @@ -798,15 +912,69 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, shape, operands[0], config ? *config : "")); break; } - case HloOpcode::kConditional: - case HloOpcode::kCustomCall: - case HloOpcode::kReducePrecision: - case HloOpcode::kRng: + case HloOpcode::kRng: { + optional distribution; + attrs["distribution"] = {/*required=*/true, AttrTy::kDistribution, + &distribution}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateRng(shape, *distribution, operands)); + break; + } + case HloOpcode::kReducePrecision: { + optional exponent_bits; + optional mantissa_bits; + attrs["exponent_bits"] = {/*required=*/true, AttrTy::kInt64, + &exponent_bits}; + attrs["mantissa_bits"] = {/*required=*/true, AttrTy::kInt64, + &mantissa_bits}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = + builder->AddInstruction(HloInstruction::CreateReducePrecision( + shape, operands[0], static_cast(*exponent_bits), + static_cast(*mantissa_bits))); + break; + } + case HloOpcode::kConditional: { + optional true_computation; + optional false_computation; + attrs["true_computation"] = {/*required=*/true, AttrTy::kHloComputation, + &true_computation}; + attrs["false_computation"] = {/*required=*/true, AttrTy::kHloComputation, + &false_computation}; + if (!ParseOperands(&operands, /*expected_size=*/3) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateConditional( + shape, /*pred=*/operands[0], + /*true_computation_arg=*/operands[1], *true_computation, + /*false_computation_arg=*/operands[2], *false_computation)); + break; + } + case HloOpcode::kCustomCall: { + optional custom_call_target; + attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString, + &custom_call_target}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateCustomCall( + shape, operands, *custom_call_target)); + break; + } case HloOpcode::kTrace: return TokenError(StrCat("parsing not yet implemented for op: ", HloOpcodeString(opcode))); } + instruction->set_name(name); + // Add common attrs (sharding, control predecessors) to the instruction, if // they were seen. if (sharding) { @@ -817,15 +985,15 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, for (auto* pre : *predecessors) { Status status = pre->AddControlDependencyTo(instruction); if (!status.ok()) { - return TokenError(StrCat("error adding control dependency for: ", name, - " status: ", status.ToString())); + return Error(name_loc, StrCat("error adding control dependency for: ", + name, " status: ", status.ToString())); } } } if (metadata) { instruction->set_metadata(*metadata); } - return AddInstruction(name, instruction); + return AddInstruction(name, instruction, name_loc); } // NOLINT(readability/fn_size) // ::= '{' (single_sharding | tuple_sharding) '}' @@ -871,6 +1039,7 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return false; } + LocTy loc = lexer_.GetLoc(); bool maximal = false; bool replicated = false; std::vector devices; @@ -938,34 +1107,35 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, if (replicated) { if (!devices.empty()) { - return TokenError( - "replicated shardings should not have any devices assigned"); + return Error(loc, + "replicated shardings should not have any devices assigned"); } if (!ShapeUtil::Equal(tile_shape, Shape())) { - return TokenError( - "replicated shardings should not have any tile shape set"); + return Error(loc, + "replicated shardings should not have any tile shape set"); } sharding->set_type(OpSharding::Type::OpSharding_Type_REPLICATED); } else if (maximal) { if (devices.size() != 1) { - return TokenError( - "maximal shardings should have exactly one device assigned"); + return Error(loc, + "maximal shardings should have exactly one device assigned"); } if (!ShapeUtil::Equal(tile_shape, Shape())) { - return TokenError("maximal shardings should not have any tile shape set"); + return Error(loc, "maximal shardings should not have any tile shape set"); } sharding->set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); sharding->add_tile_assignment_devices(devices[0]); } else { if (devices.size() <= 1) { - return TokenError( - "non-maximal shardings must have more than one device assigned"); + return Error( + loc, "non-maximal shardings must have more than one device assigned"); } if (ShapeUtil::Equal(tile_shape, Shape())) { - return TokenError("non-maximal shardings should have a tile shape set"); + return Error(loc, "non-maximal shardings should have a tile shape set"); } if (tile_assignment_dimensions.empty()) { - return TokenError( + return Error( + loc, "non-maximal shardings must have a tile assignment list including " "dimensions"); } @@ -990,10 +1160,11 @@ bool HloParser::ParseInstructionNames( "expects '{' at the beginning of instruction name list")) { return false; } + LocTy loc = lexer_.GetLoc(); do { string name; if (!ParseName(&name)) { - return TokenError("expects a instruction name"); + return Error(loc, "expects a instruction name"); } HloInstruction* instr = tensorflow::gtl::FindPtrOrNull(instruction_pool_, name); @@ -1005,7 +1176,7 @@ bool HloParser::ParseInstructionNames( } while (EatIfPresent(TokKind::kComma)); return ParseToken(TokKind::kRbrace, - "expects '}' at the end of control instructions"); + "expects '}' at the end of instruction name list"); } bool HloParser::SetValueInLiteral(int64 value, int64 linear_index, @@ -1040,6 +1211,8 @@ bool HloParser::SetValueInLiteral(double value, int64 linear_index, switch (shape.element_type()) { case F16: return SetValueInLiteralHelper(value, linear_index, literal); + case BF16: + return SetValueInLiteralHelper(value, linear_index, literal); case F32: return SetValueInLiteralHelper(value, linear_index, literal); case F64: @@ -1078,7 +1251,8 @@ bool HloParser::SetValueInLiteralHelper(ParsedElemT value, int64 linear_index, (std::numeric_limits::infinity() == value || -std::numeric_limits::infinity() == value))) { // Skip range checking for non-finite value. - } else if (literal->shape().element_type() == F16) { + } else if (literal->shape().element_type() == F16 || + literal->shape().element_type() == BF16) { if (value > kF16max || value < -kF16max) { return TokenError(StrCat( "value ", value, " is out of range for literal's primitive type ", @@ -1164,12 +1338,6 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr* literal, // rank2345 ::= shape nested_array bool HloParser::ParseNonTupleLiteral(std::unique_ptr* literal, const Shape& shape) { - const int64 size = ShapeUtil::ElementsIn(shape); - if (size == 0) { - *literal = Literal::CreateFromShape(shape); - return true; - } - const int64 rank = ShapeUtil::Rank(shape); if (rank > 1 && !EatShapeAndCheckCompatible(shape)) { return false; @@ -1270,20 +1438,22 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr* literal, } lexer_.Lex(); } else if (primitive_util::IsIntegralType(shape.element_type())) { + LocTy loc = lexer_.GetLoc(); int64 value; if (!ParseInt64(&value)) { - return TokenError(StrCat("expects integer for primitive type: ", + return Error(loc, StrCat("expects integer for primitive type: ", PrimitiveType_Name(shape.element_type()))); } if (!SetValueInLiteral(value, linear_index++, literal->get())) { return false; } } else if (primitive_util::IsFloatingPointType(shape.element_type())) { + LocTy loc = lexer_.GetLoc(); double value; if (!ParseDouble(&value)) { - return TokenError( - StrCat("expect floating point value for primitive type: ", - PrimitiveType_Name(shape.element_type()))); + return Error( + loc, StrCat("expect floating point value for primitive type: ", + PrimitiveType_Name(shape.element_type()))); } if (!SetValueInLiteral(value, linear_index++, literal->get())) { return false; @@ -1305,7 +1475,7 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr* literal, // operands1 // ::= /*empty*/ // ::= operand (, operand)* -// operand ::= shape name +// operand ::= (shape)? name bool HloParser::ParseOperands(std::vector* operands) { if (!ParseToken(TokKind::kLparen, "expects '(' at the beginning of operands")) { @@ -1315,15 +1485,21 @@ bool HloParser::ParseOperands(std::vector* operands) { // empty } else { do { - Shape shape; + LocTy loc = lexer_.GetLoc(); string name; - if (!ParseShape(&shape) || !ParseName(&name)) { + if (CanBeShape()) { + Shape shape; + if (!ParseShape(&shape)) { + return false; + } + } + if (!ParseName(&name)) { return false; } HloInstruction* instruction = tensorflow::gtl::FindPtrOrNull(instruction_pool_, name); if (!instruction) { - return TokenError(StrCat("instruction does not exist: ", name)); + return Error(loc, StrCat("instruction does not exist: ", name)); } operands->push_back(instruction); } while (EatIfPresent(TokKind::kComma)); @@ -1333,11 +1509,12 @@ bool HloParser::ParseOperands(std::vector* operands) { bool HloParser::ParseOperands(std::vector* operands, const int expected_size) { + LocTy loc = lexer_.GetLoc(); if (!ParseOperands(operands)) { return false; } if (expected_size != operands->size()) { - return TokenError(StrCat("expects ", expected_size, " operands, but has ", + return Error(loc, StrCat("expects ", expected_size, " operands, but has ", operands->size(), " operands")); } return true; @@ -1346,6 +1523,7 @@ bool HloParser::ParseOperands(std::vector* operands, // sub_attributes ::= '{' (','? attribute)* '}' bool HloParser::ParseSubAttributes( const std::unordered_map& attrs) { + LocTy loc = lexer_.GetLoc(); if (!ParseToken(TokKind::kLbrace, "expects '{' to start sub attributes")) { return false; } @@ -1364,7 +1542,7 @@ bool HloParser::ParseSubAttributes( for (const auto& attr_it : attrs) { if (attr_it.second.required && seen_attrs.find(attr_it.first) == seen_attrs.end()) { - return TokenError(Printf("sub-attribute %s is expected but not seen", + return Error(loc, Printf("sub-attribute %s is expected but not seen", attr_it.first.c_str())); } } @@ -1374,6 +1552,7 @@ bool HloParser::ParseSubAttributes( // attributes ::= (',' attribute)* bool HloParser::ParseAttributes( const std::unordered_map& attrs) { + LocTy loc = lexer_.GetLoc(); std::unordered_set seen_attrs; while (EatIfPresent(TokKind::kComma)) { if (!ParseAttributeHelper(attrs, &seen_attrs)) { @@ -1384,7 +1563,7 @@ bool HloParser::ParseAttributes( for (const auto& attr_it : attrs) { if (attr_it.second.required && seen_attrs.find(attr_it.first) == seen_attrs.end()) { - return TokenError(Printf("attribute %s is expected but not seen", + return Error(loc, Printf("attribute %s is expected but not seen", attr_it.first.c_str())); } } @@ -1394,21 +1573,23 @@ bool HloParser::ParseAttributes( bool HloParser::ParseAttributeHelper( const std::unordered_map& attrs, std::unordered_set* seen_attrs) { + LocTy loc = lexer_.GetLoc(); string name; if (!ParseAttributeName(&name)) { - return TokenError("error parsing attributes"); + return Error(loc, "error parsing attributes"); } VLOG(1) << "Parsing attribute " << name; if (!seen_attrs->insert(name).second) { - return TokenError(Printf("attribute %s already exists", name.c_str())); + return Error(loc, Printf("attribute %s already exists", name.c_str())); } auto attr_it = attrs.find(name); if (attr_it == attrs.end()) { - return TokenError(Printf("unexpected attribute %s", name.c_str())); + return Error(loc, Printf("unexpected attribute %s", name.c_str())); } AttrTy attr_type = attr_it->second.attr_type; void* attr_out_ptr = attr_it->second.result; bool success = [&] { + LocTy attr_loc = lexer_.GetLoc(); switch (attr_type) { case AttrTy::kInt64: { int64 result; @@ -1424,7 +1605,7 @@ bool HloParser::ParseAttributeHelper( return false; } if (result != static_cast(result)) { - return TokenError("value out of range for int32"); + return Error(attr_loc, "value out of range for int32"); } static_cast*>(attr_out_ptr) ->emplace(static_cast(result)); @@ -1437,7 +1618,7 @@ bool HloParser::ParseAttributeHelper( } if (result > std::numeric_limits::max() || result < std::numeric_limits::lowest()) { - return TokenError("value out of range for float"); + return Error(attr_loc, "value out of range for float"); } static_cast*>(attr_out_ptr) ->emplace(static_cast(result)); @@ -1536,22 +1717,32 @@ bool HloParser::ParseAttributeHelper( static_cast*>(attr_out_ptr)->emplace(result); return true; } + case AttrTy::kDistribution: { + RandomDistribution result; + if (!ParseRandomDistribution(&result)) { + return false; + } + static_cast*>(attr_out_ptr) + ->emplace(result); + return true; + } } }(); if (!success) { - return TokenError(Printf("error parsing attribute %s", name.c_str())); + return Error(loc, Printf("error parsing attribute %s", name.c_str())); } return true; } bool HloParser::ParseComputationName(HloComputation** value) { string name; + LocTy loc = lexer_.GetLoc(); if (!ParseName(&name)) { - return TokenError("expects computation name"); + return Error(loc, "expects computation name"); } *value = tensorflow::gtl::FindPtrOrNull(computation_pool_, name); if (*value == nullptr) { - return TokenError(StrCat("computation does not exist: ", name)); + return Error(loc, StrCat("computation does not exist: ", name)); } return true; } @@ -1560,6 +1751,7 @@ bool HloParser::ParseComputationName(HloComputation** value) { // The subattributes can appear in any order. 'size=' is required, others are // optional. bool HloParser::ParseWindow(Window* window) { + LocTy loc = lexer_.GetLoc(); if (!ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) { return false; } @@ -1569,10 +1761,12 @@ bool HloParser::ParseWindow(Window* window) { std::vector> pad; std::vector lhs_dilate; std::vector rhs_dilate; + std::vector rhs_reversal; while (lexer_.GetKind() != TokKind::kRbrace) { + LocTy attr_loc = lexer_.GetLoc(); string field_name; if (!ParseAttributeName(&field_name)) { - return TokenError("expects sub-attributes in window"); + return Error(attr_loc, "expects sub-attributes in window"); } bool ok = [&] { if (field_name == "size") { @@ -1590,7 +1784,10 @@ bool HloParser::ParseWindow(Window* window) { if (field_name == "pad") { return ParseWindowPad(&pad); } - return TokenError(StrCat("unexpected attribute name: ", field_name)); + if (field_name == "rhs_reversal") { + return ParseDxD("rhs_reversal", &rhs_reversal); + } + return Error(loc, StrCat("unexpected attribute name: ", field_name)); }(); if (!ok) { return false; @@ -1598,20 +1795,20 @@ bool HloParser::ParseWindow(Window* window) { } if (size.empty()) { - return TokenError( - "sub-attribute 'size=' is required in the window attribute"); + return Error(loc, + "sub-attribute 'size=' is required in the window attribute"); } if (!stride.empty() && stride.size() != size.size()) { - return TokenError("expects 'stride=' has the same size as 'size='"); + return Error(loc, "expects 'stride=' has the same size as 'size='"); } if (!lhs_dilate.empty() && lhs_dilate.size() != size.size()) { - return TokenError("expects 'lhs_dilate=' has the same size as 'size='"); + return Error(loc, "expects 'lhs_dilate=' has the same size as 'size='"); } if (!rhs_dilate.empty() && rhs_dilate.size() != size.size()) { - return TokenError("expects 'rhs_dilate=' has the same size as 'size='"); + return Error(loc, "expects 'rhs_dilate=' has the same size as 'size='"); } if (!pad.empty() && pad.size() != size.size()) { - return TokenError("expects 'pad=' has the same size as 'size='"); + return Error(loc, "expects 'pad=' has the same size as 'size='"); } for (int i = 0; i < size.size(); i++) { @@ -1626,6 +1823,8 @@ bool HloParser::ParseWindow(Window* window) { lhs_dilate.empty() ? 1 : lhs_dilate[i]); window->mutable_dimensions(i)->set_window_dilation( rhs_dilate.empty() ? 1 : rhs_dilate[i]); + window->mutable_dimensions(i)->set_window_reversal( + rhs_reversal.empty() ? false : (rhs_reversal[i] == 1)); } return ParseToken(TokKind::kRbrace, "expected '}' to end window attribute"); } @@ -1673,7 +1872,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( StrCat("expects unique lhs dimension numbers, but sees ", lhs)); } for (int i = 0; i < rank - 2; i++) { - dnums->add_spatial_dimensions(-1); + dnums->add_input_spatial_dimensions(-1); } for (int i = 0; i < rank; i++) { char c = lhs[i]; @@ -1682,7 +1881,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( } else if (c == 'f') { dnums->set_input_feature_dimension(i); } else if (c < '0' + rank && c >= '0') { - dnums->set_spatial_dimensions(c - '0', i); + dnums->set_input_spatial_dimensions(c - '0', i); } else { return TokenError( Printf("expects [0-%lldbf] in lhs dimension numbers", rank - 1)); @@ -1720,6 +1919,9 @@ bool HloParser::ParseConvolutionDimensionNumbers( return TokenError( StrCat("expects unique output dimension numbers, but sees ", out)); } + for (int i = 0; i < rank - 2; i++) { + dnums->add_output_spatial_dimensions(-1); + } for (int i = 0; i < rank; i++) { char c = out[i]; if (c == 'b') { @@ -1727,11 +1929,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( } else if (c == 'f') { dnums->set_output_feature_dimension(i); } else if (c < '0' + rank && c >= '0') { - if (dnums->spatial_dimensions(c - '0') != i) { - return TokenError( - "output spatial dimensions should be the same as input spatial " - "dimensions"); - } + dnums->set_output_spatial_dimensions(c - '0', i); } else { return TokenError( Printf("expects [0-%lldbf] in output dimension numbers", rank - 1)); @@ -1772,20 +1970,19 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) { return ParseToken(TokKind::kRbrace, "expects '}' to end ranges"); } do { + LocTy loc = lexer_.GetLoc(); ranges.emplace_back(); if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kColon, &ranges.back())) { return false; } - } while (EatIfPresent(TokKind::kComma)); - - for (const auto& range : ranges) { + const auto& range = ranges.back(); if (range.size() != 2 && range.size() != 3) { - return TokenError(Printf( - "expects [start:limit:step] or [start:limit], but sees %ld elements.", - range.size())); + return Error(loc, Printf("expects [start:limit:step] or [start:limit], " + "but sees %ld elements.", + range.size())); } - } + } while (EatIfPresent(TokKind::kComma)); for (const auto& range : ranges) { result->starts.push_back(range[0]); @@ -1821,6 +2018,19 @@ bool HloParser::ParseInt64List(const TokKind start, const TokKind end, end, StrCat("expects an int64 list to end with ", TokKindToString(end))); } +// param_list_to_shape ::= param_list '->' shape +bool HloParser::ParseParamListToShape(Shape* shape, LocTy* shape_loc) { + if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'")) { + return false; + } + *shape_loc = lexer_.GetLoc(); + return ParseShape(shape); +} + +bool HloParser::CanBeParamListToShape() { + return lexer_.GetKind() == TokKind::kLparen; +} + // param_list ::= '(' param_list1 ')' // param_list1 // ::= /*empty*/ @@ -1837,8 +2047,8 @@ bool HloParser::ParseParamList() { } else { do { Shape shape; - if (!ParseToken(TokKind::kName, "expects name in parameter") || - !ParseShape(&shape)) { + string name; + if (!ParseName(&name) || !ParseShape(&shape)) { return false; } } while (EatIfPresent(TokKind::kComma)); @@ -1877,9 +2087,17 @@ bool HloParser::ParseShape(Shape* result) { return true; } +bool HloParser::CanBeShape() { + // A non-tuple shape starts with a kShape token; a tuple shape starts with + // '('. + return lexer_.GetKind() == TokKind::kShape || + lexer_.GetKind() == TokKind::kLparen; +} + bool HloParser::ParseName(string* result) { VLOG(1) << "ParseName"; - if (lexer_.GetKind() != TokKind::kName) { + if (lexer_.GetKind() != TokKind::kIdent && + lexer_.GetKind() != TokKind::kName) { return TokenError("expects name"); } *result = lexer_.GetStrVal(); @@ -1907,15 +2125,16 @@ bool HloParser::ParseString(string* result) { } bool HloParser::ParseDxD(const string& name, std::vector* result) { + LocTy loc = lexer_.GetLoc(); if (!result->empty()) { - return TokenError( - Printf("sub-attribute '%s=' already exists", name.c_str())); + return Error(loc, + Printf("sub-attribute '%s=' already exists", name.c_str())); } // 1D if (lexer_.GetKind() == TokKind::kInt) { int64 number; if (!ParseInt64(&number)) { - return TokenError(Printf("expects sub-attribute '%s=i'", name.c_str())); + return Error(loc, Printf("expects sub-attribute '%s=i'", name.c_str())); } result->push_back(number); return true; @@ -1924,8 +2143,8 @@ bool HloParser::ParseDxD(const string& name, std::vector* result) { if (lexer_.GetKind() == TokKind::kDxD) { string str = lexer_.GetStrVal(); if (!SplitAndParseAsInts(str, 'x', result)) { - return TokenError( - Printf("expects sub-attribute '%s=ixj...'", name.c_str())); + return Error(loc, + Printf("expects sub-attribute '%s=ixj...'", name.c_str())); } lexer_.Lex(); return true; @@ -1934,8 +2153,9 @@ bool HloParser::ParseDxD(const string& name, std::vector* result) { } bool HloParser::ParseWindowPad(std::vector>* pad) { + LocTy loc = lexer_.GetLoc(); if (!pad->empty()) { - return TokenError("sub-attribute 'pad=' already exists"); + return Error(loc, "sub-attribute 'pad=' already exists"); } if (lexer_.GetKind() != TokKind::kPad) { return TokenError("expects window pad pattern, e.g., '0_0x3_3'"); @@ -1946,8 +2166,8 @@ bool HloParser::ParseWindowPad(std::vector>* pad) { std::vector low_high; if (!SplitAndParseAsInts(padding_str[i], '_', &low_high) || low_high.size() != 2) { - return TokenError( - "expects padding_low and padding_high separated by '_'"); + return Error(loc, + "expects padding_low and padding_high separated by '_'"); } pad->push_back(low_high); } @@ -1963,15 +2183,16 @@ bool HloParser::ParsePaddingConfig(PaddingConfig* padding) { if (lexer_.GetKind() != TokKind::kPad) { return TokenError("expects padding config, e.g., '0_0_0x3_3_1'"); } + LocTy loc = lexer_.GetLoc(); string str = lexer_.GetStrVal(); std::vector padding_str = Split(str, 'x'); for (const auto& padding_dim_str : padding_str) { std::vector padding_dim; if (!SplitAndParseAsInts(padding_dim_str, '_', &padding_dim) || (padding_dim.size() != 2 && padding_dim.size() != 3)) { - return TokenError( - "expects padding config pattern like 'low_high_interior' or " - "'low_high'"); + return Error(loc, + "expects padding config pattern like 'low_high_interior' or " + "'low_high'"); } auto* dim = padding->add_dimensions(); dim->set_edge_padding_low(padding_dim[0]); @@ -2013,20 +2234,51 @@ bool HloParser::ParseMetadata(OpMetadata* metadata) { bool HloParser::ParseOpcode(HloOpcode* result) { VLOG(1) << "ParseOpcode"; - if (lexer_.GetKind() != TokKind::kOpcode) { + if (lexer_.GetKind() != TokKind::kIdent) { return TokenError("expects opcode"); } - *result = lexer_.GetOpcodeVal(); + string val = lexer_.GetStrVal(); + auto status_or_result = StringToHloOpcode(val); + if (!status_or_result.ok()) { + return TokenError( + Printf("expects opcode but sees: %s, error: %s", val.c_str(), + status_or_result.status().error_message().c_str())); + } + *result = status_or_result.ValueOrDie(); lexer_.Lex(); return true; } bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) { VLOG(1) << "ParseFusionKind"; - if (lexer_.GetKind() != TokKind::kFusionKind) { + if (lexer_.GetKind() != TokKind::kIdent) { return TokenError("expects fusion kind"); } - *result = lexer_.GetFusionKindVal(); + string val = lexer_.GetStrVal(); + auto status_or_result = StringToFusionKind(val); + if (!status_or_result.ok()) { + return TokenError( + Printf("expects fusion kind but sees: %s, error: %s", val.c_str(), + status_or_result.status().error_message().c_str())); + } + *result = status_or_result.ValueOrDie(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseRandomDistribution(RandomDistribution* result) { + VLOG(1) << "ParseRandomDistribution"; + if (lexer_.GetKind() != TokKind::kIdent) { + return TokenError("expects random distribution"); + } + string val = lexer_.GetStrVal(); + auto status_or_result = StringToRandomDistribution(val); + if (!status_or_result.ok()) { + return TokenError( + Printf("expects random distribution but sees: %s, error: %s", + val.c_str(), status_or_result.status().error_message().c_str())); + } + *result = status_or_result.ValueOrDie(); lexer_.Lex(); return true; } @@ -2092,20 +2344,20 @@ bool HloParser::EatIfPresent(TokKind kind) { return true; } -bool HloParser::AddInstruction(const string& name, - HloInstruction* instruction) { +bool HloParser::AddInstruction(const string& name, HloInstruction* instruction, + LocTy name_loc) { auto result = instruction_pool_.insert({name, instruction}); if (!result.second) { - return TokenError(StrCat("instruction already exists: ", name)); + return Error(name_loc, StrCat("instruction already exists: ", name)); } return true; } -bool HloParser::AddComputation(const string& name, - HloComputation* computation) { +bool HloParser::AddComputation(const string& name, HloComputation* computation, + LocTy name_loc) { auto result = computation_pool_.insert({name, computation}); if (!result.second) { - return TokenError(StrCat("computation already exists: ", name)); + return Error(name_loc, StrCat("computation already exists: ", name)); } return true; } @@ -2116,7 +2368,7 @@ StatusOr> Parse(StringPiece str, const HloModuleConfig& config) { HloParser parser(str, config); if (!parser.Run()) { - return InvalidArgument("Syntax error: %s", parser.GetError().c_str()); + return InvalidArgument("Syntax error:\n%s", parser.GetError().c_str()); } return parser.ConsumeHloModule(); } diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index cb02ef84a9295fb100c77f2951e6acf3cce896f1..e6f7ee7c08f4d17a8d8ac58ec4662756b7c7159f 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -46,7 +46,7 @@ std::vector CreateTestCases() { // ax + y { "AxpyParam", -R"(HloModule axpy_module: +R"(HloModule axpy_module ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { %alpha = f32[] parameter(0) @@ -62,7 +62,7 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { // pred constant { "ConstantPred", -R"(HloModule constant_pred_module: +R"(HloModule constant_pred_module ENTRY %constant_pred () -> pred[] { ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68} @@ -73,7 +73,7 @@ ENTRY %constant_pred () -> pred[] { // s32 constant { "ConstantS32", -R"(HloModule constant_s32_module: +R"(HloModule constant_s32_module ENTRY %constant_s32 () -> s32[] { ROOT %constant = s32[] constant(-42) @@ -84,18 +84,40 @@ ENTRY %constant_s32 () -> s32[] { // f32 constant, but the value is not a decimal { "ConstantF32", -R"(HloModule ConstantF32_module: +R"(HloModule ConstantF32_module ENTRY %ConstantF32.v4 () -> f32[] { ROOT %constant = f32[] constant(42) } +)" +}, +// f32 constant, rank 1 empty array. +{ +"ConstantF32R1Empty", +R"(HloModule ConstantF32Empty_module + +ENTRY %ConstantF32Empty.v4 () -> f32[0] { + ROOT %constant = f32[0]{0} constant({}) +} + +)" +}, +// f32 constant, rank 4 empty array. +{ +"ConstantF32R4Empty", +R"(HloModule ConstantF32R4Empty_module + +ENTRY %ConstantF32R4Empty.v4 () -> f32[2,0,4,3] { + ROOT %constant = f32[2,0,4,3]{3,2,1,0} constant(f32[2,0,4,3] { { /*i0=0*/ }, { /*i0=1*/ } }) +} + )" }, // constant 4D { "Constant4D", -R"(HloModule Small_3x2x1x1_module: +R"(HloModule Small_3x2x1x1_module ENTRY %Small_3x2x1x1.v1 () -> f32[3,2,1,1] { ROOT %constant = f32[3,2,1,1]{3,2,1,0} constant(f32[3,2,1,1] { { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } }) @@ -106,7 +128,7 @@ ENTRY %Small_3x2x1x1.v1 () -> f32[3,2,1,1] { // non-finite constants: nan, inf, -inf { "ConstantNonFinite", -R"(HloModule IsFiniteR1F32s_module: +R"(HloModule IsFiniteR1F32s_module ENTRY %IsFiniteR1F32s.v2 () -> pred[6] { %constant = f32[6]{0} constant({nan, 7, nan, -1, inf, -inf}) @@ -118,18 +140,29 @@ ENTRY %IsFiniteR1F32s.v2 () -> pred[6] { // constant f16 { "ConstantF16", -R"(HloModule ConstantF16_module: +R"(HloModule ConstantF16_module ENTRY %ConstantF16.v4 () -> f16[] { ROOT %constant = f16[] constant(500) } +)" +}, +// bf16 +{ +"BF16", +R"(HloModule BF16 + +ENTRY %BF16.v4 () -> bf16[] { + ROOT %constant = bf16[] constant(500) +} + )" }, // constant + constant { "AddConstants", -R"(HloModule add_constants_module: +R"(HloModule add_constants_module ENTRY %add_constants () -> f32[] { %constant = f32[] constant(3.14) @@ -141,7 +174,7 @@ ENTRY %add_constants () -> f32[] { // tuple constant { "TupleConstant", -R"(HloModule TupleConstant_module: +R"(HloModule TupleConstant_module ENTRY %TupleConstant.v1 () -> (f32[2,1], f32[2]) { ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { { 1 }, { 2 } }, {2, 42} )) @@ -152,7 +185,7 @@ ENTRY %TupleConstant.v1 () -> (f32[2,1], f32[2]) { // v1 > v2 ? v1 : v2 { "SelectR1F32", -R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module: +R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] { %v1 = f32[4]{0} parameter(0), sharding={maximal device=1} @@ -166,7 +199,7 @@ ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f3 // empty tuple { "EmptyTupleCreate", -R"(HloModule EmptyTupleCreate_module: +R"(HloModule EmptyTupleCreate_module ENTRY %EmptyTupleCreate.v1 () -> () { ROOT %tuple = () tuple() @@ -177,7 +210,7 @@ ENTRY %EmptyTupleCreate.v1 () -> () { // tuple { "TupleCreate", -R"(HloModule TupleCreate_module: +R"(HloModule TupleCreate_module ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) { %v1 = f32[] parameter(0) @@ -190,7 +223,7 @@ ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f }, { "ShardedTupleCreate", -R"(HloModule ShardedTupleCreate_module: +R"(HloModule ShardedTupleCreate_module ENTRY %ShardedTupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) { %v1 = f32[] parameter(0) @@ -205,7 +238,7 @@ ENTRY %ShardedTupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f3 // while (result < 5) { result = result + 1; } { "WhileWithScalarS32Result", -R"(HloModule WhileWithScalarS32Result_module: +R"(HloModule WhileWithScalarS32Result_module %body.v3 (prev.1: s32[]) -> s32[] { %constant = s32[] constant(1) @@ -229,7 +262,7 @@ ENTRY %WhileWithScalarS32Result.v2 () -> s32[] { // send and recv { "SendRecv", -R"(HloModule TwoSendRecvBothWayRecvFist_module: +R"(HloModule TwoSendRecvBothWayRecvFist_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { %recv = (f32[], u32[]) recv(), channel_id=15, sharding={maximal device=1} @@ -244,7 +277,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { // get-tuple-element { "GetTupleElement", -R"(HloModule GetTupleElement_module: +R"(HloModule GetTupleElement_module ENTRY %GetTupleElement.v4 () -> s32[2,3] { %constant = f32[3]{0} constant({1, 2, 3}) @@ -258,7 +291,7 @@ ENTRY %GetTupleElement.v4 () -> s32[2,3] { // call { "Call", -R"(HloModule CallR0F32IdentityScalar_module: +R"(HloModule CallR0F32IdentityScalar_module %Identity.v1 (x: f32[]) -> f32[] { ROOT %x = f32[] parameter(0) @@ -274,7 +307,7 @@ ENTRY %CallR0F32IdentityScalar.v2 () -> f32[] { // reduce window { "ReduceWindow", -R"(HloModule R4UnitWindow_module: +R"(HloModule R4UnitWindow_module %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { %lhs = f32[] parameter(0) @@ -288,12 +321,31 @@ ENTRY %R4UnitWindow.v3 (operand: f32[13,12,8,15]) -> f32[13,3,8,15] { ROOT %reduce-window = f32[13,3,8,15]{0,3,2,1} reduce-window(f32[13,12,8,15]{0,3,2,1} %operand, f32[] %constant), window={size=1x1x7x1 stride=1x4x1x1 pad=0_0x0_0x3_3x0_0}, to_apply=%add_F32.v3 } +)" +}, +// reduce window on scalar +{ +"ReduceWindowScalar", +R"(HloModule reduce_window_scalar + +%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) +} + +ENTRY %R4UnitWindowScalar () -> f32[] { + %constant = f32[] constant(42) + %constant.1 = f32[] constant(1) + ROOT %reduce-window = f32[] reduce-window(f32[] %constant, f32[] %constant.1), to_apply=%add_F32.v3 +} + )" }, // convolution { "Convolution", -R"(HloModule Convolve1D1Window_0_module: +R"(HloModule Convolve1D1Window_0_module ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] { %input = f32[1,2,1]{2,1,0} parameter(0) @@ -307,12 +359,25 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 // convolution rank 2 { "ConvolutionR2", -R"(HloModule ConvolveR2_module: +R"(HloModule ConvolveR2_module ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[1,1]) -> f32[1,2] { %input = f32[1,2]{1,0} parameter(0) %filter = f32[1,1]{1,0} parameter(1) - ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), window={size=1}, dim_labels=bf_io->bf + ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf +} + +)" +}, +// convolution backward +{ +"ConvolutionBackward", +R"(HloModule ConvolveBackward_module + +ENTRY %ConvolveBackward (input: f32[128,7,7,512], filter: f32[3,3,512,512]) -> f32[128,14,14,512] { + %input = f32[128,7,7,512]{0,3,2,1} parameter(0) + %filter = f32[3,3,512,512]{3,2,1,0} parameter(1) + ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f } )" @@ -320,7 +385,7 @@ ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[1,1]) -> f32[1,2] { // reverse(constant) { "Reverse4D", -R"(HloModule Reverse4DFloatArrayOnDim01_module: +R"(HloModule Reverse4DFloatArrayOnDim01_module ENTRY %Reverse4DFloatArrayOnDim01.v2 () -> f32[4,3,2,1] { %constant = f32[4,3,2,1]{0,1,2,3} constant(f32[4,3,2,1] { { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} }, { /*i1=2*/ {5}, {6} } }, { /*i0=1*/ { /*i1=0*/ {7}, {8} }, { /*i1=1*/ {9}, {10} }, { /*i1=2*/ {11}, {12} } }, { /*i0=2*/ { /*i1=0*/ {13}, {14} }, { /*i1=1*/ {15}, {16} }, { /*i1=2*/ {17}, {18} } }, { /*i0=3*/ { /*i1=0*/ {19}, {20} }, { /*i1=1*/ {21}, {22} }, { /*i1=2*/ {23}, {24} } } }) @@ -332,7 +397,7 @@ ENTRY %Reverse4DFloatArrayOnDim01.v2 () -> f32[4,3,2,1] { // concat { "Concat", -R"(HloModule Concat2x3With2x5_module: +R"(HloModule Concat2x3With2x5_module ENTRY %Concat2x3With2x5.v3 () -> f32[2,8] { %constant = f32[2,3]{1,0} constant(f32[2,3] { { 0, 1, 2 }, { 1000, 1001, 1002 } }) @@ -342,48 +407,36 @@ ENTRY %Concat2x3With2x5.v3 () -> f32[2,8] { )" }, -// map +// select and scatter { -"Map", -R"(HloModule MapBinaryAdder_module: +"SelectAndScatter", +R"(HloModule R4F32OverlapSmall_module -%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { +%ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] { %lhs = f32[] parameter(0) %rhs = f32[] parameter(1) - ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) -} - -ENTRY %MapBinaryAdder.v3 (param0: f32[4], param1: f32[4]) -> f32[4] { - %param0 = f32[4]{0} parameter(0) - %param1 = f32[4]{0} parameter(1) - ROOT %map = f32[4]{0} map(f32[4]{0} %param0, f32[4]{0} %param1), to_apply=%add_F32.v3 + ROOT %greater-than-or-equal-to = pred[] greater-than-or-equal-to(f32[] %lhs, f32[] %rhs) } -)" -}, -// reduce -{ -"Reduce", -R"(HloModule ReduceR3ToR2_module: - -%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { - %lhs = f32[] parameter(0) - %rhs = f32[] parameter(1) - ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) +%add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] { + %lhs.1 = f32[] parameter(0) + %rhs.1 = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %lhs.1, f32[] %rhs.1) } -ENTRY %ReduceR3ToR2.v3 (input: f32[8,16,256]) -> f32[8,16] { - %input = f32[8,16,256]{2,1,0} parameter(0) - %constant = f32[] constant(0) - ROOT %reduce = f32[8,16]{1,0} reduce(f32[8,16,256]{2,1,0} %input, f32[] %constant), dimensions={2}, to_apply=%add_F32.v3 +ENTRY %R4F32OverlapSmall.v4 () -> f32[4,5,1,1] { + %constant = f32[4,5,1,1]{3,2,1,0} constant(f32[4,5,1,1] { { /*i0=0*/ { /*i1=0*/ {7} }, { /*i1=1*/ {2} }, { /*i1=2*/ {5} }, { /*i1=3*/ {3} }, { /*i1=4*/ {8} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {8} }, { /*i1=2*/ {9} }, { /*i1=3*/ {3} }, { /*i1=4*/ {4} } }, { /*i0=2*/ { /*i1=0*/ {1} }, { /*i1=1*/ {5} }, { /*i1=2*/ {7} }, { /*i1=3*/ {5} }, { /*i1=4*/ {6} } }, { /*i0=3*/ { /*i1=0*/ {0} }, { /*i1=1*/ {6} }, { /*i1=2*/ {2} }, { /*i1=3*/ {10} }, { /*i1=4*/ {2} } } }) + %constant.1 = f32[2,2,1,1]{3,2,1,0} constant(f32[2,2,1,1] { { /*i0=0*/ { /*i1=0*/ {2} }, { /*i1=1*/ {6} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {1} } } }) + %constant.2 = f32[] constant(0) + ROOT %select-and-scatter = f32[4,5,1,1]{3,2,1,0} select-and-scatter(f32[4,5,1,1]{3,2,1,0} %constant, f32[2,2,1,1]{3,2,1,0} %constant.1, f32[] %constant.2), window={size=2x3x1x1 stride=2x2x1x1}, select=%ge_F32.v3, scatter=%add_F32.v3 } )" }, -// select and scatter +// select and scatter on scalar { -"SelectAndScatter", -R"(HloModule R4F32OverlapSmall_module: +"SelectAndScatterScalar", +R"(HloModule select_and_scatter_scalar %ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] { %lhs = f32[] parameter(0) @@ -397,11 +450,11 @@ R"(HloModule R4F32OverlapSmall_module: ROOT %add = f32[] add(f32[] %lhs.1, f32[] %rhs.1) } -ENTRY %R4F32OverlapSmall.v4 () -> f32[4,5,1,1] { - %constant = f32[4,5,1,1]{3,2,1,0} constant(f32[4,5,1,1] { { /*i0=0*/ { /*i1=0*/ {7} }, { /*i1=1*/ {2} }, { /*i1=2*/ {5} }, { /*i1=3*/ {3} }, { /*i1=4*/ {8} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {8} }, { /*i1=2*/ {9} }, { /*i1=3*/ {3} }, { /*i1=4*/ {4} } }, { /*i0=2*/ { /*i1=0*/ {1} }, { /*i1=1*/ {5} }, { /*i1=2*/ {7} }, { /*i1=3*/ {5} }, { /*i1=4*/ {6} } }, { /*i0=3*/ { /*i1=0*/ {0} }, { /*i1=1*/ {6} }, { /*i1=2*/ {2} }, { /*i1=3*/ {10} }, { /*i1=4*/ {2} } } }) - %constant.1 = f32[2,2,1,1]{3,2,1,0} constant(f32[2,2,1,1] { { /*i0=0*/ { /*i1=0*/ {2} }, { /*i1=1*/ {6} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {1} } } }) - %constant.2 = f32[] constant(0) - ROOT %select-and-scatter = f32[4,5,1,1]{3,2,1,0} select-and-scatter(f32[4,5,1,1]{3,2,1,0} %constant, f32[2,2,1,1]{3,2,1,0} %constant.1, f32[] %constant.2), window={size=2x3x1x1 stride=2x2x1x1}, select=%ge_F32.v3, scatter=%add_F32.v3 +ENTRY %SelectAndScatterScalar () -> f32[] { + %constant = f32[] constant(42) + %constant.1 = f32[] constant(1) + %constant.2 = f32[] constant(2) + ROOT %select-and-scatter = f32[] select-and-scatter(f32[] %constant, f32[] %constant.1, f32[] %constant.2), select=%ge_F32.v3, scatter=%add_F32.v3 } )" @@ -409,7 +462,7 @@ ENTRY %R4F32OverlapSmall.v4 () -> f32[4,5,1,1] { // slice { "Slice", -R"(HloModule slice_module: +R"(HloModule slice_module ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] { %p0 = f32[3,3,4,4]{3,2,1,0} parameter(0) @@ -421,7 +474,7 @@ ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] { // slice, no stride { "SliceNoStride", -R"(HloModule Slice3x3x3_To_1x3x3_F32_module: +R"(HloModule Slice3x3x3_To_1x3x3_F32_module ENTRY %Slice3x3x3_To_1x3x3_F32.v2 () -> f32[1,3,3] { %constant = f32[3,3,3]{2,1,0} constant(f32[3,3,3] { { { 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 } }, { { 9, 10, 11 }, { 12, 13, 14 }, { 15, 16, 17 } }, { { 18, 19, 20 }, { 21, 22, 23 }, { 24, 25, 26 } } }) @@ -433,7 +486,7 @@ ENTRY %Slice3x3x3_To_1x3x3_F32.v2 () -> f32[1,3,3] { // slice R0 { "SliceR0", -R"(HloModule SliceR0_module: +R"(HloModule SliceR0_module ENTRY %SliceR0.v2 () -> s32[] { %constant = s32[] constant(1) @@ -445,7 +498,7 @@ ENTRY %SliceR0.v2 () -> s32[] { // transpose { "Transpose", -R"(HloModule Transpose_module: +R"(HloModule Transpose_module ENTRY %Transpose.v2 () -> s32[1,2,3] { %constant = s32[1,2,3]{2,1,0} constant(s32[1,2,3] { { { 1, 2, 3 }, { 4, 5, 6 } } }) @@ -457,7 +510,7 @@ ENTRY %Transpose.v2 () -> s32[1,2,3] { // Dynamic slice { "DynamicSlice", -R"(HloModule DynamicSlice_module: +R"(HloModule DynamicSlice_module ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[1]) -> s32[2,2,258] { %original_parameter = s32[2,2,258]{2,1,0} parameter(0) @@ -472,7 +525,7 @@ ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[1]) - // Dynamic update slice { "DynamicUpdateSlice", -R"(HloModule DynamicUpdateSlice_module: +R"(HloModule DynamicUpdateSlice_module ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_indices: s32[4]) -> s32[1,1,25,1] { %input = s32[1,1,25,1]{3,2,1,0} parameter(0) @@ -486,7 +539,7 @@ ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_ // batch norm training { "BatchNormTraining", -R"(HloModule BasicTraining_module: +R"(HloModule BasicTraining_module ENTRY %BasicTraining.v4 () -> (f32[2,2,1,2], f32[2], f32[2]) { %constant = f32[2,2,1,2]{3,2,1,0} constant(f32[2,2,1,2] { { /*i0=0*/ { /*i1=0*/ {1, 2} }, { /*i1=1*/ {3, 4} } }, { /*i0=1*/ { /*i1=0*/ {5, 6} }, { /*i1=1*/ {7, 8} } } }) @@ -500,7 +553,7 @@ ENTRY %BasicTraining.v4 () -> (f32[2,2,1,2], f32[2], f32[2]) { // batch norm inference { "BatchNormInference", -R"(HloModule BatchNormInference_module: +R"(HloModule BatchNormInference_module ENTRY %BatchNormInference.v6 (input: f32[2,2,2,2], offset: f32[2], scale: f32[2], mean: f32[2], variance: f32[2]) -> f32[2,2,2,2] { %input = f32[2,2,2,2]{3,2,1,0} parameter(0) @@ -516,7 +569,7 @@ ENTRY %BatchNormInference.v6 (input: f32[2,2,2,2], offset: f32[2], scale: f32[2] // batch norm grad { "BatchNormGrad", -R"(HloModule BatchNormGrad_module: +R"(HloModule BatchNormGrad_module ENTRY %BatchNormGrad.v4 (input: f32[2,2,2,2], scale: f32[2], mean: f32[2], variance: f32[2], grad_output: f32[2,2,2,2]) -> (f32[2,2,2,2], f32[2], f32[2]) { %input = f32[2,2,2,2]{3,2,1,0} parameter(0) @@ -532,7 +585,7 @@ ENTRY %BatchNormGrad.v4 (input: f32[2,2,2,2], scale: f32[2], mean: f32[2], varia // pad { "Pad", -R"(HloModule Pad1DS3Array_module: +R"(HloModule Pad1DS3Array_module ENTRY %Pad1DS3Array.v3 () -> f32[8] { %constant = f32[3]{0} constant({1, 2, 3}) @@ -545,7 +598,7 @@ ENTRY %Pad1DS3Array.v3 () -> f32[8] { // pad has interior { "PadHasInterior", -R"(HloModule PadHasInterior_module: +R"(HloModule PadHasInterior_module ENTRY %PadHasInterior.v3 (input: f32[1,25,7,7]) -> f32[1,25,17,11] { %input = f32[1,25,7,7]{3,2,1,0} parameter(0) @@ -553,12 +606,25 @@ ENTRY %PadHasInterior.v3 (input: f32[1,25,7,7]) -> f32[1,25,17,11] { ROOT %pad = f32[1,25,17,11]{3,2,1,0} pad(f32[1,25,7,7]{3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_0_0x2_2_1x2_2_0 } +)" +}, +// Negative padding +{ +"PadHasNegativePadding", +R"(HloModule PadHasNegativePadding_module + +ENTRY %PadHasNegativePadding (input: f32[1,25,7,7,10]) -> f32[1,15,6,3,29] { + %input = f32[1,25,7,7,10]{4,3,2,1,0} parameter(0) + %constant = f32[] constant(-5.123) + ROOT %pad = f32[1,15,6,3,29]{4,3,2,1,0} pad(f32[1,25,7,7,10]{4,3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_-10_0x0_-1_0x-2_-2_0x-1_-1_3 +} + )" }, // fusion { "Fusion", -R"(HloModule fusion_module: +R"(HloModule fusion_module %fused_computation (constant.param_0: f32[3,2,1,1], constant.1.param_1: f32[2]) -> f32[3,2,1,1] { %constant.param_0 = f32[3,2,1,1]{3,2,1,0} parameter(0) @@ -573,22 +639,140 @@ ENTRY %fusion.v3 () -> f32[3,2,1,1] { ROOT %fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %constant, f32[2]{0} %constant.1), kind=kLoop, calls=%fused_computation } +)" +} + }); + // clang-format on +} + +std::vector CreateShortTestCases() { + // clang-format off + return std::vector({ +// map +{ +"Map", +R"(HloModule MapBinaryAdder_module + +add_F32.v3 { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY MapBinaryAdder.v3 { + param0 = f32[4]{0} parameter(0) + param1 = f32[4]{0} parameter(1) + ROOT map = f32[4]{0} map(param0, param1), to_apply=add_F32.v3 +} + +)" +}, +// reduce +{ +"Reduce", +R"(HloModule ReduceR3ToR2_module + +add_F32.v3 { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY ReduceR3ToR2.v3 { + input = f32[8,16,256]{2,1,0} parameter(0) + constant = f32[] constant(0) + ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3 +} + )" }, // infeed/outfeed { "InfeedOutfeed", -R"(HloModule outfeed_module: +R"(HloModule outfeed_module + +ENTRY InfeedToOutfeed { + infeed = (u32[3]{0}, pred[]) infeed() + outfeed = () outfeed(infeed) + ROOT infeed.1 = (u32[3]{0}, pred[]) infeed() + outfeed.1 = () outfeed(infeed.1) +} + +)" +}, +// Rng +{ +"Rng", +R"(HloModule rng_module -ENTRY %InfeedToOutfeed () -> (u32[3], pred[]) { - %infeed = (u32[3]{0}, pred[]) infeed() - %outfeed = () outfeed((u32[3]{0}, pred[]) %infeed) - ROOT %infeed.1 = (u32[3]{0}, pred[]) infeed() - %outfeed.1 = () outfeed((u32[3]{0}, pred[]) %infeed.1) +ENTRY Rng { + constant = f32[] constant(0) + constant.1 = f32[] constant(1) + ROOT rng = f32[8]{0} rng(constant, constant.1), distribution=rng_uniform } )" +}, +// Reduce precision +{ +"ReducePrevison", +R"(HloModule reduce_precision + +ENTRY ReducePrecision { + constant = f32[1]{0} constant({3.14159}) + ROOT reduce-precision = f32[1]{0} reduce-precision(constant), exponent_bits=8, mantissa_bits=10 } + +)" +}, +// Conditional +{ +"Conditional", +R"(HloModule conditional + +Negate { + x = f32[] parameter(0) + ROOT negate = f32[] negate(x) +} + +Identity { + y = f32[] parameter(0) + ROOT copy = f32[] copy(y) +} + +ENTRY Parameters1.v4 { + constant = pred[] constant(true) + constant.1 = f32[] constant(56) + constant.2 = f32[] constant(12) + ROOT conditional = f32[] conditional(constant, constant.1, constant.2), true_computation=Negate, false_computation=Identity +} + +)" +}, +// CustomCall +{ +"CustomCall", +R"(HloModule custom_call + +ENTRY CustomCall { + constant = f32[1]{0} constant({12345}) + ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar" +} + +)" +}, +// Variables with non-default names +{ +"NonDefaultNames", +R"(HloModule add_constants_module + +ENTRY add_constants { + foo = f32[] constant(3.14) + ROOT bar = f32[] add(foo, foo) +} + +)" +}, }); // clang-format on } @@ -607,18 +791,35 @@ class HloParserTest : public ::testing::Test, void ExpectEqual() { const string& original = GetParam().module_string; auto result = Parse(original); - TF_EXPECT_OK(result.status()); + TF_ASSERT_OK(result.status()); + EXPECT_EQ(original, result.ValueOrDie()->ToString( + HloPrintOptions().set_print_large_constants(true))); + } +}; + +class HloParserShortTest : public HloParserTest { + protected: + void ExpectEqualShort() { + const string& original = GetParam().module_string; + auto result = Parse(original); + TF_ASSERT_OK(result.status()); EXPECT_EQ(original, - result.ValueOrDie()->ToString(/*include_large_constants=*/true)); + result.ValueOrDie()->ToString(HloPrintOptions::ShortParsable())); } }; TEST_P(HloParserTest, Run) { ExpectEqual(); } +TEST_P(HloParserShortTest, Run) { ExpectEqualShort(); } + INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTest, ::testing::ValuesIn(CreateTestCases()), TestDataToString); +INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserShortTest, + ::testing::ValuesIn(CreateShortTestCases()), + TestDataToString); + TEST_F(HloParserTest, Empty) { const string original = ""; auto result = Parse(original); @@ -682,7 +883,7 @@ ENTRY %blabla (x: f32[]) -> pred[] { } TEST_F(HloParserTest, MoreConstants) { - const string original = R"(HloModule SelectScalarS32True_module: + const string original = R"(HloModule SelectScalarS32True_module ENTRY %SelectScalarS32True.v4 () -> s32[] { %constant.2 = pred[] constant(true) @@ -699,7 +900,7 @@ ENTRY %SelectScalarS32True.v4 () -> s32[] { } TEST_F(HloParserTest, LiteralDimensionsMismatch_1) { - const string original = R"(HloModule some_2_module: + const string original = R"(HloModule some_2_module ENTRY %some_2 () -> f32[2] { ROOT %constant = f32[2]{0} constant({1,{2}}) @@ -713,7 +914,7 @@ ENTRY %some_2 () -> f32[2] { } TEST_F(HloParserTest, LiteralDimensionsMismatch_2) { - const string original = R"(HloModule some_2x3_module: + const string original = R"(HloModule some_2x3_module ENTRY %some_2x3 () -> f32[2,3] { ROOT %constant = f32[2,3]{1,0} constant(f32[2,3] {1, 2, 3, 4, 5, 6}) @@ -727,7 +928,7 @@ ENTRY %some_2x3 () -> f32[2,3] { } TEST_F(HloParserTest, LiteralDimensionsMismatch_3) { - const string original = R"(HloModule some_2x3x2_module: + const string original = R"(HloModule some_2x3x2_module ENTRY %some_2x3x2 () -> f32[2,3,2] { ROOT %constant = f32[2,3,2]{2,1,0} constant(f32[2,3,2] {{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}, {11, 12}}}) @@ -742,7 +943,7 @@ ENTRY %some_2x3x2 () -> f32[2,3,2] { TEST_F(HloParserTest, ConstantF16Overflow) { const string original = - R"(HloModule ConstantF16Overflow_module: + R"(HloModule ConstantF16Overflow_module ENTRY %ConstantF16Overflow.v4 () -> f16[] { ROOT %constant = f16[] constant(-65505) @@ -756,7 +957,7 @@ ENTRY %ConstantF16Overflow.v4 () -> f16[] { } TEST_F(HloParserTest, ConstantWithExp) { - const string original = R"(HloModule ConstantWithExp_module: + const string original = R"(HloModule ConstantWithExp_module ENTRY %ConstantWithExp.v4 () -> f32[] { %constant.1 = f32[] constant(3e+2) @@ -771,7 +972,7 @@ ENTRY %ConstantWithExp.v4 () -> f32[] { } TEST_F(HloParserTest, AttibutesAnyOrder) { - const string original = R"(HloModule any_order_module: + const string original = R"(HloModule any_order_module ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] { %input = f32[1,2,1]{2,1,0} parameter(0) @@ -785,7 +986,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 } TEST_F(HloParserTest, InvalidDimLabels) { - string prefix = R"(HloModule invalid_dim_labels_module: + string prefix = R"(HloModule invalid_dim_labels_module ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] { %input = f32[1,2,1]{2,1,0} parameter(0) @@ -806,16 +1007,10 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 .status() .error_message(), "must have the same rank"); - - ExpectHasSubstr(Parse(StrCat(prefix, ",dim_labels=0bf_io0->b0f", suffix)) - .status() - .error_message(), - "output spatial dimensions should be the same as input " - "spatial dimensions"); } TEST_F(HloParserTest, UnexpectedAttribute) { - const string original = R"(HloModule unexpected_attr_module: + const string original = R"(HloModule unexpected_attr_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { %recv = (f32[], u32[]) recv(), channel_id=15 @@ -831,7 +1026,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { } TEST_F(HloParserTest, MissingAttribute) { - const string original = R"(HloModule missing_attr_module: + const string original = R"(HloModule missing_attr_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { %recv = (f32[], u32[]) recv(), channel_id=15 @@ -847,7 +1042,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { } TEST_F(HloParserTest, PredecessorUndefined) { - const string original = R"(HloModule pre_not_found_module: + const string original = R"(HloModule pre_not_found_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { %recv = (f32[], u32[]) recv(), channel_id=15 @@ -863,7 +1058,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { } TEST_F(HloParserTest, SliceAllowOmitStride1) { - const string original = R"(HloModule slice_module: + const string original = R"(HloModule slice_module ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] { %p0 = f32[3,3,4,4]{3,2,1,0} parameter(0) @@ -875,7 +1070,7 @@ ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] { } TEST_F(HloParserTest, PaddingConfigIsNotWindowPad) { - const string original = R"(HloModule window_pad_module: + const string original = R"(HloModule window_pad_module ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] { %input = f32[1,2,1]{2,1,0} parameter(0) @@ -890,7 +1085,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 } TEST_F(HloParserTest, CommaBetweenSubAttributes) { - const string original = R"(HloModule test_comma_module: + const string original = R"(HloModule test_comma_module ENTRY %test_comma.v4 () -> f32[] { ROOT %constant = f32[] constant(-4.2), metadata={source_line=5, op_type="::const"} @@ -900,6 +1095,95 @@ ENTRY %test_comma.v4 () -> f32[] { TF_EXPECT_OK(Parse(original).status()); } +TEST_F(HloParserTest, ComputationShapeDoesNotMatchRootShape) { + const string original = R"(HloModule custom_call: + +ENTRY %CustomCall () -> f32[1] { + %constant = f32[1]{0} constant({12345}) + ROOT %foo = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar" +})"; + ExpectHasSubstr(Parse(original).status().error_message(), + "Shape of computation CustomCall, f32[1], is not compatible " + "with that of its root instruction foo, f32[1,2,3]"); +} + +TEST_F(HloParserTest, EntryComputationWithLayout) { + const string original = R"(HloModule layout: +add_F32.v3 { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] { + input = f32[8,16,256]{0,1,2} parameter(0) + constant = f32[] constant(0) + ROOT reduce = f32[8,16]{0,1} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3 +})"; + + auto module = Parse(original); + TF_ASSERT_OK(module.status()); + auto program_layout = module.ValueOrDie()->entry_computation_layout(); + ASSERT_EQ(program_layout.parameter_count(), 1); + auto param_layout = program_layout.parameter_layout(0).layout(); + auto result_layout = program_layout.result_layout().layout(); + EXPECT_TRUE( + LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1, 2}), param_layout)) + << "actual layout of parameter(0) is " + << LayoutUtil::HumanString(param_layout); + EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1}), result_layout)) + << "actual layout of result is " + << LayoutUtil::HumanString(result_layout); +} + +TEST_F(HloParserTest, NoEntry) { + const string original = R"(HloModule no_entry: +c1 { + const1 = f32[1]{0} constant({12345}) +} +c2 { + const2 = f32[1]{0} constant({67890}) +})"; + auto module = Parse(original); + TF_ASSERT_OK(module.status()); + EXPECT_EQ(module.ValueOrDie()->entry_computation()->name(), "c2"); +} + +TEST_F(HloParserTest, NoRoot) { + const string original = R"(HloModule no_root: +ENTRY consts { + first = f32[1]{0} constant({12345}) + last = f32[1]{0} constant({67890}) +})"; + auto module = Parse(original); + TF_ASSERT_OK(module.status()); + EXPECT_EQ( + module.ValueOrDie()->entry_computation()->root_instruction()->name(), + "last"); +} + +TEST_F(HloParserTest, MultipleEntries) { + const string original = R"(HloModule multiple_entries: +ENTRY c1 { + const1 = f32[1]{0} constant({12345}) +} +ENTRY c2 { + const2 = f32[1]{0} constant({67890}) +})"; + ExpectHasSubstr(Parse(original).status().error_message(), + "expects only one ENTRY"); +} + +TEST_F(HloParserTest, MultipleRoots) { + const string original = R"(HloModule multiple_roots: +ENTRY consts { + ROOT const1 = f32[1]{0} constant({12345}) + ROOT const2 = f32[1]{0} constant({12345}) +})"; + ExpectHasSubstr(Parse(original).status().error_message(), + "one computation should have only one ROOT"); +} + } // namespace } // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_token.h b/tensorflow/compiler/xla/tools/parser/hlo_token.h index 07e48804d053f31bdff6678f09ee2c1e3b731e0f..7928bee5c2097f353b182095a555c334d7b69c95 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_token.h +++ b/tensorflow/compiler/xla/tools/parser/hlo_token.h @@ -18,6 +18,9 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/types.h" + namespace xla { namespace tools { @@ -60,10 +63,9 @@ enum class TokKind { kDimLabels, // [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} kDxD, // [0-9]+(x[0-9]+)+ kPad, // [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* + kIdent, // other identifiers kString, // "abcd\"\n" kShape, // f32[2,3]{1,0} - kOpcode, // add - kFusionKind, // kLoop, kOutput, ... kInt, // 42 kDecimal, // 4.2 }; diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 503e7d456e1f462b753610e8a08a47db7a714ed6..a7dc5862057047f7c56faeb211cc0b13992caec7 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -59,18 +59,26 @@ namespace xla { namespace tools { namespace { +// Command-line opts to this tool. See main() for descriptions of these +// fields. +struct Options { + string fake_infeed_shape; + bool use_fake_data = false; + bool print_result = true; + int num_runs = 1; +}; + // Invokes the given computation passing arbitrary data for every (unbound) // parameter if use_fake_data, Otherwise use recorded data if available. // // Similarly, infeeds fake data of shape fake_infeed_shape if it is provided; // otherwise, no infeed is performed. StatusOr> ReplayComputation( - const SessionModule& module, tensorflow::StringPiece fake_infeed_shape, - bool use_fake_data, Client* client) { + const SessionModule& module, Client* client, const Options& opts) { TF_ASSIGN_OR_RETURN(Computation computation, client->LoadSnapshot(module)); std::vector> arguments; - if (use_fake_data) { + if (opts.use_fake_data) { arguments = MakeFakeArgumentsOrDie(computation, client); } else { // use recorded data if available for (const auto& proto : module.arguments()) { @@ -85,12 +93,12 @@ StatusOr> ReplayComputation( // concurrent infeed occur via the fake_infeed_shape. tensorflow::gtl::optional pool; - if (!fake_infeed_shape.empty()) { + if (!opts.fake_infeed_shape.empty()) { pool.emplace(tensorflow::Env::Default(), "infeed", /*num_threads=*/1); - pool->Schedule([fake_infeed_shape, client]() { + pool->Schedule([opts, client]() { StatusOr shape_status = - ShapeUtil::ParseShapeString(fake_infeed_shape); + ShapeUtil::ParseShapeString(opts.fake_infeed_shape); TF_CHECK_OK(shape_status.status()); Shape shape = std::move(shape_status).ValueOrDie(); StatusOr> data_status = MakeFakeLiteral(shape); @@ -107,11 +115,32 @@ StatusOr> ReplayComputation( for (auto& argument : arguments) { execute_arguments.push_back(argument.get()); } - return client->ExecuteAndTransfer(computation, execute_arguments); + + // Run the computation num_runs times, and return the result from the last + // execution. + std::unique_ptr result; + for (int i = 0; i < opts.num_runs; ++i) { + ExecutionProfile profile; + if (opts.print_result) { + TF_ASSIGN_OR_RETURN(result, client->ExecuteAndTransfer( + computation, execute_arguments, + /*execution_options=*/nullptr, &profile)); + } else { + // If we're not printing the result, execute the computation but don't + // bother retrieving the result. This can be a significant speedup. + TF_RETURN_IF_ERROR(client + ->Execute(computation, execute_arguments, + /*execution_options=*/nullptr, &profile) + .status()); + } + LOG(INFO) << "Execution took " + << static_cast(profile.compute_time_ns()) / 1e9 << "s"; + } + + return std::move(result); } -int RealMain(tensorflow::gtl::ArraySlice args, - tensorflow::StringPiece fake_infeed_shape, bool use_fake_data) { +int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { Client* client = ClientLibrary::LocalClientOrDie(); tensorflow::Env* env = tensorflow::Env::Default(); int exit_status = EXIT_SUCCESS; @@ -119,21 +148,24 @@ int RealMain(tensorflow::gtl::ArraySlice args, SessionModule module; TF_CHECK_OK(tensorflow::ReadBinaryProto(env, arg, &module)); StatusOr> result_status = - ReplayComputation(module, fake_infeed_shape, use_fake_data, client); + ReplayComputation(module, client, opts); if (!result_status.ok()) { fprintf(stderr, "%s: error: %s\n", arg, result_status.status().ToString().c_str()); exit_status = EXIT_FAILURE; continue; } + std::unique_ptr result = result_status.ConsumeValueOrDie(); - fprintf(stdout, "%s: %s :: %s:%s\n", arg, module.entry().name().c_str(), - ShapeUtil::HumanString(result->shape()).c_str(), - result->ToString().c_str()); - if (module.has_result()) { - fprintf(stdout, "was %s:%s\n", - ShapeUtil::HumanString(module.result().shape()).c_str(), - Literal(module.result()).ToString().c_str()); + if (result != nullptr) { + fprintf(stdout, "%s: %s :: %s:%s\n", arg, module.entry().name().c_str(), + ShapeUtil::HumanString(result->shape()).c_str(), + result->ToString().c_str()); + if (module.has_result()) { + fprintf(stdout, "was %s:%s\n", + ShapeUtil::HumanString(module.result().shape()).c_str(), + Literal(module.result()).ToString().c_str()); + } } } return exit_status; @@ -144,13 +176,15 @@ int RealMain(tensorflow::gtl::ArraySlice args, } // namespace xla int main(int argc, char** argv) { - // Flags - xla::string fake_infeed_shape; - bool use_fake_data = false; + xla::tools::Options opts; const std::vector flag_list = { - tensorflow::Flag("use_fake_data", &use_fake_data, + tensorflow::Flag("use_fake_data", &opts.use_fake_data, "Replay computation using fake data"), - tensorflow::Flag("fake_infeed_shape", &fake_infeed_shape, + tensorflow::Flag("print_result", &opts.print_result, + "Print the result of the computation to stdout"), + tensorflow::Flag("num_runs", &opts.num_runs, + "Number of times to run each computation"), + tensorflow::Flag("fake_infeed_shape", &opts.fake_infeed_shape, "Shape of fake data to construct for (infinite) infeed"), }; xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); @@ -162,5 +196,5 @@ int main(int argc, char** argv) { tensorflow::gtl::ArraySlice args(argv, argc); args.pop_front(); // Pop off the binary name, argv[0] - return xla::tools::RealMain(args, fake_infeed_shape, use_fake_data); + return xla::tools::RealMain(args, opts); } diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index 2624ef0252fd9482a600fe3aec07f7f328a86d69..fe5d29a6b655a89d559eb1214c2b8dd54d34094c 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -42,15 +42,15 @@ Status WithLogBacktrace(const Status& status) { } // namespace -ScopedLoggingTimer::ScopedLoggingTimer(const string& label, int32 vlog_level) - : label(label), vlog_level(vlog_level) { - if (VLOG_IS_ON(vlog_level)) { +ScopedLoggingTimer::ScopedLoggingTimer(const string& label, bool enabled) + : enabled(enabled), label(label) { + if (enabled) { start_micros = tensorflow::Env::Default()->NowMicros(); } } ScopedLoggingTimer::~ScopedLoggingTimer() { - if (VLOG_IS_ON(vlog_level)) { + if (enabled) { uint64 end_micros = tensorflow::Env::Default()->NowMicros(); double secs = (end_micros - start_micros) / 1000000.0; @@ -191,9 +191,9 @@ std::vector ComposePermutations(tensorflow::gtl::ArraySlice p1, return output; } -bool IsIdentityPermutation(tensorflow::gtl::ArraySlice p) { - for (int64 i = 0; i < p.size(); ++i) { - if (p[i] != i) { +bool IsIdentityPermutation(tensorflow::gtl::ArraySlice permutation) { + for (int64 i = 0; i < permutation.size(); ++i) { + if (permutation[i] != i) { return false; } } diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index f58f57b44396c90a3820835a3d0ecc792aaa7cd0..b722095d1f38bf8a984c3ce9092a65f8e0baa911 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -50,13 +50,43 @@ using DimensionVector = tensorflow::gtl::InlinedVector; // RAII timer that logs with a given label the wall clock time duration in human // readable form. This differs from base's ElapsedTimer primarily in that it // spits out the human-readable duration form. +// +// By default, the timing traces are only printed at VLOG(1) and above: +// +// XLA_SCOPED_LOGGING_TIMER("fooing bar"); // nop if !VLOG_IS_ON(1). +// +// but you can control this via: +// +// XLA_SCOPED_LOGGING_TIMER_LEVEL("fooing bar", 2); // nop if !VLOG_IS_ON(2) +// +#define XLA_SCOPED_LOGGING_TIMER(label) \ + XLA_SCOPED_LOGGING_TIMER_HELPER(label, 1, __COUNTER__) +#define XLA_SCOPED_LOGGING_TIMER_LEVEL(label, level) \ + XLA_SCOPED_LOGGING_TIMER_HELPER(label, level, __COUNTER__) + +// Helper for implementing macros above. Do not use directly. +// +// Forces the evaluation of "counter", which we expect is equal to __COUNTER__. +#define XLA_SCOPED_LOGGING_TIMER_HELPER(label, level, counter) \ + XLA_SCOPED_LOGGING_TIMER_HELPER2(label, level, counter) + +// Helper for macros above. Don't use directly. +#define XLA_SCOPED_LOGGING_TIMER_HELPER2(label, level, counter) \ + ::xla::ScopedLoggingTimer XLA_ScopedLoggingTimerInstance##counter( \ + label, VLOG_IS_ON(level)) + +// RAII timer for XLA_SCOPED_LOGGING_TIMER and XLA_SCOPED_LOGGING_TIMER_LEVEL +// macros above. Recommended usage is via the macros so you don't have to give +// the timer a name or worry about calling VLOG_IS_ON yourself. struct ScopedLoggingTimer { - explicit ScopedLoggingTimer(const string& label, int32 vlog_level = 1); + // The timer does nothing if enabled is false. This lets you pass in your + // file's VLOG_IS_ON value. + ScopedLoggingTimer(const string& label, bool enabled); ~ScopedLoggingTimer(); - uint64 start_micros; + bool enabled; string label; - int32 vlog_level; + uint64 start_micros; }; // Given a vector, returns a MutableArraySlice that points at its diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc index 6f7f1479b90377ea3c2019508acb6db311c5a1ba..293f0781a203d092a7996d5548de1dbf5bf32e4c 100644 --- a/tensorflow/compiler/xla/window_util.cc +++ b/tensorflow/compiler/xla/window_util.cc @@ -44,6 +44,9 @@ namespace window_util { if (dim.window_dilation() != 1) { StrAppend(&str, ",window_dilation=", dim.window_dilation()); } + if (dim.window_reversal()) { + StrAppend(&str, ",window_reversal"); + } StrAppend(&str, ")"); return str; } @@ -85,6 +88,11 @@ string ToString(const Window& window) { return StrCat(dim.window_dilation()); }); } + if (HasWindowReversal(window)) { + add_field(" rhs_reversal", [](const WindowDimension& dim) { + return StrCat(dim.window_reversal() ? 1 : 0); + }); + } return str; } @@ -138,6 +146,15 @@ bool HasWindowDilation(const Window& window) { return false; } +bool HasWindowReversal(const Window& window) { + for (const auto& dim : window.dimensions()) { + if (dim.window_reversal()) { + return true; + } + } + return false; +} + bool HasDilation(const Window& window) { return HasBaseDilation(window) || HasWindowDilation(window); } diff --git a/tensorflow/compiler/xla/window_util.h b/tensorflow/compiler/xla/window_util.h index 235cb2d59d451a25dc4f824ab488f8cef6b03bfb..125900dac0c5ab478b834c315b4a438c9238ef6d 100644 --- a/tensorflow/compiler/xla/window_util.h +++ b/tensorflow/compiler/xla/window_util.h @@ -39,6 +39,8 @@ bool HasBaseDilation(const Window& window); bool HasWindowDilation(const Window& window); bool HasDilation(const Window& window); +bool HasWindowReversal(const Window& window); + // Returns the new bound after dilation. // // If a window with the given bound in some dimension is dilated with the given diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index eac8f2ff07e4a885affdc0f7b1563d3a2cb606d7..95045d5e28b96c8e9b31fccd62a24d5c83523092 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -114,6 +114,14 @@ message PaddingConfig { repeated PaddingConfigDimension dimensions = 1; } +// A format specifies the method used by a layout to store an array in memory. +enum Format { + INVALID_FORMAT = 0; + // The default layout, with exactly one storage location per element (ignoring + // padding). + DENSE = 1; +} + // A layout describes how the array is placed in (1D) memory space. This // includes the minor-to-major ordering of dimensions within a shape, as well as // any padding present in those dimensions. @@ -124,19 +132,23 @@ message PaddingConfig { // // See the XLA documentation for more information on shapes and layouts. message Layout { + // The method used to store the data in memory. The format determines which of + // the other fields are used by the layout. + Format format = 4; + // Sequence of dimension numbers, from minor (fastest varying index) to major // (slowest varying index). This field is required. repeated int64 minor_to_major = 1; - // The width to which the layout of each dimension is padded up - // to. If present, the size of the padded_dimensions must equal the - // rank of the shape. The padding appears at the end of a dimension, - // not at the beginning. This kind of padding, unlike padding in - // e.g. convolution, is not part of the shape. + // The width to which the layout of each dimension is padded up to. If + // present, the size of the padded_dimensions must equal the rank of the + // shape. The padding appears at the end of a dimension, not at the + // beginning. This kind of padding, unlike padding in e.g. convolution, is not + // part of the shape. This field must be unset unless the format is DENSE. repeated int64 padded_dimensions = 2; - // Describes the values in the padding specified by - // padded_dimensions. + // Describes the values in the padding specified by padded_dimensions. This + // field must be unset unless the format is DENSE. PaddingValue padding_value = 3; // Important: if any field is added, be sure to modify ShapeUtil::Equal() @@ -357,6 +369,10 @@ message WindowDimension { // means no dilation. base_dilation - 1 no-op entries ("holes") are implicitly // placed between each base area element. See documentation for convolution. int64 base_dilation = 6; + + // Window reversal means that this dimension was logically reversed before the + // operation. + bool window_reversal = 7; } // Describes the windowing in an operation such as convolution. @@ -413,15 +429,9 @@ message ConvolutionDimensionNumbers { // The number of the dimension that represents features in the input. int64 input_feature_dimension = 8; - // The number of the dimension that represents batch in the output. - int64 output_batch_dimension = 9; - - // The number of the dimension that represents features in the output. - int64 output_feature_dimension = 10; - // The dimension numbers for the spatial dimensions that the window - // moves through in the input (lhs) and output. - repeated int64 spatial_dimensions = 5; + // moves through in the input. + repeated int64 input_spatial_dimensions = 11; // The number of the dimension that represents input features in the // convolutional kernel (rhs). @@ -435,12 +445,24 @@ message ConvolutionDimensionNumbers { // moves through in the kernel (rhs). window.strides(0) is the // stride in the kernel_spatial_dimensions(0) dimension. repeated int64 kernel_spatial_dimensions = 6; + + // The number of the dimension that represents batch in the output. + int64 output_batch_dimension = 9; + + // The number of the dimension that represents features in the output. + int64 output_feature_dimension = 10; + + // The dimension numbers for the spatial dimensions that the window + // moves through in the output. + repeated int64 output_spatial_dimensions = 12; + + // Next = 13 }; message ConvolveRequest { ComputationDataHandle lhs = 2; ComputationDataHandle rhs = 3; // This is the filter/kernel. - Window window = 4; // Describes the filter/kenel. + Window window = 4; // Describes the filter/kernel. ConvolutionDimensionNumbers dimension_numbers = 5; } @@ -488,6 +510,23 @@ message CustomCallRequest { Shape shape = 4; } +message DotDimensionNumbers { + // The dimension numbers that represent the 'lhs' contracting dimensions. + repeated int64 lhs_contracting_dimensions = 1; + // The dimension numbers that represent the 'rhs' contracting dimensions. + repeated int64 rhs_contracting_dimensions = 2; + // The dimension numbers that represent the 'lhs' batch dimensions. + repeated int64 lhs_batch_dimensions = 3; + // The dimension numbers that represent the 'rhs' batch dimensions. + repeated int64 rhs_batch_dimensions = 4; +}; + +message DotRequest { + ComputationDataHandle lhs = 2; + ComputationDataHandle rhs = 3; + DotDimensionNumbers dimension_numbers = 4; +} + message MapRequest { repeated ComputationDataHandle operands = 2; ComputationHandle to_apply = 3; @@ -641,6 +680,14 @@ message ConcatenateRequest { int64 dimension = 3; } +message ConditionalRequest { + ComputationDataHandle predicate = 2; + ComputationDataHandle true_operand = 3; + ComputationHandle true_computation = 4; + ComputationDataHandle false_operand = 5; + ComputationHandle false_computation = 6; +} + message WhileRequest { ComputationHandle condition = 2; ComputationHandle body = 3; @@ -722,9 +769,6 @@ enum BinaryOperation { BINOP_LT = 9; BINOP_NE = 10; - // Dot product, matrix multiply. - BINOP_DOT = 12; - // Element-wise maximum. BINOP_MAX = 14; @@ -875,6 +919,7 @@ message OpRequest { ConvolveRequest convolve_request = 8; CrossReplicaSumRequest cross_replica_sum_request = 9; CustomCallRequest custom_call_request = 10; + DotRequest dot_request = 43; DynamicSliceRequest dynamic_slice_request = 11; DynamicUpdateSliceRequest dynamic_update_slice_request = 12; GetTupleElementRequest get_tuple_element_request = 13; @@ -903,7 +948,9 @@ message OpRequest { BatchNormGradRequest batch_norm_grad_request = 37; BatchNormInferenceRequest batch_norm_inference_request = 38; FftRequest fft_request = 41; - // Next: 42 + ConvertRequest bitcast_convert_request = 42; + ConditionalRequest conditional_request = 44; + // Next: 45 } } diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index b7ade951150412e0ad3f72c235f0677e68fce66e..6e2320bd0d6376cfddb60f8069e141a88bc93563 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -9,7 +9,12 @@ load("//third_party/mpi:mpi.bzl", "if_mpi") py_library( name = "contrib_py", - srcs = glob(["**/*.py"]), + srcs = glob( + ["**/*.py"], + exclude = [ + "**/*_test.py", + ], + ), srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ @@ -48,6 +53,7 @@ py_library( "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/learn", "//tensorflow/contrib/legacy_seq2seq:seq2seq_py", + "//tensorflow/contrib/libsvm", "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/contrib/linear_optimizer:sdca_estimator_py", "//tensorflow/contrib/linear_optimizer:sdca_ops_py", @@ -64,6 +70,7 @@ py_library( "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_py", "//tensorflow/contrib/nn:nn_py", "//tensorflow/contrib/opt:opt_py", + "//tensorflow/contrib/periodic_resample:init_py", "//tensorflow/contrib/predictor", "//tensorflow/contrib/quantization:quantization_py", "//tensorflow/contrib/quantize:quantize_graph", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 1eda1abfcf779ece7af3dbf2554c2a0a8c2611e9..08247c6b38a4df663ad28a6b4d3c41a1da41a020 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -55,6 +55,7 @@ from tensorflow.contrib import model_pruning from tensorflow.contrib import nccl from tensorflow.contrib import nn from tensorflow.contrib import opt +from tensorflow.contrib import periodic_resample from tensorflow.contrib import predictor from tensorflow.contrib import quantization from tensorflow.contrib import quantize diff --git a/tensorflow/contrib/android/README.md b/tensorflow/contrib/android/README.md index f49e5857fe5255c2459793cb1389052a2ff5f88f..c7c128bf14f03d3769ef08e83da61f6d2f91fbd2 100644 --- a/tensorflow/contrib/android/README.md +++ b/tensorflow/contrib/android/README.md @@ -15,9 +15,9 @@ For prebuilt libraries, see the page for a recent build. The TensorFlow Inference Interface is also available as a -[JCenter package](https://bintray.com/google/tensorflow/tensorflow-android) and -can be included quite simply in your android project with a couple of lines in -the project's `build.gradle` file: +[JCenter package](https://bintray.com/google/tensorflow/tensorflow) +(see the tensorflow-android directory) and can be included quite simply in your +android project with a couple of lines in the project's `build.gradle` file: ``` allprojects { diff --git a/tensorflow/contrib/android/cmake/CMakeLists.txt b/tensorflow/contrib/android/cmake/CMakeLists.txt index 25ada5ba27aa167e4aaf4cebd6517e3b80aa1058..a115d1610e2334a6626f29674f3dd195e3a3c648 100644 --- a/tensorflow/contrib/android/cmake/CMakeLists.txt +++ b/tensorflow/contrib/android/cmake/CMakeLists.txt @@ -34,10 +34,12 @@ add_library(lib_tf STATIC IMPORTED ) set_target_properties(lib_tf PROPERTIES IMPORTED_LOCATION ${PREBUILT_DIR}/lib/libtensorflow-core.a) # Change to compile flags should be replicated into bazel build file +# TODO: Consider options other than -O2 for binary size. +# e.g. -Os for gcc, and -Oz for clang. set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DIS_SLIM_BUILD \ -std=c++11 -fno-rtti -fno-exceptions \ -O2 -Wno-narrowing -fomit-frame-pointer \ - -mfpu=neon -mfloat-abi=softfp -fPIE \ + -mfpu=neon -mfloat-abi=softfp -fPIE -fPIC \ -ftemplate-depth=900 \ -DGOOGLE_PROTOBUF_NO_RTTI \ -DGOOGLE_PROTOBUF_NO_STATIC_INITIALIZER") diff --git a/tensorflow/contrib/android/cmake/README.md b/tensorflow/contrib/android/cmake/README.md index 6f19b657fe72064bd7b005b568540cd52a5e19e8..934b58c7242fc06064ee3c06bc8f4c2740bd24ef 100644 --- a/tensorflow/contrib/android/cmake/README.md +++ b/tensorflow/contrib/android/cmake/README.md @@ -14,7 +14,7 @@ Add TensorFlow-Android-Inference as a dependency of your Android application ``` include ':TensorFlow-Android-Inference' -findProject(":TensorFlow-Android-Inference").projectDir = +findProject(":TensorFlow-Android-Inference").projectDir = new File("${/path/to/tensorflow_repo}/contrib/android/cmake") ``` diff --git a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java index 1f423a7a5bf6a115dc627ddd6f5e98c074282585..dc5b9fb88742d78d0f40207b589e29451a6358dd 100644 --- a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java +++ b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java @@ -160,7 +160,7 @@ public class TensorFlowInferenceInterface { throw new RuntimeException("Failed to load model from the input stream", e); } } - + /* * Construct a TensorFlowInferenceInterface with provided Graph * @@ -168,7 +168,7 @@ public class TensorFlowInferenceInterface { */ public TensorFlowInferenceInterface(Graph g) { prepareNativeRuntime(); - + // modelName is redundant here, here is for // avoiding error in initialization as modelName is marked final. this.modelName = ""; @@ -290,7 +290,7 @@ public class TensorFlowInferenceInterface { */ public void feed(String inputName, boolean[] src, long... dims) { byte[] b = new byte[src.length]; - + for (int i = 0; i < src.length; i++) { b[i] = src[i] ? (byte) 1 : (byte) 0; } diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD index a111cfecb366fe245150cc71d2c43662d0d69090..ea8ac2c680e62ee03a45716aa1e0870d44495f1e 100644 --- a/tensorflow/contrib/batching/BUILD +++ b/tensorflow/contrib/batching/BUILD @@ -82,7 +82,10 @@ cc_library( tf_cc_test( name = "adaptive_shared_batch_scheduler_test", srcs = ["adaptive_shared_batch_scheduler_test.cc"], - tags = ["manual"], # b/69013768 + tags = [ + "local", + "manual", + ], deps = [ ":adaptive_shared_batch_scheduler", "//tensorflow/contrib/batching/test_util:fake_clock_env", diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h index 6ed177e001758ad8c566c7965e1ec10ae5235fc8..a2cb146b8d69b6cc0eda8912a9c840ac4e0c7030 100644 --- a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h +++ b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h @@ -16,9 +16,11 @@ limitations under the License. #ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ #define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ +#include #include #include #include +#include #include #include @@ -42,19 +44,36 @@ template class ASBSQueue; } // namespace internal +// EXPERIMENTAL: API MAY BE SUBJECTED TO SUDDEN CHANGES. +// // Shared batch scheduler designed to minimize latency. The scheduler keeps // track of a number of queues (one per model or model version) which are // continuously enqueuing requests. The scheduler groups the requests into // batches which it periodically sends off for processing (see // shared_batch_scheduler.h for more details). The AdaptiveSharedBatchScheduler // prioritizes batches by age (i.e. the batch's oldest request) irrespective of -// queue. The scheduler will process the oldest batch at an adjustable rate, -// regardless of batch size. The user can provide feedback to help set this rate -// to achieve some goal (i.e. minimize overall latency, limit cpu usage, etc). +// queue or batch size. +// +// The scheduling decision currently exists in two flavors, controlled by the +// option use_in_flight_batches_implementation. It is expected that setting this +// option to true will give universally better results; after a period of +// testing to confirm, the old implementation will be removed. // -// The rate (or rather, the corresponding period) is adjusted each time a batch -// is processed, using an exponentially weighted moving average to smooth -// potentially noisy feedback: +// If use_in_flight_batches_implementation is set to true, the scheduler +// limits the number of batches which can be processed concurrently. If a new +// batch is created, and the number of in flight batches is below the limit, +// the next (i.e. oldest) batch is immediately scheduled. Similarly, when a +// batch finishes processing, the limit is rechecked, and another batch may be +// scheduled. To avoid the need to carefully tune the limit for workload, +// model type, platform, etc, it is dynamically adjusted in order to provide the +// lowest latency. +// +// If use_in_flight_batches_implementation is set to false, the scheduler will +// process the oldest batch at an adjustable rate, regardless of batch size. +// The user can provide feedback to help set this rate to achieve some goal +// (i.e. minimize overall latency, limit cpu usage, etc). The rate (or rather, +// the corresponding period) is adjusted each time a batch is processed, using +// an exponentially weighted moving average to smooth noisy feedback: // ewma_feedback = ((N - 1) * ewma_feedback + feedback()) / N // period *= (1 + K * emwa_feedback) // @@ -82,6 +101,20 @@ class AdaptiveSharedBatchScheduler int64 num_batch_threads = port::NumSchedulableCPUs(); // The environment to use (typically only overridden by test code). Env* env = Env::Default(); + // Which implementation to use (described in class comments above). + bool use_in_flight_batches_implementation = false; + // Initial limit for number of batches being concurrently processed. + // Non-integer values correspond to probabilistic limits - i.e. a value of + // 3.2 results in an actual cap of 3 80% of the time, and 4 20% of the time. + double initial_in_flight_batches_limit = 3; + // Number of batches between adjustments of in_flight_batches_limit. Larger + // numbers will give less noisy latency measurements, but will be less + // responsive to changes in workload. + int64 batches_to_average_over = 1000; + + // TODO(kte): remove the rate based implementation and corresponding options + // below once testing confirms the superiority of the in flight batches + // implementation. // Initial batch scheduling period in microseconds. Will be altered for // non-zero rate_feedback. double initial_scheduling_period_micros = 500; @@ -122,6 +155,11 @@ class AdaptiveSharedBatchScheduler BatchProcessor process_batch_callback, std::unique_ptr>* queue); + double in_flight_batches_limit() { + mutex_lock l(mu_); + return in_flight_batches_limit_; + } + private: // access to AddBatch, RemoveQueue, GetEnv. friend class internal::ASBSQueue; @@ -129,10 +167,20 @@ class AdaptiveSharedBatchScheduler explicit AdaptiveSharedBatchScheduler(const Options& options); // Batch scheduling function which runs every scheduling_period_ microseconds. + // Only used when options_.use_in_flight_batches_implementation == false. void ProcessOneBatch(); + // Tracks processing latency and adjusts in_flight_batches_limit to minimize. + // Only used when options_.use_in_flight_batches_implementation == true. + void CallbackWrapper(const internal::ASBSBatch* batch, + BatchProcessor callback); + + // Schedules batch if in_flight_batches_limit_ is not met. + // Only used when options_.use_in_flight_batches_implementation == true. + void MaybeScheduleNextBatch() EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Notifies scheduler of non-empty batch which is eligible for processing. - void AddBatch(internal::ASBSBatch*); + void AddBatch(const internal::ASBSBatch* batch); // Removes queue from scheduler. void RemoveQueue(const internal::ASBSQueue* queue); @@ -149,7 +197,8 @@ class AdaptiveSharedBatchScheduler // Collection of batches added by AddBatch, ordered by age. Owned by scheduler // until they are released for processing. std::priority_queue*, - std::vector*>, BatchCompare> + std::vector*>, + BatchCompare> batches_ GUARDED_BY(mu_); // Unowned queues and callbacks added by AddQueue. @@ -160,19 +209,56 @@ class AdaptiveSharedBatchScheduler // Responsible for running ProcessOneBatch. PeriodicFunction was used in order // to check for deletion so that the thread can be shut down. + // Only used when options_.use_in_flight_batches_implementation == false. std::unique_ptr scheduling_thread_; // Responsible for running the batch processing callbacks. std::unique_ptr batch_thread_pool_; // Time interval in microseconds between successive ProcessOneBatch calls. + // Only used when options_.use_in_flight_batches_implementation == false. double scheduling_period_; // Exponentially weighted moving average of // options_.scheduling_period_feedback() evaluated in each ProcessOneBatch // call. + // Only used when options_.use_in_flight_batches_implementation == false. double ewma_feedback_ = 0; + // Limit on number of batches which can be concurrently processed. + // Non-integer values correspond to probabilistic limits - i.e. a value of 3.2 + // results in an actual cap of 3 80% of the time, and 4 20% of the time. + // Only used when options_.use_in_flight_batches_implementation == true. + double in_flight_batches_limit_ GUARDED_BY(mu_); + + // Number of batches currently being processed. + // Only used when options_.use_in_flight_batches_implementation == true. + int64 in_flight_batches_ GUARDED_BY(mu_) = 0; + + // RNG engine and distribution. + // Only used when options_.use_in_flight_batches_implementation == true. + std::default_random_engine rand_engine_; + std::uniform_real_distribution rand_double_; + + // Fields controlling the dynamic adjustment of in_flight_batches_limit_. + // Only used when options_.use_in_flight_batches_implementation == true. + // Number of batches since the last in_flight_batches_limit_ adjustment. + int64 batch_count_ GUARDED_BY(mu_) = 0; + // Sum of processing latency for batches counted by batch_count_. + int64 batch_latency_sum_ GUARDED_BY(mu_) = 0; + // Average batch latency for previous value of in_flight_batches_limit_. + double last_avg_latency_ms_ GUARDED_BY(mu_) = 0; + // Did last_avg_latency_ms_ decrease from the previous last_avg_latency_ms_? + bool last_latency_decreased_ GUARDED_BY(mu_) = false; + // Current direction (+-) to adjust in_flight_batches_limit_ + int step_direction_ GUARDED_BY(mu_) = 1; + // Max adjustment size (as a fraction of in_flight_batches_limit_). + constexpr static double kMaxStepSizeMultiplier = 0.125; // 1/8; + // Min adjustment size (as a fraction of in_flight_batches_limit_). + constexpr static double kMinStepSizeMultiplier = 0.0078125; // 1/128 + // Current adjustment size (as a fraction of in_flight_batches_limit_). + double step_size_multiplier_ GUARDED_BY(mu_) = kMaxStepSizeMultiplier; + TF_DISALLOW_COPY_AND_ASSIGN(AdaptiveSharedBatchScheduler); }; @@ -208,6 +294,8 @@ class ASBSQueue : public BatchScheduler { // place any more tasks in this batch. void ReleaseBatch(const ASBSBatch* batch); + size_t max_task_size() const override { return options_.max_batch_size; } + private: std::shared_ptr> scheduler_; const QueueOptions options_; @@ -241,6 +329,12 @@ class ASBSBatch : public Batch { // ---------------- AdaptiveSharedBatchScheduler ---------------- +template +constexpr double AdaptiveSharedBatchScheduler::kMaxStepSizeMultiplier; + +template +constexpr double AdaptiveSharedBatchScheduler::kMinStepSizeMultiplier; + template Status AdaptiveSharedBatchScheduler::Create( const Options& options, @@ -275,6 +369,25 @@ Status AdaptiveSharedBatchScheduler::Create( "feedback_smoothing_batches must be positive; was ", options.feedback_smoothing_batches); } + if (options.initial_in_flight_batches_limit > options.num_batch_threads) { + return errors::InvalidArgument( + "initial_in_flight_batches_limit (", + options.initial_in_flight_batches_limit, + ") should not be larger than num_batch_threads (", + options.num_batch_threads, ")"); + } + if (options.initial_in_flight_batches_limit < 1) { + return errors::InvalidArgument( + "initial_in_flight_batches_limit should be " + "greater than or equal to 1; was ", + options.initial_in_flight_batches_limit); + } + if (options.batches_to_average_over < 1) { + return errors::InvalidArgument( + "batches_to_average_over should be " + "greater than or equal to 1; was ", + options.batches_to_average_over); + } scheduler->reset(new AdaptiveSharedBatchScheduler(options)); return Status::OK(); } @@ -283,14 +396,20 @@ template AdaptiveSharedBatchScheduler::AdaptiveSharedBatchScheduler( const Options& options) : options_(options), - scheduling_period_(options.initial_scheduling_period_micros) { + scheduling_period_(options.initial_scheduling_period_micros), + in_flight_batches_limit_(options.initial_in_flight_batches_limit), + rand_double_(0.0, 1.0) { + std::random_device device; + rand_engine_.seed(device()); PeriodicFunction::Options opts; opts.thread_name_prefix = "scheduling_thread"; opts.env = GetEnv(); - scheduling_thread_.reset( - new PeriodicFunction([this] { ProcessOneBatch(); }, 0, opts)); batch_thread_pool_.reset(new thread::ThreadPool( GetEnv(), options.thread_pool_name, options.num_batch_threads)); + if (!options.use_in_flight_batches_implementation) { + scheduling_thread_.reset( + new PeriodicFunction([this] { ProcessOneBatch(); }, 0, opts)); + } } template @@ -316,9 +435,12 @@ Status AdaptiveSharedBatchScheduler::AddQueue( template void AdaptiveSharedBatchScheduler::AddBatch( - internal::ASBSBatch* batch) { + const internal::ASBSBatch* batch) { mutex_lock l(mu_); batches_.push(batch); + if (options_.use_in_flight_batches_implementation) { + MaybeScheduleNextBatch(); + } } template @@ -328,10 +450,78 @@ void AdaptiveSharedBatchScheduler::RemoveQueue( queues_and_callbacks_.erase(queue); } +template +void AdaptiveSharedBatchScheduler::MaybeScheduleNextBatch() { + if (batches_.empty() || in_flight_batches_ >= in_flight_batches_limit_) + return; + // Non-integer limit handled probabilistially. + if (in_flight_batches_limit_ - in_flight_batches_ < 1 && + rand_double_(rand_engine_) > + (in_flight_batches_limit_ - in_flight_batches_)) + return; + const internal::ASBSBatch* batch = batches_.top(); + batches_.pop(); + // Queue may destroy itself after ReleaseBatch is called. + batch->queue()->ReleaseBatch(batch); + batch_thread_pool_->Schedule( + std::bind(&AdaptiveSharedBatchScheduler::CallbackWrapper, this, + batch, queues_and_callbacks_[batch->queue()])); + in_flight_batches_++; +} + +template +void AdaptiveSharedBatchScheduler::CallbackWrapper( + const internal::ASBSBatch* batch, + AdaptiveSharedBatchScheduler::BatchProcessor callback) { + int64 start_time = batch->creation_time_micros(); + callback(std::unique_ptr>( + const_cast*>(batch))); + int64 end_time = GetEnv()->NowMicros(); + mutex_lock l(mu_); + in_flight_batches_--; + batch_count_++; + batch_latency_sum_ += end_time - start_time; + // Occasionally adjust in_flight_batches_limit_ to minimize average latency. + // Although the optimal value may depend on the workload, the latency should + // be a simple convex function of in_flight_batches_limit_, allowing us to + // locate the global minimum relatively quickly. + if (batch_count_ == options_.batches_to_average_over) { + double current_avg_latency_ms = (batch_latency_sum_ / 1000.) / batch_count_; + bool current_latency_decreased = + current_avg_latency_ms < last_avg_latency_ms_; + if (current_latency_decreased) { + // If latency improvement was because we're moving in the correct + // direction, increase step_size so that we can get to the minimum faster. + // If latency improvement was due to backtracking from a previous failure, + // decrease step_size in order to refine our location. + step_size_multiplier_ *= (last_latency_decreased_ ? 2 : 0.5); + step_size_multiplier_ = + std::min(step_size_multiplier_, kMaxStepSizeMultiplier); + step_size_multiplier_ = + std::max(step_size_multiplier_, kMinStepSizeMultiplier); + } else { + // Return (nearly) to previous position and confirm that latency is better + // there before decreasing step size. + step_direction_ = -step_direction_; + } + in_flight_batches_limit_ += + step_direction_ * in_flight_batches_limit_ * step_size_multiplier_; + in_flight_batches_limit_ = + std::min(in_flight_batches_limit_, + static_cast(options_.num_batch_threads)); + in_flight_batches_limit_ = std::max(in_flight_batches_limit_, 1.0); + last_avg_latency_ms_ = current_avg_latency_ms; + last_latency_decreased_ = current_latency_decreased; + batch_count_ = 0; + batch_latency_sum_ = 0; + } + MaybeScheduleNextBatch(); +} + template void AdaptiveSharedBatchScheduler::ProcessOneBatch() { static const double kFeedbackMultiplier = .001; - internal::ASBSBatch* batch = nullptr; + const internal::ASBSBatch* batch = nullptr; BatchProcessor callback; const int64 start_time_micros = GetEnv()->NowMicros(); { @@ -355,7 +545,8 @@ void AdaptiveSharedBatchScheduler::ProcessOneBatch() { // Queue may destroy itself after ReleaseBatch is called. batch->queue()->ReleaseBatch(batch); batch_thread_pool_->Schedule([callback, batch] { - callback(std::unique_ptr>(batch)); + callback(std::unique_ptr>( + const_cast*>(batch))); }); } const int64 sleep_time = @@ -425,6 +616,7 @@ Status ASBSQueue::Schedule(std::unique_ptr* task) { current_batch_->AddTask(std::move(*task)); num_enqueued_tasks_++; } + // AddBatch must be called outside of lock, since it may call ReleaseBatch. if (new_batch != nullptr) scheduler_->AddBatch(new_batch); return Status::OK(); } diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc index a07cd6d834fa28904bf7748b16972cca217503c1..18f1e554525a306ffe07460a889411ed4755b89f 100644 --- a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc +++ b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc @@ -141,6 +141,16 @@ TEST(AdaptiveSharedBatchSchedulerTest, BadOptions) { options = Scheduler::Options(); options.feedback_smoothing_batches = 0; EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); + options = Scheduler::Options(); + options.initial_in_flight_batches_limit = 0.5; + EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); + options = Scheduler::Options(); + options.num_batch_threads = 5; + options.initial_in_flight_batches_limit = 8; + EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); + options = Scheduler::Options(); + options.batches_to_average_over = -5; + EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); } TEST(AdaptiveSharedBatchSchedulerTest, ObeysQueueOptions) { @@ -186,6 +196,7 @@ TEST(AdaptiveSharedBatchSchedulerTest, ObeysQueueOptions) { queue_options.max_enqueued_batches = 2; TF_ASSERT_OK( scheduler->AddQueue(queue_options, queue_0_callback, &queue_0)); + EXPECT_EQ(10, queue_0->max_task_size()); queue_options.max_batch_size = 0; // Queue must have max_batch_size > 0. EXPECT_FALSE( @@ -433,6 +444,107 @@ TEST(AdaptiveSharedBatchSchedulerTest, QueueCapacityInfo) { } stop_teardown.Notify(); } + +TEST(AdaptiveSharedBatchSchedulerTest, InFlightBatchesImplementation) { + AdaptiveSharedBatchScheduler::Options options; + options.use_in_flight_batches_implementation = true; + options.initial_in_flight_batches_limit = 2; + options.batches_to_average_over = 1000; + std::shared_ptr> scheduler; + TF_ASSERT_OK( + AdaptiveSharedBatchScheduler::Create(options, &scheduler)); + std::unique_ptr> queue; + mutex mu; + int processed_batches = 0; + Notification finish_processing; + auto queue_callback = [&mu, &processed_batches, &finish_processing]( + std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + mu.lock(); + int batch_num = ++processed_batches; + mu.unlock(); + if (batch_num == 2) { + // Give third batch a chance to process if it's going to. + Env::Default()->SleepForMicroseconds(1000); + finish_processing.Notify(); + } + if (batch_num == 3) { + ASSERT_TRUE(finish_processing.HasBeenNotified()); + } + finish_processing.WaitForNotification(); + }; + + TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue)); + + // Enqueue 3 batches. + for (int i = 0; i < 3; i++) { + TF_ASSERT_OK(ScheduleTask(100, queue.get())); + } +} + +TEST(AdaptiveSharedBatchSchedulerTest, InFlightBatchesLimitTuning) { + test_util::FakeClockEnv env(Env::Default()); + Notification start_teardown, stop_teardown; + std::unique_ptr teardown_thread = + CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown); + { + AdaptiveSharedBatchScheduler::Options options; + options.env = &env; + options.use_in_flight_batches_implementation = true; + options.initial_in_flight_batches_limit = 2; + options.batches_to_average_over = 1; + std::shared_ptr> scheduler; + TF_ASSERT_OK( + AdaptiveSharedBatchScheduler::Create(options, &scheduler)); + std::unique_ptr> queue; + auto queue_callback = [&env](std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + switch (batch->size()) { + case 0: + env.AdvanceByMicroseconds(10); + break; + case 1: + env.AdvanceByMicroseconds(15); + break; + case 2: + env.AdvanceByMicroseconds(10); + break; + case 3: + env.AdvanceByMicroseconds(11); + break; + } + }; + + TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue)); + TF_ASSERT_OK(ScheduleTask(0, queue.get())); + double in_flight_batches_limit = 2; + while (scheduler->in_flight_batches_limit() == in_flight_batches_limit) { + } + // Initial direction will be negative. + EXPECT_LT(scheduler->in_flight_batches_limit(), in_flight_batches_limit); + in_flight_batches_limit = scheduler->in_flight_batches_limit(); + TF_ASSERT_OK(ScheduleTask(1, queue.get())); + while (scheduler->in_flight_batches_limit() == in_flight_batches_limit) { + } + // Latency increased -> change direction. + EXPECT_GT(scheduler->in_flight_batches_limit(), in_flight_batches_limit); + in_flight_batches_limit = scheduler->in_flight_batches_limit(); + TF_ASSERT_OK(ScheduleTask(2, queue.get())); + while (scheduler->in_flight_batches_limit() == in_flight_batches_limit) { + } + // Latency decreased -> keep going in same direction. + EXPECT_GT(scheduler->in_flight_batches_limit(), in_flight_batches_limit); + in_flight_batches_limit = scheduler->in_flight_batches_limit(); + TF_ASSERT_OK(ScheduleTask(3, queue.get())); + while (scheduler->in_flight_batches_limit() == in_flight_batches_limit) { + } + // Latency increased -> change direction. + EXPECT_LT(scheduler->in_flight_batches_limit(), in_flight_batches_limit); + start_teardown.Notify(); + } + stop_teardown.Notify(); +} } // namespace anonymous } // namespace serving } // namespace tensorflow diff --git a/tensorflow/contrib/batching/basic_batch_scheduler.h b/tensorflow/contrib/batching/basic_batch_scheduler.h index 9d3805fbaf39978159dd2f4a754e6d41a07acf6a..91065db2499dffd2687a53bd6304d9b7593f7b3a 100644 --- a/tensorflow/contrib/batching/basic_batch_scheduler.h +++ b/tensorflow/contrib/batching/basic_batch_scheduler.h @@ -192,6 +192,10 @@ class BasicBatchScheduler : public BatchScheduler { size_t NumEnqueuedTasks() const override; size_t SchedulingCapacity() const override; + size_t max_task_size() const override { + return shared_scheduler_queue_->max_task_size(); + } + private: explicit BasicBatchScheduler( std::unique_ptr> shared_scheduler_queue); diff --git a/tensorflow/contrib/batching/basic_batch_scheduler_test.cc b/tensorflow/contrib/batching/basic_batch_scheduler_test.cc index e020301795c7dadee2815c0e0d727e53e5fb9e6e..187823151cf840dcf8058677fcf74d1beffc3bc2 100644 --- a/tensorflow/contrib/batching/basic_batch_scheduler_test.cc +++ b/tensorflow/contrib/batching/basic_batch_scheduler_test.cc @@ -73,6 +73,7 @@ TEST(BasicBatchSchedulerTest, Basic) { std::unique_ptr> scheduler; TF_ASSERT_OK( BasicBatchScheduler::Create(options, callback, &scheduler)); + EXPECT_EQ(10, scheduler->max_task_size()); EXPECT_EQ(0, scheduler->NumEnqueuedTasks()); EXPECT_EQ(3 * 10, scheduler->SchedulingCapacity()); TF_ASSERT_OK(ScheduleTask(3, scheduler.get())); diff --git a/tensorflow/contrib/batching/batch_scheduler.h b/tensorflow/contrib/batching/batch_scheduler.h index a5072f439abad3c5db79a514a7f2baff0b021b39..e18cf6c35059e4d720768e3b2c02b03727a6bac4 100644 --- a/tensorflow/contrib/batching/batch_scheduler.h +++ b/tensorflow/contrib/batching/batch_scheduler.h @@ -178,6 +178,10 @@ class BatchScheduler { // This method is useful for monitoring, or for guaranteeing a future slot in // the schedule (but being mindful about the caveats listed above). virtual size_t SchedulingCapacity() const = 0; + + // Returns the maximum allowed size of tasks submitted to the scheduler. (This + // is typically equal to a configured maximum batch size.) + virtual size_t max_task_size() const = 0; }; ////////// diff --git a/tensorflow/contrib/batching/shared_batch_scheduler.h b/tensorflow/contrib/batching/shared_batch_scheduler.h index 41a3f99137ade2552432fee62ddce17d064148a4..86c45bdc2e66e30fbde15f6cafe481cf969c14d0 100644 --- a/tensorflow/contrib/batching/shared_batch_scheduler.h +++ b/tensorflow/contrib/batching/shared_batch_scheduler.h @@ -63,7 +63,7 @@ namespace serving { // instead of N independent ones, with their sharing deliberately coordinated. // // SharedBatchScheduler does not implement the BatchScheduler API; rather, it -// presents an abstraction of "queues", where each queue coresponds to one type +// presents an abstraction of "queues", where each queue corresponds to one type // of task. Tasks submitted to a given queue are placed in their own batches, // and cannot be mixed with other tasks. Queues can be added and deleted // dynamically, to accommodate e.g. versions of a model being brought up and @@ -248,6 +248,9 @@ class Queue { // BatchScheduler::SchedulingCapacity(). size_t SchedulingCapacity() const; + // Returns the maximum allowed size of tasks submitted to the queue. + size_t max_task_size() const { return options_.max_batch_size; } + // Called by a thread that is ready to process a batch, to request one from // this queue. Either returns a batch that is ready to be processed, or // nullptr if the queue declines to schedule a batch at this time. If it @@ -338,6 +341,8 @@ class QueueHandle : public BatchScheduler { size_t NumEnqueuedTasks() const override; size_t SchedulingCapacity() const override; + size_t max_task_size() const override { return queue_->max_task_size(); } + private: // The scheduler that owns 'queue_'. std::shared_ptr> scheduler_; diff --git a/tensorflow/contrib/batching/shared_batch_scheduler_test.cc b/tensorflow/contrib/batching/shared_batch_scheduler_test.cc index 3e924ae5f13519b4fe9a3f4b510773ca2bddaf23..3ac79a8fdc47389816db8ca09f27846d1c4623c2 100644 --- a/tensorflow/contrib/batching/shared_batch_scheduler_test.cc +++ b/tensorflow/contrib/batching/shared_batch_scheduler_test.cc @@ -429,6 +429,7 @@ TEST(SharedBatchSchedulerTest, ConstMethods) { queue_options.max_enqueued_batches = max_enqueued_batches; std::unique_ptr> queue; TF_ASSERT_OK(scheduler->AddQueue(queue_options, callback, &queue)); + EXPECT_EQ(2, queue->max_task_size()); EXPECT_EQ(0, queue->NumEnqueuedTasks()); EXPECT_EQ(max_enqueued_batches * 2, queue->SchedulingCapacity()); diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD index a262d4aecdbb69dfcd8b88bc0a09060500d6b1c9..4e0520fa33a57e2f15c39d362ec3a39945202d46 100644 --- a/tensorflow/contrib/bayesflow/BUILD +++ b/tensorflow/contrib/bayesflow/BUILD @@ -99,6 +99,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "layers_conv_variational_test", + size = "small", + srcs = ["python/kernel_tests/layers_conv_variational_test.py"], + additional_deps = [ + ":bayesflow_py", + "//third_party/py/numpy", + "//tensorflow/contrib/distributions:distributions_py", + "//tensorflow/python/ops/distributions", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:gradients", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + ], +) + cuda_py_test( name = "layers_dense_variational_test", size = "small", diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py index b1f108e5f01e4945ee83d8262f1d99877f0fe9f0..cbc66b6dc13db62c25952de6b6c13b2fdfe27f12 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for Hamiltonian Monte Carlo. -""" +"""Tests for Hamiltonian Monte Carlo.""" from __future__ import absolute_import from __future__ import division @@ -27,6 +26,7 @@ from tensorflow.contrib.bayesflow.python.ops import hmc from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import test @@ -46,6 +46,9 @@ class HMCTest(test.TestCase): random_seed.set_random_seed(10003) np.random.seed(10003) + def assertAllFinite(self, x): + self.assertAllEqual(np.ones_like(x).astype(bool), np.isfinite(x)) + def _log_gamma_log_prob(self, x, event_dims=()): """Computes log-pdf of a log-gamma random variable. @@ -345,5 +348,97 @@ class HMCTest(test.TestCase): def testAIS12(self): self._ais_gets_correct_log_normalizer_wrapper([1, 2]) + def testNanRejection(self): + """Tests that an update that yields NaN potentials gets rejected. + + We run HMC with a target distribution that returns NaN + log-likelihoods if any element of x < 0, and unit-scale + exponential log-likelihoods otherwise. The exponential potential + pushes x towards 0, ensuring that any reasonably large update will + push us over the edge into NaN territory. + """ + def _unbounded_exponential_log_prob(x): + """An exponential distribution with log-likelihood NaN for x < 0.""" + per_element_potentials = array_ops.where(x < 0, + np.nan * array_ops.ones_like(x), + -x) + return math_ops.reduce_sum(per_element_potentials) + + with self.test_session() as sess: + initial_x = math_ops.linspace(0.01, 5, 10) + updated_x, acceptance_probs, _, _ = hmc.kernel( + 2., 5, initial_x, _unbounded_exponential_log_prob, [0]) + initial_x_val, updated_x_val, acceptance_probs_val = sess.run( + [initial_x, updated_x, acceptance_probs]) + + logging.vlog(1, 'initial_x = {}'.format(initial_x_val)) + logging.vlog(1, 'updated_x = {}'.format(updated_x_val)) + logging.vlog(1, 'acceptance_probs = {}'.format(acceptance_probs_val)) + + self.assertAllEqual(initial_x_val, updated_x_val) + self.assertEqual(acceptance_probs_val, 0.) + + def testNanFromGradsDontPropagate(self): + """Test that update with NaN gradients does not cause NaN in results.""" + def _nan_log_prob_with_nan_gradient(x): + return np.nan * math_ops.reduce_sum(x) + + with self.test_session() as sess: + initial_x = math_ops.linspace(0.01, 5, 10) + updated_x, acceptance_probs, new_log_prob, new_grad = hmc.kernel( + 2., 5, initial_x, _nan_log_prob_with_nan_gradient, [0]) + initial_x_val, updated_x_val, acceptance_probs_val = sess.run( + [initial_x, updated_x, acceptance_probs]) + + logging.vlog(1, 'initial_x = {}'.format(initial_x_val)) + logging.vlog(1, 'updated_x = {}'.format(updated_x_val)) + logging.vlog(1, 'acceptance_probs = {}'.format(acceptance_probs_val)) + + self.assertAllEqual(initial_x_val, updated_x_val) + self.assertEqual(acceptance_probs_val, 0.) + + self.assertAllFinite( + gradients_impl.gradients(updated_x, initial_x)[0].eval()) + self.assertTrue( + gradients_impl.gradients(new_grad, initial_x)[0] is None) + + # Gradients of the acceptance probs and new log prob are not finite. + _ = new_log_prob # Prevent unused arg error. + # self.assertAllFinite( + # gradients_impl.gradients(acceptance_probs, initial_x)[0].eval()) + # self.assertAllFinite( + # gradients_impl.gradients(new_log_prob, initial_x)[0].eval()) + + def testChainWorksIn64Bit(self): + def log_prob(x): + return - math_ops.reduce_sum(x * x, axis=-1) + states, acceptance_probs = hmc.chain( + n_iterations=10, + step_size=np.float64(0.01), + n_leapfrog_steps=10, + initial_x=np.zeros(5).astype(np.float64), + target_log_prob_fn=log_prob, + event_dims=[-1]) + with self.test_session() as sess: + states_, acceptance_probs_ = sess.run([states, acceptance_probs]) + self.assertEqual(np.float64, states_.dtype) + self.assertEqual(np.float64, acceptance_probs_.dtype) + + def testChainWorksIn16Bit(self): + def log_prob(x): + return - math_ops.reduce_sum(x * x, axis=-1) + states, acceptance_probs = hmc.chain( + n_iterations=10, + step_size=np.float16(0.01), + n_leapfrog_steps=10, + initial_x=np.zeros(5).astype(np.float16), + target_log_prob_fn=log_prob, + event_dims=[-1]) + with self.test_session() as sess: + states_, acceptance_probs_ = sess.run([states, acceptance_probs]) + self.assertEqual(np.float16, states_.dtype) + self.assertEqual(np.float16, acceptance_probs_.dtype) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_conv_variational_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_conv_variational_test.py new file mode 100644 index 0000000000000000000000000000000000000000..57f44aef1a198f62cd8a715472a68a3d889ec3ac --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_conv_variational_test.py @@ -0,0 +1,289 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for convolutional Bayesian layers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.bayesflow.python.ops import layers_conv_variational as prob_layers_lib +from tensorflow.contrib.bayesflow.python.ops import layers_util as prob_layers_util +from tensorflow.contrib.distributions.python.ops import independent as independent_lib +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.platform import test + + +class Counter(object): + """Helper class to manage incrementing a counting `int`.""" + + def __init__(self): + self._value = -1 + + @property + def value(self): + return self._value + + def __call__(self): + self._value += 1 + return self._value + + +class MockDistribution(independent_lib.Independent): + """Monitors DenseVariational calls to the underlying distribution.""" + + def __init__(self, result_sample, result_log_prob, loc=None, scale=None): + self.result_sample = result_sample + self.result_log_prob = result_log_prob + self.result_loc = loc + self.result_scale = scale + self.result_distribution = normal_lib.Normal(loc=0.0, scale=1.0) + if loc is not None and scale is not None: + self.result_distribution = normal_lib.Normal(loc=self.result_loc, + scale=self.result_scale) + self.called_log_prob = Counter() + self.called_sample = Counter() + self.called_loc = Counter() + self.called_scale = Counter() + + def log_prob(self, *args, **kwargs): + self.called_log_prob() + return self.result_log_prob + + def sample(self, *args, **kwargs): + self.called_sample() + return self.result_sample + + @property + def distribution(self): # for dummy check on Independent(Normal) + return self.result_distribution + + @property + def loc(self): + self.called_loc() + return self.result_loc + + @property + def scale(self): + self.called_scale() + return self.result_scale + + +class MockKLDivergence(object): + """Monitors layer calls to the divergence implementation.""" + + def __init__(self, result): + self.result = result + self.args = [] + self.called = Counter() + + def __call__(self, *args, **kwargs): + self.called() + self.args.append(args) + return self.result + + +class ConvVariational(test.TestCase): + + def _testKLPenaltyKernel(self, layer_class): + with self.test_session(): + layer = layer_class(filters=2, kernel_size=3) + if layer_class == prob_layers_lib.Conv1DVariational: + inputs = random_ops.random_uniform([2, 3, 1], seed=1) + elif layer_class == prob_layers_lib.Conv2DVariational: + inputs = random_ops.random_uniform([2, 3, 3, 1], seed=1) + elif layer_class == prob_layers_lib.Conv3DVariational: + inputs = random_ops.random_uniform([2, 3, 3, 3, 1], seed=1) + + # No keys. + losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(losses), 0) + self.assertListEqual(layer.losses, losses) + + _ = layer(inputs) + + # Yes keys. + losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(losses), 1) + self.assertListEqual(layer.losses, losses) + + def _testKLPenaltyBoth(self, layer_class): + def _make_normal(dtype, *args): # pylint: disable=unused-argument + return normal_lib.Normal( + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)) + with self.test_session(): + layer = layer_class( + filters=2, + kernel_size=3, + bias_posterior_fn=prob_layers_util.default_mean_field_normal_fn(), + bias_prior_fn=_make_normal) + if layer_class == prob_layers_lib.Conv1DVariational: + inputs = random_ops.random_uniform([2, 3, 1], seed=1) + elif layer_class == prob_layers_lib.Conv2DVariational: + inputs = random_ops.random_uniform([2, 3, 3, 1], seed=1) + elif layer_class == prob_layers_lib.Conv3DVariational: + inputs = random_ops.random_uniform([2, 3, 3, 3, 1], seed=1) + + # No keys. + losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(losses), 0) + self.assertListEqual(layer.losses, losses) + + _ = layer(inputs) + + # Yes keys. + losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(losses), 2) + self.assertListEqual(layer.losses, losses) + + def _testConvVariational(self, layer_class): + batch_size, depth, height, width, channels, filters = 2, 4, 4, 4, 3, 5 + with self.test_session() as sess: + seed = Counter() + if layer_class == prob_layers_lib.Conv1DVariational: + inputs = random_ops.random_uniform( + [batch_size, width, channels], seed=seed()) + kernel_size = (2,) + elif layer_class == prob_layers_lib.Conv2DVariational: + inputs = random_ops.random_uniform( + [batch_size, height, width, channels], seed=seed()) + kernel_size = (2, 2) + elif layer_class == prob_layers_lib.Conv3DVariational: + inputs = random_ops.random_uniform( + [batch_size, depth, height, width, channels], seed=seed()) + kernel_size = (2, 2, 2) + + kernel_shape = kernel_size + (channels, filters) + kernel_posterior = MockDistribution( + result_log_prob=random_ops.random_uniform(kernel_shape, seed=seed()), + result_sample=random_ops.random_uniform(kernel_shape, seed=seed())) + kernel_prior = MockDistribution( + result_log_prob=random_ops.random_uniform(kernel_shape, seed=seed()), + result_sample=random_ops.random_uniform(kernel_shape, seed=seed())) + kernel_divergence = MockKLDivergence( + result=random_ops.random_uniform(kernel_shape, seed=seed())) + + bias_size = (filters,) + bias_posterior = MockDistribution( + result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), + result_sample=random_ops.random_uniform(bias_size, seed=seed())) + bias_prior = MockDistribution( + result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), + result_sample=random_ops.random_uniform(bias_size, seed=seed())) + bias_divergence = MockKLDivergence( + result=random_ops.random_uniform(bias_size, seed=seed())) + + convolution_op = nn_ops.Convolution( + tensor_shape.TensorShape(inputs.shape), + filter_shape=tensor_shape.TensorShape(kernel_shape), + padding="SAME") + expected_outputs = convolution_op(inputs, kernel_posterior.result_sample) + expected_outputs = nn.bias_add(expected_outputs, + bias_posterior.result_sample, + data_format="NHWC") + + layer = layer_class( + filters=filters, + kernel_size=kernel_size, + padding="SAME", + kernel_posterior_fn=lambda *args: kernel_posterior, + kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), + kernel_prior_fn=lambda *args: kernel_prior, + kernel_divergence_fn=kernel_divergence, + bias_posterior_fn=lambda *args: bias_posterior, + bias_posterior_tensor_fn=lambda d: d.sample(seed=43), + bias_prior_fn=lambda *args: bias_prior, + bias_divergence_fn=bias_divergence) + + outputs = layer(inputs) + + kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + + [ + expected_outputs_, actual_outputs_, + expected_kernel_, actual_kernel_, + expected_kernel_divergence_, actual_kernel_divergence_, + expected_bias_, actual_bias_, + expected_bias_divergence_, actual_bias_divergence_, + ] = sess.run([ + expected_outputs, outputs, + kernel_posterior.result_sample, layer.kernel_posterior_tensor, + kernel_divergence.result, kl_penalty[0], + bias_posterior.result_sample, layer.bias_posterior_tensor, + bias_divergence.result, kl_penalty[1], + ]) + + self.assertAllClose( + expected_kernel_, actual_kernel_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_bias_, actual_bias_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_outputs_, actual_outputs_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_kernel_divergence_, actual_kernel_divergence_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_bias_divergence_, actual_bias_divergence_, + rtol=1e-6, atol=0.) + + self.assertAllEqual( + [[kernel_posterior.distribution, + kernel_prior.distribution, + kernel_posterior.result_sample]], + kernel_divergence.args) + + self.assertAllEqual( + [[bias_posterior.distribution, + bias_prior.distribution, + bias_posterior.result_sample]], + bias_divergence.args) + + def testKLPenaltyKernelConv1DVariational(self): + self._testKLPenaltyKernel(prob_layers_lib.Conv1DVariational) + + def testKLPenaltyKernelConv2DVariational(self): + self._testKLPenaltyKernel(prob_layers_lib.Conv2DVariational) + + def testKLPenaltyKernelConv3DVariational(self): + self._testKLPenaltyKernel(prob_layers_lib.Conv3DVariational) + + def testKLPenaltyBothConv1DVariational(self): + self._testKLPenaltyBoth(prob_layers_lib.Conv1DVariational) + + def testKLPenaltyBothConv2DVariational(self): + self._testKLPenaltyBoth(prob_layers_lib.Conv2DVariational) + + def testKLPenaltyBothConv3DVariational(self): + self._testKLPenaltyBoth(prob_layers_lib.Conv3DVariational) + + def testConv1DVariational(self): + self._testConvVariational(prob_layers_lib.Conv1DVariational) + + def testConv2DVariational(self): + self._testConvVariational(prob_layers_lib.Conv2DVariational) + + def testConv3DVariational(self): + self._testConvVariational(prob_layers_lib.Conv3DVariational) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py index 50358fd1c2b7635ffe2d08c5af3219bb0a11498b..4e9f1193511c35beead85914ca988fde69b3afde 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py @@ -18,11 +18,18 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib.bayesflow.python.ops import layers_dense_variational_impl as prob_layers_lib +from tensorflow.contrib.bayesflow.python.ops import layers_util as prob_layers_util +from tensorflow.contrib.distributions.python.ops import independent as independent_lib +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.ops.distributions import util as distribution_util from tensorflow.python.platform import test @@ -41,14 +48,18 @@ class Counter(object): return self._value -class MockDistribution(normal_lib.Normal): - """Monitors DenseVariational calls to the underlying distribution.""" +class MockDistribution(independent_lib.Independent): + """Monitors layer calls to the underlying distribution.""" def __init__(self, result_sample, result_log_prob, loc=None, scale=None): self.result_sample = result_sample self.result_log_prob = result_log_prob self.result_loc = loc self.result_scale = scale + self.result_distribution = normal_lib.Normal(loc=0.0, scale=1.0) + if loc is not None and scale is not None: + self.result_distribution = normal_lib.Normal(loc=self.result_loc, + scale=self.result_scale) self.called_log_prob = Counter() self.called_sample = Counter() self.called_loc = Counter() @@ -62,6 +73,10 @@ class MockDistribution(normal_lib.Normal): self.called_sample() return self.result_sample + @property + def distribution(self): # for dummy check on Independent(Normal) + return self.result_distribution + @property def loc(self): self.called_loc() @@ -74,7 +89,7 @@ class MockDistribution(normal_lib.Normal): class MockKLDivergence(object): - """Monitors DenseVariational calls to the divergence implementation.""" + """Monitors layer calls to the divergence implementation.""" def __init__(self, result): self.result = result @@ -87,94 +102,125 @@ class MockKLDivergence(object): return self.result -class DenseVariationalLocalReparametrization(test.TestCase): +class DenseVariational(test.TestCase): - def testKLPenaltyKernel(self): + def _testKLPenaltyKernel(self, layer_class): with self.test_session(): - dense_vi = prob_layers_lib.DenseVariational(units=2) + layer = layer_class(units=2) inputs = random_ops.random_uniform([2, 3], seed=1) # No keys. - loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(loss_keys), 0) - self.assertListEqual(dense_vi.losses, loss_keys) + losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(losses), 0) + self.assertListEqual(layer.losses, losses) - _ = dense_vi(inputs) + _ = layer(inputs) # Yes keys. - loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(loss_keys), 1) - self.assertListEqual(dense_vi.losses, loss_keys) + losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(losses), 1) + self.assertListEqual(layer.losses, losses) - def testKLPenaltyBoth(self): + def _testKLPenaltyBoth(self, layer_class): def _make_normal(dtype, *args): # pylint: disable=unused-argument return normal_lib.Normal( loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)) with self.test_session(): - dense_vi = prob_layers_lib.DenseVariational( + layer = layer_class( units=2, - bias_posterior_fn=prob_layers_lib.default_mean_field_normal_fn(), + bias_posterior_fn=prob_layers_util.default_mean_field_normal_fn(), bias_prior_fn=_make_normal) inputs = random_ops.random_uniform([2, 3], seed=1) # No keys. - loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(loss_keys), 0) - self.assertListEqual(dense_vi.losses, loss_keys) + losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(losses), 0) + self.assertListEqual(layer.losses, losses) - _ = dense_vi(inputs) + _ = layer(inputs) # Yes keys. - loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(loss_keys), 2) - self.assertListEqual(dense_vi.losses, loss_keys) - - def testVariationalNonLocal(self): + losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(losses), 2) + self.assertListEqual(layer.losses, losses) + + def _testDenseSetUp(self, layer_class, batch_size, in_size, out_size, + **kwargs): + seed = Counter() + inputs = random_ops.random_uniform([batch_size, in_size], seed=seed()) + + kernel_size = [in_size, out_size] + kernel_posterior = MockDistribution( + loc=random_ops.random_uniform(kernel_size, seed=seed()), + scale=random_ops.random_uniform(kernel_size, seed=seed()), + result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), + result_sample=random_ops.random_uniform(kernel_size, seed=seed())) + kernel_prior = MockDistribution( + result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), + result_sample=random_ops.random_uniform(kernel_size, seed=seed())) + kernel_divergence = MockKLDivergence( + result=random_ops.random_uniform(kernel_size, seed=seed())) + + bias_size = [out_size] + bias_posterior = MockDistribution( + result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), + result_sample=random_ops.random_uniform(bias_size, seed=seed())) + bias_prior = MockDistribution( + result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), + result_sample=random_ops.random_uniform(bias_size, seed=seed())) + bias_divergence = MockKLDivergence( + result=random_ops.random_uniform(bias_size, seed=seed())) + + layer = layer_class( + units=out_size, + kernel_posterior_fn=lambda *args: kernel_posterior, + kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), + kernel_prior_fn=lambda *args: kernel_prior, + kernel_divergence_fn=kernel_divergence, + bias_posterior_fn=lambda *args: bias_posterior, + bias_posterior_tensor_fn=lambda d: d.sample(seed=43), + bias_prior_fn=lambda *args: bias_prior, + bias_divergence_fn=bias_divergence, + **kwargs) + + outputs = layer(inputs) + + kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + return (kernel_posterior, kernel_prior, kernel_divergence, + bias_posterior, bias_prior, bias_divergence, + layer, inputs, outputs, kl_penalty) + + def testKLPenaltyKernelReparameterization(self): + self._testKLPenaltyKernel(prob_layers_lib.DenseReparameterization) + + def testKLPenaltyKernelLocalReparameterization(self): + self._testKLPenaltyKernel(prob_layers_lib.DenseLocalReparameterization) + + def testKLPenaltyKernelFlipout(self): + self._testKLPenaltyKernel(prob_layers_lib.DenseFlipout) + + def testKLPenaltyBothReparameterization(self): + self._testKLPenaltyBoth(prob_layers_lib.DenseReparameterization) + + def testKLPenaltyBothLocalReparameterization(self): + self._testKLPenaltyBoth(prob_layers_lib.DenseLocalReparameterization) + + def testKLPenaltyBothFlipout(self): + self._testKLPenaltyBoth(prob_layers_lib.DenseFlipout) + + def testDenseReparameterization(self): batch_size, in_size, out_size = 2, 3, 4 with self.test_session() as sess: - seed = Counter() - inputs = random_ops.random_uniform([batch_size, in_size], seed=seed()) - - kernel_size = [in_size, out_size] - kernel_posterior = MockDistribution( - result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), - result_sample=random_ops.random_uniform(kernel_size, seed=seed())) - kernel_prior = MockDistribution( - result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), - result_sample=random_ops.random_uniform(kernel_size, seed=seed())) - kernel_divergence = MockKLDivergence( - result=random_ops.random_uniform(kernel_size, seed=seed())) - - bias_size = [out_size] - bias_posterior = MockDistribution( - result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), - result_sample=random_ops.random_uniform(bias_size, seed=seed())) - bias_prior = MockDistribution( - result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), - result_sample=random_ops.random_uniform(bias_size, seed=seed())) - bias_divergence = MockKLDivergence( - result=random_ops.random_uniform(bias_size, seed=seed())) + (kernel_posterior, kernel_prior, kernel_divergence, + bias_posterior, bias_prior, bias_divergence, layer, inputs, + outputs, kl_penalty) = self._testDenseSetUp( + prob_layers_lib.DenseReparameterization, + batch_size, in_size, out_size) expected_outputs = ( math_ops.matmul(inputs, kernel_posterior.result_sample) + bias_posterior.result_sample) - dense_vi = prob_layers_lib.DenseVariational( - units=2, - kernel_use_local_reparameterization=False, - kernel_posterior_fn=lambda *args: kernel_posterior, - kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), - kernel_prior_fn=lambda *args: kernel_prior, - kernel_divergence_fn=kernel_divergence, - bias_posterior_fn=lambda *args: bias_posterior, - bias_posterior_tensor_fn=lambda d: d.sample(seed=43), - bias_prior_fn=lambda *args: bias_prior, - bias_divergence_fn=bias_divergence) - - outputs = dense_vi(inputs) - - kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - [ expected_outputs_, actual_outputs_, expected_kernel_, actual_kernel_, @@ -183,9 +229,9 @@ class DenseVariationalLocalReparametrization(test.TestCase): expected_bias_divergence_, actual_bias_divergence_, ] = sess.run([ expected_outputs, outputs, - kernel_posterior.result_sample, dense_vi.kernel.posterior_tensor, + kernel_posterior.result_sample, layer.kernel_posterior_tensor, kernel_divergence.result, kl_penalty[0], - bias_posterior.result_sample, dense_vi.bias.posterior_tensor, + bias_posterior.result_sample, layer.bias_posterior_tensor, bias_divergence.result, kl_penalty[1], ]) @@ -206,40 +252,25 @@ class DenseVariationalLocalReparametrization(test.TestCase): rtol=1e-6, atol=0.) self.assertAllEqual( - [[kernel_posterior, kernel_prior, kernel_posterior.result_sample]], + [[kernel_posterior.distribution, + kernel_prior.distribution, + kernel_posterior.result_sample]], kernel_divergence.args) self.assertAllEqual( - [[bias_posterior, bias_prior, bias_posterior.result_sample]], + [[bias_posterior.distribution, + bias_prior.distribution, + bias_posterior.result_sample]], bias_divergence.args) - def testVariationalLocal(self): + def testDenseLocalReparameterization(self): batch_size, in_size, out_size = 2, 3, 4 with self.test_session() as sess: - seed = Counter() - inputs = random_ops.random_uniform([batch_size, in_size], seed=seed()) - - kernel_size = [in_size, out_size] - kernel_posterior = MockDistribution( - loc=random_ops.random_uniform(kernel_size, seed=seed()), - scale=random_ops.random_uniform(kernel_size, seed=seed()), - result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), - result_sample=random_ops.random_uniform(kernel_size, seed=seed())) - kernel_prior = MockDistribution( - result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), - result_sample=random_ops.random_uniform(kernel_size, seed=seed())) - kernel_divergence = MockKLDivergence( - result=random_ops.random_uniform(kernel_size, seed=seed())) - - bias_size = [out_size] - bias_posterior = MockDistribution( - result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), - result_sample=random_ops.random_uniform(bias_size, seed=seed())) - bias_prior = MockDistribution( - result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), - result_sample=random_ops.random_uniform(bias_size, seed=seed())) - bias_divergence = MockKLDivergence( - result=random_ops.random_uniform(bias_size, seed=seed())) + (kernel_posterior, kernel_prior, kernel_divergence, + bias_posterior, bias_prior, bias_divergence, layer, inputs, + outputs, kl_penalty) = self._testDenseSetUp( + prob_layers_lib.DenseLocalReparameterization, + batch_size, in_size, out_size) expected_kernel_posterior_affine = normal_lib.Normal( loc=math_ops.matmul(inputs, kernel_posterior.result_loc), @@ -250,21 +281,80 @@ class DenseVariationalLocalReparametrization(test.TestCase): expected_outputs = (expected_kernel_posterior_affine_tensor + bias_posterior.result_sample) - dense_vi = prob_layers_lib.DenseVariational( - units=2, - kernel_use_local_reparameterization=True, - kernel_posterior_fn=lambda *args: kernel_posterior, - kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), - kernel_prior_fn=lambda *args: kernel_prior, - kernel_divergence_fn=kernel_divergence, - bias_posterior_fn=lambda *args: bias_posterior, - bias_posterior_tensor_fn=lambda d: d.sample(seed=43), - bias_prior_fn=lambda *args: bias_prior, - bias_divergence_fn=bias_divergence) + [ + expected_outputs_, actual_outputs_, + expected_kernel_divergence_, actual_kernel_divergence_, + expected_bias_, actual_bias_, + expected_bias_divergence_, actual_bias_divergence_, + ] = sess.run([ + expected_outputs, outputs, + kernel_divergence.result, kl_penalty[0], + bias_posterior.result_sample, layer.bias_posterior_tensor, + bias_divergence.result, kl_penalty[1], + ]) - outputs = dense_vi(inputs) + self.assertAllClose( + expected_bias_, actual_bias_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_outputs_, actual_outputs_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_kernel_divergence_, actual_kernel_divergence_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_bias_divergence_, actual_bias_divergence_, + rtol=1e-6, atol=0.) + + self.assertAllEqual( + [[kernel_posterior.distribution, + kernel_prior.distribution, + None]], + kernel_divergence.args) + + self.assertAllEqual( + [[bias_posterior.distribution, + bias_prior.distribution, + bias_posterior.result_sample]], + bias_divergence.args) + + def testDenseFlipout(self): + batch_size, in_size, out_size = 2, 3, 4 + with self.test_session() as sess: + (kernel_posterior, kernel_prior, kernel_divergence, + bias_posterior, bias_prior, bias_divergence, layer, inputs, + outputs, kl_penalty) = self._testDenseSetUp( + prob_layers_lib.DenseFlipout, + batch_size, in_size, out_size, seed=44) + + expected_kernel_posterior_affine = normal_lib.Normal( + loc=array_ops.zeros_like(kernel_posterior.result_loc), + scale=kernel_posterior.result_scale) + expected_kernel_posterior_affine_tensor = ( + expected_kernel_posterior_affine.sample(seed=42)) - kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + sign_input = random_ops.random_uniform( + [batch_size, in_size], + minval=0, + maxval=2, + dtype=dtypes.int32, + seed=layer.seed) + sign_input = math_ops.cast(2 * sign_input - 1, inputs.dtype) + sign_output = random_ops.random_uniform( + [batch_size, out_size], + minval=0, + maxval=2, + dtype=dtypes.int32, + seed=distribution_util.gen_new_seed( + layer.seed, salt="dense_flipout")) + sign_output = math_ops.cast(2 * sign_output - 1, inputs.dtype) + perturbed_inputs = math_ops.matmul( + inputs * sign_input, expected_kernel_posterior_affine_tensor) + perturbed_inputs *= sign_output + + expected_outputs = math_ops.matmul(inputs, kernel_posterior.result_loc) + expected_outputs += perturbed_inputs + expected_outputs += bias_posterior.result_sample [ expected_outputs_, actual_outputs_, @@ -274,7 +364,7 @@ class DenseVariationalLocalReparametrization(test.TestCase): ] = sess.run([ expected_outputs, outputs, kernel_divergence.result, kl_penalty[0], - bias_posterior.result_sample, dense_vi.bias.posterior_tensor, + bias_posterior.result_sample, layer.bias_posterior_tensor, bias_divergence.result, kl_penalty[1], ]) @@ -292,13 +382,62 @@ class DenseVariationalLocalReparametrization(test.TestCase): rtol=1e-6, atol=0.) self.assertAllEqual( - [[kernel_posterior, kernel_prior, None]], + [[kernel_posterior.distribution, kernel_prior.distribution, None]], kernel_divergence.args) self.assertAllEqual( - [[bias_posterior, bias_prior, bias_posterior.result_sample]], + [[bias_posterior.distribution, + bias_prior.distribution, + bias_posterior.result_sample]], bias_divergence.args) + def testRandomDenseFlipout(self): + batch_size, in_size, out_size = 2, 3, 4 + with self.test_session() as sess: + seed = Counter() + inputs = random_ops.random_uniform([batch_size, in_size], seed=seed()) + + kernel_posterior = MockDistribution( + loc=random_ops.random_uniform( + [in_size, out_size], seed=seed()), + scale=random_ops.random_uniform( + [in_size, out_size], seed=seed()), + result_log_prob=random_ops.random_uniform( + [in_size, out_size], seed=seed()), + result_sample=random_ops.random_uniform( + [in_size, out_size], seed=seed())) + bias_posterior = MockDistribution( + loc=random_ops.random_uniform( + [out_size], seed=seed()), + scale=random_ops.random_uniform( + [out_size], seed=seed()), + result_log_prob=random_ops.random_uniform( + [out_size], seed=seed()), + result_sample=random_ops.random_uniform( + [out_size], seed=seed())) + layer_one = prob_layers_lib.DenseFlipout( + units=out_size, + kernel_posterior_fn=lambda *args: kernel_posterior, + kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), + bias_posterior_fn=lambda *args: bias_posterior, + bias_posterior_tensor_fn=lambda d: d.sample(seed=43), + seed=44) + layer_two = prob_layers_lib.DenseFlipout( + units=out_size, + kernel_posterior_fn=lambda *args: kernel_posterior, + kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), + bias_posterior_fn=lambda *args: bias_posterior, + bias_posterior_tensor_fn=lambda d: d.sample(seed=43), + seed=45) + + outputs_one = layer_one(inputs) + outputs_two = layer_two(inputs) + + outputs_one_, outputs_two_ = sess.run([ + outputs_one, outputs_two]) + + self.assertLess(np.sum(np.isclose(outputs_one_, outputs_two_)), out_size) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py index 333dce929530adceb30dcb63653a5bd009c059e0..5685a942e98800a39ec718adc67bcfd43aeafd52 100644 --- a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py @@ -27,6 +27,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -174,9 +175,11 @@ def chain(n_iterations, step_size, n_leapfrog_steps, initial_x, potential_and_grad = _make_potential_and_grad(target_log_prob_fn) potential, grad = potential_and_grad(initial_x) - return functional_ops.scan(body, array_ops.zeros(n_iterations), - (initial_x, array_ops.zeros(non_event_shape), - -potential, -grad))[:2] + return functional_ops.scan( + body, array_ops.zeros(n_iterations, dtype=initial_x.dtype), + (initial_x, + array_ops.zeros(non_event_shape, dtype=initial_x.dtype), + -potential, -grad))[:2] def ais_chain(n_iterations, step_size, n_leapfrog_steps, initial_x, @@ -298,8 +301,9 @@ def ais_chain(n_iterations, step_size, n_leapfrog_steps, initial_x, return updated_x, acceptance_probs, w x, acceptance_probs, w = functional_ops.scan( - _body, beta_series, (initial_x, array_ops.zeros(non_event_shape), - array_ops.zeros(non_event_shape))) + _body, beta_series, + (initial_x, array_ops.zeros(non_event_shape, dtype=initial_x.dtype), + array_ops.zeros(non_event_shape, dtype=initial_x.dtype))) return w[-1], x[-1], acceptance_probs[-1] @@ -446,9 +450,10 @@ def kernel(step_size, n_leapfrog_steps, x, target_log_prob_fn, event_dims=(), """ with ops.name_scope(name, 'hmc_kernel', [step_size, n_leapfrog_steps, x]): potential_and_grad = _make_potential_and_grad(target_log_prob_fn) + x = ops.convert_to_tensor(x, name='x') x_shape = array_ops.shape(x) - m = random_ops.random_normal(x_shape) + m = random_ops.random_normal(x_shape, dtype=x.dtype) kinetic_0 = 0.5 * math_ops.reduce_sum(math_ops.square(m), event_dims) @@ -468,26 +473,33 @@ def kernel(step_size, n_leapfrog_steps, x, target_log_prob_fn, event_dims=(), kinetic_1 = 0.5 * math_ops.reduce_sum(math_ops.square(new_m), event_dims) - # TODO(mhoffman): It seems like there may be an opportunity for nans here. - # I'm delaying addressing this because we're going to refactor this part - # to use the more general Metropolis abstraction anyway. - acceptance_probs = math_ops.exp(math_ops.minimum(0., log_potential_0 - - log_potential_1 + - kinetic_0 - kinetic_1)) - accepted = math_ops.cast( - random_ops.random_uniform(array_ops.shape(acceptance_probs)) < - acceptance_probs, np.float32) - new_log_prob = (-log_potential_0 * (1. - accepted) - - log_potential_1 * accepted) + energy_change = log_potential_1 - log_potential_0 + kinetic_1 - kinetic_0 + # Treat NaN as infinite energy (and therefore guaranteed rejection). + energy_change = array_ops.where( + math_ops.is_nan(energy_change), + array_ops.fill(array_ops.shape(energy_change), + energy_change.dtype.as_numpy_dtype(np.inf)), + energy_change) + acceptance_probs = math_ops.exp(math_ops.minimum(-energy_change, 0.)) + accepted = ( + random_ops.random_uniform( + array_ops.shape(acceptance_probs), dtype=x.dtype) + < acceptance_probs) + new_log_prob = -array_ops.where(accepted, log_potential_1, log_potential_0) # TODO(b/65738010): This should work, but it doesn't for now. # reduced_shape = math_ops.reduced_shape(x_shape, event_dims) reduced_shape = array_ops.shape(math_ops.reduce_sum(x, event_dims, keep_dims=True)) accepted = array_ops.reshape(accepted, reduced_shape) - new_x = x * (1. - accepted) + new_x * accepted - new_grad = -grad_0 * (1. - accepted) - grad_1 * accepted - + accepted = math_ops.logical_or( + accepted, math_ops.cast(array_ops.zeros_like(x), dtypes.bool)) + new_x = array_ops.where(accepted, new_x, x) + new_grad = -array_ops.where(accepted, grad_1, grad_0) + + # TODO(langmore) Gradients of acceptance_probs and new_log_prob with respect + # to initial_x will propagate NaNs (see testNanFromGradsDontPropagate). This + # should be fixed. return new_x, acceptance_probs, new_log_prob, new_grad @@ -525,6 +537,7 @@ def leapfrog_integrator(step_size, n_steps, initial_position, initial_momentum, Has shape matching `initial_position`. Example: Simple quadratic potential. + ```python def potential_and_grad(position): return tf.reduce_sum(0.5 * tf.square(position)), position @@ -600,6 +613,7 @@ def leapfrog_step(step_size, position, momentum, potential_and_grad, grad, Has shape matching `position`. Example: Simple quadratic potential. + ```python def potential_and_grad(position): # Simple quadratic potential diff --git a/tensorflow/contrib/bayesflow/python/ops/layers.py b/tensorflow/contrib/bayesflow/python/ops/layers.py index dcead38af826a12e776160bdb251ba021e6b953c..93412afae738564d440065f230c9df0036591467 100644 --- a/tensorflow/contrib/bayesflow/python/ops/layers.py +++ b/tensorflow/contrib/bayesflow/python/ops/layers.py @@ -23,13 +23,31 @@ from __future__ import print_function # go/tf-wildcard-import # pylint: disable=wildcard-import +from tensorflow.contrib.bayesflow.python.ops.layers_conv_variational import * from tensorflow.contrib.bayesflow.python.ops.layers_dense_variational_impl import * +from tensorflow.contrib.bayesflow.python.ops.layers_util import * # pylint: enable=wildcard-import from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ - 'DenseVariational', - 'dense_variational', + 'Convolution1DVariational', + 'Convolution2DVariational', + 'Convolution3DVariational', + 'Conv1DVariational', + 'Conv2DVariational', + 'Conv3DVariational', + 'convolution1d_variational', + 'convolution2d_variational', + 'convolution3d_variational', + 'conv1d_variational', + 'conv2d_variational', + 'conv3d_variational', + 'DenseReparameterization', + 'DenseLocalReparameterization', + 'DenseFlipout', + 'dense_reparameterization', + 'dense_local_reparameterization', + 'dense_flipout', 'default_loc_scale_fn', 'default_mean_field_normal_fn', ] diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py b/tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py new file mode 100644 index 0000000000000000000000000000000000000000..6ffb55feb1ad754bf96473c075ad6fd38d4e8be9 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py @@ -0,0 +1,1415 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Convolutional variational layer classes and their functional aliases. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.bayesflow.python.ops import layers_util +from tensorflow.contrib.distributions.python.ops import independent as independent_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.layers import base as layers_lib +from tensorflow.python.layers import utils +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import standard_ops +from tensorflow.python.ops.distributions import kullback_leibler as kl_lib +from tensorflow.python.ops.distributions import normal as normal_lib + + +class _ConvVariational(layers_lib.Layer): + """Abstract nD convolution layer (private, used as implementation base). + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of n integers, specifying the + length of the convolution window. + strides: An integer or tuple/list of n integers, + specifying the stride length of the convolution. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, ..., channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, ...)`. + dilation_rate: An integer or tuple/list of n integers, specifying + the dilation rate to use for dilated convolution. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any `strides` value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + name: A string, the name of the layer. + + Properties: + rank: Python integer, dimensionality of convolution. + filters: Python integer, dimensionality of the output space. + kernel_size: Size of the convolution window. + strides: Stride length of convolution. + padding: Python string describing padding approach. + data_format: Python string describing input data's dimensions. + dilation_rate: Dilation rate for an atrous convolution. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_use_local_reparameterization: Python `bool` indicating whether + `kernel` calculation should employ the Local Reparameterization Trick. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + """ + + def __init__( + self, + rank, + filters, + kernel_size, + strides=1, + padding="valid", + data_format="channels_last", + dilation_rate=1, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + **kwargs): + super(_ConvVariational, self).__init__( + trainable=trainable, + name=name, + activity_regularizer=activity_regularizer, + **kwargs) + self.rank = rank + self.filters = filters + self.kernel_size = utils.normalize_tuple(kernel_size, rank, "kernel_size") + self.strides = utils.normalize_tuple(strides, rank, "strides") + self.padding = utils.normalize_padding(padding) + self.data_format = utils.normalize_data_format(data_format) + self.dilation_rate = utils.normalize_tuple( + dilation_rate, rank, "dilation_rate") + self.activation = activation + self.input_spec = layers_lib.InputSpec(ndim=self.rank + 2) + self.kernel_posterior_fn = kernel_posterior_fn + self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn + self.kernel_prior_fn = kernel_prior_fn + self.kernel_divergence_fn = kernel_divergence_fn + self.bias_posterior_fn = bias_posterior_fn + self.bias_posterior_tensor_fn = bias_posterior_tensor_fn + self.bias_prior_fn = bias_prior_fn + self.bias_divergence_fn = bias_divergence_fn + + def build(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape) + if self.data_format == "channels_first": + channel_axis = 1 + else: + channel_axis = -1 + if input_shape[channel_axis].value is None: + raise ValueError("The channel dimension of the inputs " + "should be defined. Found `None`.") + input_dim = input_shape[channel_axis].value + kernel_shape = self.kernel_size + (input_dim, self.filters) + dtype = dtypes.as_dtype(self.dtype) + + # Must have a posterior kernel. + self.kernel_posterior = self.kernel_posterior_fn( + dtype, kernel_shape, "kernel_posterior", + self.trainable, self.add_variable) + + if self.kernel_prior_fn is None: + self.kernel_prior = None + else: + self.kernel_prior = self.kernel_prior_fn( + dtype, kernel_shape, "kernel_prior", + self.trainable, self.add_variable) + self._built_kernel_divergence = False + + if self.bias_posterior_fn is None: + self.bias_posterior = None + else: + self.bias_posterior = self.bias_posterior_fn( + dtype, (self.filters,), "bias_posterior", + self.trainable, self.add_variable) + + if self.bias_prior_fn is None: + self.bias_prior = None + else: + self.bias_prior = self.bias_prior_fn( + dtype, (self.filters,), "bias_prior", + self.trainable, self.add_variable) + self._built_bias_divergence = False + + self.input_spec = layers_lib.InputSpec(ndim=self.rank + 2, + axes={channel_axis: input_dim}) + self._convolution_op = nn_ops.Convolution( + input_shape, + filter_shape=tensor_shape.TensorShape(kernel_shape), + dilation_rate=self.dilation_rate, + strides=self.strides, + padding=self.padding.upper(), + data_format=utils.convert_data_format(self.data_format, + self.rank + 2)) + + self.built = True + + def call(self, inputs): + inputs = ops.convert_to_tensor(inputs, dtype=self.dtype) + + outputs = self._apply_variational_kernel(inputs) + outputs = self._apply_variational_bias(outputs) + if self.activation is not None: + outputs = self.activation(outputs) + if not self._built_kernel_divergence: + kernel_posterior = self.kernel_posterior + kernel_prior = self.kernel_prior + if isinstance(self.kernel_posterior, independent_lib.Independent): + kernel_posterior = kernel_posterior.distribution + if isinstance(self.kernel_prior, independent_lib.Independent): + kernel_prior = kernel_prior.distribution + self._apply_divergence(self.kernel_divergence_fn, + kernel_posterior, + kernel_prior, + self.kernel_posterior_tensor, + name="divergence_kernel") + self._built_kernel_divergence = True + if not self._built_bias_divergence: + bias_posterior = self.bias_posterior + bias_prior = self.bias_prior + if isinstance(self.bias_posterior, independent_lib.Independent): + bias_posterior = bias_posterior.distribution + if isinstance(self.bias_prior, independent_lib.Independent): + bias_prior = bias_prior.distribution + self._apply_divergence(self.bias_divergence_fn, + bias_posterior, + bias_prior, + self.bias_posterior_tensor, + name="divergence_bias") + self._built_bias_divergence = True + return outputs + + def _apply_variational_kernel(self, inputs): + self.kernel_posterior_tensor = self.kernel_posterior_tensor_fn( + self.kernel_posterior) + outputs = self._convolution_op(inputs, self.kernel_posterior_tensor) + return outputs + + def _apply_variational_bias(self, inputs): + if self.bias_posterior is None: + self.bias_posterior_tensor = None + return inputs + self.bias_posterior_tensor = self.bias_posterior_tensor_fn( + self.bias_posterior) + outputs = inputs + if self.data_format == "channels_first": + if self.rank == 1: + # nn.bias_add does not accept a 1D input tensor. + bias = array_ops.reshape(self.bias_posterior_tensor, + (1, self.filters, 1)) + outputs += bias + if self.rank == 2: + outputs = nn.bias_add(outputs, + self.bias_posterior_tensor, + data_format="NCHW") + if self.rank == 3: + # As of Mar 2017, direct addition is significantly slower than + # bias_add when computing gradients. To use bias_add, we collapse Z + # and Y into a single dimension to obtain a 4D input tensor. + outputs_shape = outputs.shape.as_list() + outputs_4d = array_ops.reshape(outputs, + [outputs_shape[0], outputs_shape[1], + outputs_shape[2] * outputs_shape[3], + outputs_shape[4]]) + outputs_4d = nn.bias_add(outputs_4d, + self.bias_posterior_tensor, + data_format="NCHW") + outputs = array_ops.reshape(outputs_4d, outputs_shape) + else: + outputs = nn.bias_add(outputs, + self.bias_posterior_tensor, + data_format="NHWC") + return outputs + + def _apply_divergence(self, divergence_fn, posterior, prior, + posterior_tensor, name): + if (divergence_fn is None or + posterior is None or + prior is None): + divergence = None + return + divergence = standard_ops.identity( + divergence_fn( + posterior, prior, posterior_tensor), + name=name) + self.add_loss(divergence) + + def _compute_output_shape(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape).as_list() + if self.data_format == "channels_last": + space = input_shape[1:-1] + new_space = [] + for i in range(len(space)): + new_dim = utils.conv_output_length( + space[i], + self.kernel_size[i], + padding=self.padding, + stride=self.strides[i], + dilation=self.dilation_rate[i]) + new_space.append(new_dim) + return tensor_shape.TensorShape([input_shape[0]] + new_space + + [self.filters]) + else: + space = input_shape[2:] + new_space = [] + for i in range(len(space)): + new_dim = utils.conv_output_length( + space[i], + self.kernel_size[i], + padding=self.padding, + stride=self.strides[i], + dilation=self.dilation_rate[i]) + new_space.append(new_dim) + return tensor_shape.TensorShape([input_shape[0], self.filters] + + new_space) + + +class Conv1DVariational(_ConvVariational): + """1D convolution layer (e.g. temporal convolution). + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of a single integer, specifying the + length of the 1D convolution window. + strides: An integer or tuple/list of a single integer, + specifying the stride length of the convolution. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, length, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, length)`. + dilation_rate: An integer or tuple/list of a single integer, specifying + the dilation rate to use for dilated convolution. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any `strides` value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + name: A string, the name of the layer. + + Properties: + filters: Python integer, dimensionality of the output space. + kernel_size: Size of the convolution window. + strides: Stride length of convolution. + padding: Python string describing padding approach. + data_format: Python string describing input data's dimensions. + dilation_rate: Dilation rate for an atrous convolution. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_use_local_reparameterization: Python `bool` indicating whether + `kernel` calculation should employ the Local Reparameterization Trick. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tf.reshape(features, [-1, 128, 1]) + net = tfp.layers.Conv1DVariational(64, + kernel_size=5, + padding="SAME", + activation=tf.nn.relu)(net) + net = tf.reshape(net, [-1, 128 * 64]) + logits = tfp.layers.DenseVariational(10)(net) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses reparameterization gradients to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + """ + + def __init__( + self, + filters, + kernel_size, + strides=1, + padding="valid", + data_format="channels_last", + dilation_rate=1, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + **kwargs): + super(Conv1DVariational, self).__init__( + rank=1, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, **kwargs) + + +def conv1d_variational( + inputs, + filters, + kernel_size, + strides=1, + padding="valid", + data_format="channels_last", + dilation_rate=1, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + reuse=None): + """Functional interface for 1D convolution layer (e.g. temporal convolution). + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + inputs: Tensor input. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of a single integer, specifying the + length of the 1D convolution window. + strides: An integer or tuple/list of a single integer, + specifying the stride length of the convolution. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, length, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, length)`. + dilation_rate: An integer or tuple/list of a single integer, specifying + the dilation rate to use for dilated convolution. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any `strides` value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + name: A string, the name of the layer. + reuse: Boolean, whether to reuse the weights of a previous layer + by the same name. + + Returns: + Output tensor. + + Raises: + ValueError: if eager execution is enabled. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tf.reshape(features, [-1, 128, 1]) + net = tfp.layers.conv1d_variational(net, + 64, + kernel_size=5, + padding="SAME", + activation=tf.nn.relu) + net = tf.reshape(net, [-1, 128 * 64]) + logits = tfp.layers.dense_variational(net, 10) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses reparameterization gradients to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + """ + layer = Conv1DVariational( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse) + return layer.apply(inputs) + + +class Conv2DVariational(_ConvVariational): + """2D convolution layer (e.g. spatial convolution over images). + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of 2 integers, specifying the + height and width of the 2D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 2 integers, + specifying the strides of the convolution along the height and width. + Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, height, width, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, height, width)`. + + dilation_rate: An integer or tuple/list of 2 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + name: A string, the name of the layer. + + Properties: + filters: Python integer, dimensionality of the output space. + kernel_size: Size of the convolution window. + strides: Stride length of convolution. + padding: Python string describing padding approach. + data_format: Python string describing input data's dimensions. + dilation_rate: Dilation rate for an atrous convolution. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_use_local_reparameterization: Python `bool` indicating whether + `kernel` calculation should employ the Local Reparameterization Trick. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tf.reshape(features, [-1, 32, 32, 3]) + net = tfp.layers.Conv2DVariational(64, + kernel_size=5, + padding="SAME", + activation=tf.nn.relu)(net) + net = tf.layers.MaxPooling2D(pool_size=2, + strides=2, + padding="SAME")(net) + net = tf.reshape(net, [-1, 8 * 8 * 64]) + logits = tfp.layers.DenseVariational(10)(net) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses reparameterization gradients to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + """ + + def __init__( + self, + filters, + kernel_size, + strides=(1, 1), + padding="valid", + data_format="channels_last", + dilation_rate=(1, 1), + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + **kwargs): + super(Conv2DVariational, self).__init__( + rank=2, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, **kwargs) + + +def conv2d_variational( + inputs, + filters, + kernel_size, + strides=(1, 1), + padding="valid", + data_format="channels_last", + dilation_rate=(1, 1), + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + reuse=None): + """Functional interface for the 2D convolution layer. + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + inputs: Tensor input. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of 2 integers, specifying the + height and width of the 2D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 2 integers, + specifying the strides of the convolution along the height and width. + Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, height, width, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, height, width)`. + + dilation_rate: An integer or tuple/list of 2 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + name: A string, the name of the layer. + reuse: Boolean, whether to reuse the weights of a previous layer + by the same name. + + Returns: + Output tensor. + + Raises: + ValueError: if eager execution is enabled. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tf.reshape(features, [-1, 32, 32, 3]) + net = tfp.layers.conv2d_variational(net, + 64, + kernel_size=5, + padding="SAME", + activation=tf.nn.relu) + net = tf.layers.max_pooling2d(net, + pool_size=2, + strides=2, + padding="SAME") + net = tf.reshape(net, [-1, 8 * 8 * 64]) + logits = tfp.layers.dense_variational(net, 10) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses reparameterization gradients to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + """ + layer = Conv2DVariational( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse) + return layer.apply(inputs) + + +class Conv3DVariational(_ConvVariational): + """3D convolution layer (e.g. spatial convolution over volumes). + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of 3 integers, specifying the + depth, height and width of the 3D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 3 integers, + specifying the strides of the convolution along the depth, + height and width. + Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, depth, height, width, channels)` while `channels_first` + corresponds to inputs with shape + `(batch, channels, depth, height, width)`. + dilation_rate: An integer or tuple/list of 3 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + name: A string, the name of the layer. + + Properties: + filters: Python integer, dimensionality of the output space. + kernel_size: Size of the convolution window. + strides: Stride length of convolution. + padding: Python string describing padding approach. + data_format: Python string describing input data's dimensions. + dilation_rate: Dilation rate for an atrous convolution. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_use_local_reparameterization: Python `bool` indicating whether + `kernel` calculation should employ the Local Reparameterization Trick. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tf.reshape(features, [-1, 256, 32, 32, 3]) + net = tfp.layers.Conv3DVariational(64, + kernel_size=5, + padding="SAME", + activation=tf.nn.relu)(net) + net = tf.layers.MaxPooling2D(pool_size=2, + strides=2, + padding="SAME")(net) + net = tf.reshape(net, [-1, 256 * 8 * 8 * 64]) + logits = tfp.layers.DenseVariational(10)(net) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses reparameterization gradients to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + """ + + def __init__( + self, + filters, + kernel_size, + strides=(1, 1, 1), + padding="valid", + data_format="channels_last", + dilation_rate=(1, 1, 1), + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + **kwargs): + super(Conv3DVariational, self).__init__( + rank=3, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, **kwargs) + + +def conv3d_variational( + inputs, + filters, + kernel_size, + strides=(1, 1, 1), + padding="valid", + data_format="channels_last", + dilation_rate=(1, 1, 1), + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + reuse=None): + """Functional interface for the 3D convolution layer. + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + inputs: Tensor input. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of 3 integers, specifying the + depth, height and width of the 3D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 3 integers, + specifying the strides of the convolution along the depth, + height and width. + Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, depth, height, width, channels)` while `channels_first` + corresponds to inputs with shape + `(batch, channels, depth, height, width)`. + dilation_rate: An integer or tuple/list of 3 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + name: A string, the name of the layer. + reuse: Boolean, whether to reuse the weights of a previous layer + by the same name. + + Returns: + Output tensor. + + Raises: + ValueError: if eager execution is enabled. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tf.reshape(features, [-1, 256, 32, 32, 3]) + net = tfp.layers.conv3d_variational(net, + 64, + kernel_size=5, + padding="SAME", + activation=tf.nn.relu) + net = tf.layers.max_pooling2d(net, + pool_size=2, + strides=2, + padding="SAME") + net = tf.reshape(net, [-1, 256 * 8 * 8 * 64]) + logits = tfp.layers.dense_variational(net, 10) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses reparameterization gradients to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + """ + layer = Conv3DVariational( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse) + return layer.apply(inputs) + + +# Aliases + +Convolution1DVariational = Conv1DVariational +Convolution2DVariational = Conv2DVariational +Convolution3DVariational = Conv3DVariational +convolution1d_variational = conv1d_variational +convolution2d_variational = conv2d_variational +convolution3d_variational = conv3d_variational diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py b/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py index b05ce0ffc1dd55ffb029b339a846a9aa5c877620..a749a396f15188ef345b4ae7c53017b6004c5e71 100644 --- a/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py @@ -14,221 +14,60 @@ # ============================================================================== """Dense Bayesian layer using KL-divergence based variational inference. -@@DenseVariational -@@dense_variational - -@@default_loc_scale_fn -@@default_mean_field_normal_fn +@@DenseReparameterization +@@DenseLocalReparameterization +@@DenseFlipout +@@dense_reparameterization +@@dense_local_reparameterization +@@dense_flipout """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - -from tensorflow.contrib.distributions.python.ops import deterministic as deterministic_lib +from tensorflow.contrib.bayesflow.python.ops import layers_util +from tensorflow.contrib.distributions.python.ops import independent as independent_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.layers import base as layers_lib -from tensorflow.python.ops import init_ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn -from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import standard_ops from tensorflow.python.ops.distributions import kullback_leibler as kl_lib from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.ops.distributions import util as distribution_util __all__ = [ - "DenseVariational", - "dense_variational", - "default_loc_scale_fn", - "default_mean_field_normal_fn", + "DenseReparameterization", + "DenseLocalReparameterization", + "DenseFlipout", + "dense_reparameterization", + "dense_local_reparameterization", + "dense_flipout", ] -def default_loc_scale_fn( - is_singular=False, - loc_initializer=init_ops.random_normal_initializer(stddev=0.1), - untransformed_scale_initializer=init_ops.random_normal_initializer( - mean=-3., stddev=0.1), - loc_regularizer=None, - untransformed_scale_regularizer=None, - loc_constraint=None, - untransformed_scale_constraint=None): - """Makes closure which creates `loc`, `scale` params from `tf.get_variable`. - - This function produces a closure which produces `loc`, `scale` using - `tf.get_variable`. The closure accepts the following arguments: - - dtype: Type of parameter's event. - shape: Python `list`-like representing the parameter's event shape. - name: Python `str` name prepended to any created (or existing) - `tf.Variable`s. - trainable: Python `bool` indicating all created `tf.Variable`s should be - added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. - add_variable_fn: `tf.get_variable`-like `callable` used to create (or - access existing) `tf.Variable`s. - - Args: - is_singular: Python `bool` indicating if `scale is None`. Default: `False`. - loc_initializer: Initializer function for the `loc` parameters. - The default is `tf.random_normal_initializer(mean=0., stddev=0.1)`. - untransformed_scale_initializer: Initializer function for the `scale` - parameters. Default value: `tf.random_normal_initializer(mean=-3., - stddev=0.1)`. This implies the softplus transformed result has mean - approximately `0.05` and std. deviation approximately `0.005`. - loc_regularizer: Regularizer function for the `loc` parameters. - The default (`None`) is to use the `tf.get_variable` default. - untransformed_scale_regularizer: Regularizer function for the `scale` - parameters. The default (`None`) is to use the `tf.get_variable` default. - loc_constraint: An optional projection function to be applied to the - loc after being updated by an `Optimizer`. The function must take as input - the unprojected variable and must return the projected variable (which - must have the same shape). Constraints are not safe to use when doing - asynchronous distributed training. - The default (`None`) is to use the `tf.get_variable` default. - untransformed_scale_constraint: An optional projection function to be - applied to the `scale` parameters after being updated by an `Optimizer` - (e.g. used to implement norm constraints or value constraints). The - function must take as input the unprojected variable and must return the - projected variable (which must have the same shape). Constraints are not - safe to use when doing asynchronous distributed training. The default - (`None`) is to use the `tf.get_variable` default. - - Returns: - default_loc_scale_fn: Python `callable` which instantiates `loc`, `scale` - parameters from args: `dtype, shape, name, trainable, add_variable_fn`. - """ - def _fn(dtype, shape, name, trainable, add_variable_fn): - """Creates `loc`, `scale` parameters.""" - loc = add_variable_fn( - name=name + "_loc", - shape=shape, - initializer=loc_initializer, - regularizer=loc_regularizer, - constraint=loc_constraint, - dtype=dtype, - trainable=trainable) - if is_singular: - return loc, None - untransformed_scale = add_variable_fn( - name=name + "_untransformed_scale", - shape=shape, - initializer=untransformed_scale_initializer, - regularizer=untransformed_scale_regularizer, - constraint=untransformed_scale_constraint, - dtype=dtype, - trainable=trainable) - scale = (np.finfo(dtype.as_numpy_dtype).eps + - nn_ops.softplus(untransformed_scale)) - return loc, scale - return _fn - - -def default_mean_field_normal_fn( - is_singular=False, - loc_initializer=None, - untransformed_scale_initializer=None, - loc_regularizer=None, - untransformed_scale_regularizer=None, - loc_constraint=None, - untransformed_scale_constraint=None): - """Creates a function to build Normal distributions with trainable params. - - This function produces a closure which produces `tf.distributions.Normal` - parameterized by a loc` and `scale` each created using `tf.get_variable`. The - produced closure accepts the following arguments: - - name: Python `str` name prepended to any created (or existing) - `tf.Variable`s. - shape: Python `list`-like representing the parameter's event shape. - dtype: Type of parameter's event. - trainable: Python `bool` indicating all created `tf.Variable`s should be - added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. - add_variable_fn: `tf.get_variable`-like `callable` used to create (or - access existing) `tf.Variable`s. +class _DenseVariational(layers_lib.Layer): + """Abstract densely-connected class (private, used as implementation base). - Args: - is_singular: Python `bool` if `True`, forces the special case limit of - `scale->0`, i.e., a `Deterministic` distribution. - loc_initializer: Initializer function for the `loc` parameters. - If `None` (default), values are initialized using the default - initializer used by `tf.get_variable`. - untransformed_scale_initializer: Initializer function for the `scale` - parameters. If `None` (default), values are initialized using the default - initializer used by `tf.get_variable`. - loc_regularizer: Regularizer function for the `loc` parameters. - untransformed_scale_regularizer: Regularizer function for the `scale` - parameters. - loc_constraint: An optional projection function to be applied to the - loc after being updated by an `Optimizer`. The function must take as input - the unprojected variable and must return the projected variable (which - must have the same shape). Constraints are not safe to use when doing - asynchronous distributed training. - untransformed_scale_constraint: An optional projection function to be - applied to the `scale` parameters after being updated by an `Optimizer` - (e.g. used to implement norm constraints or value constraints). The - function must take as input the unprojected variable and must return the - projected variable (which must have the same shape). Constraints are not - safe to use when doing asynchronous distributed training. - - Returns: - make_normal_fn: Python `callable` which creates a `tf.distributions.Normal` - using from args: `dtype, shape, name, trainable, add_variable_fn`. - """ - loc_scale_fn_ = default_loc_scale_fn( - is_singular, - loc_initializer, - untransformed_scale_initializer, - loc_regularizer, - untransformed_scale_regularizer, - loc_constraint, - untransformed_scale_constraint) - def _fn(dtype, shape, name, trainable, add_variable_fn): - """Creates a batch of `Deterministic` or `Normal` distributions.""" - loc, scale = loc_scale_fn_(dtype, shape, name, trainable, add_variable_fn) - if scale is None: - return deterministic_lib.Deterministic(loc=loc) - return normal_lib.Normal(loc=loc, scale=scale) - return _fn - - -class DenseVariational(layers_lib.Layer): - """Densely-connected variational class. - - This layer implements the Bayesian variational inference analogue to: - `outputs = activation(matmul(inputs, kernel) + bias)` - by assuming the `kernel` and/or the `bias` are random variables. - - The layer implements a stochastic dense calculation by making a Monte Carlo - approximation of a [variational Bayesian method based on KL divergence]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), i.e., + This layer implements the Bayesian variational inference analogue to + a dense layer by assuming the `kernel` and/or the `bias` are drawn + from distributions. By default, the layer implements a stochastic + forward pass via sampling from the kernel and bias posteriors, ```none - -log p(y|x) = -log int_{R**d} p(y|x,w) p(w) dw - = -log int_{R**d} p(y,w|x) q(w|x) / q(w|x) dw - <= E_q(W|x)[-log p(y,W|x) + log q(W|x)] # Jensen's - = E_q(W|x)[-log p(y|x,W)] + KL[q(W|x), p(W)] - ~= m**-1 sum{ -log(y|x,w[j]) : w[j] ~ q(W|x), j=1..m } - + KL[q(W|x), p(W)] + kernel, bias ~ posterior + outputs = activation(matmul(inputs, kernel) + bias) ``` - where `W` denotes the (independent) `kernel` and `bias` random variables, `w` - is a random variate or outcome of `W`, `y` is the label, `x` is the evidence`, - and `~=` denotes an approximation which becomes exact as `m->inf`. The above - bound is sometimes referred to as the negative Evidence Lower BOund or - negative [ELBO](https://arxiv.org/abs/1601.00670). In context of a DNN, this - layer is appropriate to use when the final loss is a negative log-likelihood. - - The Monte-Carlo sum portion is used for the feed-forward calculation of the - DNN. The KL divergence portion can be added to the final loss via: - `loss += sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))`. - The arguments permit separate specification of the surrogate posterior (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - random variables (which together comprise `W`). + distributions. Args: units: Integer or Long, dimensionality of the output space. @@ -237,10 +76,6 @@ class DenseVariational(layers_lib.Layer): activity_regularizer: Regularizer function for the output. trainable: Boolean, if `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_use_local_reparameterization: Python `bool` indicating whether - `kernel` calculation should employ the Local Reparameterization Trick. - When `True`, `kernel_posterior_fn` must create an instance of - `tf.distributions.Normal`. kernel_posterior_fn: Python `callable` which creates `tf.distributions.Distribution` instance representing the surrogate posterior of the `kernel` parameter. Default value: @@ -283,12 +118,14 @@ class DenseVariational(layers_lib.Layer): units: Python integer, dimensionality of the output space. activation: Activation function (`callable`). activity_regularizer: Regularizer function for the output. - kernel_use_local_reparameterization: Python `bool` indicating whether - `kernel` calculation should employ the Local Reparameterization Trick. - kernel: `VariationalKernelParamater` instance containing all `kernel` - related properties and `callable`s. - bias: `VariationalParameter` instance containing all `kernel` - related properties and `callable`s. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. """ def __init__( @@ -297,66 +134,33 @@ class DenseVariational(layers_lib.Layer): activation=None, activity_regularizer=None, trainable=True, - kernel_use_local_reparameterization=True, - kernel_posterior_fn=default_mean_field_normal_fn(), + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), kernel_posterior_tensor_fn=lambda d: d.sample(), kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=default_mean_field_normal_fn(is_singular=True), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long bias_posterior_tensor_fn=lambda d: d.sample(), bias_prior_fn=None, bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), name=None, **kwargs): - super(DenseVariational, self).__init__( + super(_DenseVariational, self).__init__( trainable=trainable, name=name, activity_regularizer=activity_regularizer, **kwargs) - self._units = units - self._activation = activation - self._input_spec = layers_lib.InputSpec(min_ndim=2) - self._kernel_use_local_reparameterization = ( - kernel_use_local_reparameterization) - self._kernel = VariationalKernelParameter( - kernel_posterior_fn, - kernel_posterior_tensor_fn, - kernel_prior_fn, - kernel_divergence_fn) - self._bias = VariationalParameter( - bias_posterior_fn, - bias_posterior_tensor_fn, - bias_prior_fn, - bias_divergence_fn) - - @property - def units(self): - return self._units - - @property - def activation(self): - return self._activation - - @property - def input_spec(self): - return self._input_spec - - @input_spec.setter - def input_spec(self, value): - self._input_spec = value - - @property - def kernel_use_local_reparameterization(self): - return self._kernel_use_local_reparameterization - - @property - def kernel(self): - return self._kernel - - @property - def bias(self): - return self._bias + self.units = units + self.activation = activation + self.input_spec = layers_lib.InputSpec(min_ndim=2) + self.kernel_posterior_fn = kernel_posterior_fn + self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn + self.kernel_prior_fn = kernel_prior_fn + self.kernel_divergence_fn = kernel_divergence_fn + self.bias_posterior_fn = bias_posterior_fn + self.bias_posterior_tensor_fn = bias_posterior_tensor_fn + self.bias_prior_fn = bias_prior_fn + self.bias_divergence_fn = bias_divergence_fn def build(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape) @@ -368,29 +172,29 @@ class DenseVariational(layers_lib.Layer): dtype = dtypes.as_dtype(self.dtype) # Must have a posterior kernel. - self.kernel.posterior = self.kernel.posterior_fn( + self.kernel_posterior = self.kernel_posterior_fn( dtype, [in_size, self.units], "kernel_posterior", self.trainable, self.add_variable) - if self.kernel.prior_fn is None: + if self.kernel_prior_fn is None: self.kernel_prior = None else: - self.kernel.prior = self.kernel.prior_fn( + self.kernel_prior = self.kernel_prior_fn( dtype, [in_size, self.units], "kernel_prior", self.trainable, self.add_variable) self._built_kernel_divergence = False - if self.bias.posterior_fn is None: - self.bias.posterior = None + if self.bias_posterior_fn is None: + self.bias_posterior = None else: - self.bias.posterior = self.bias.posterior_fn( + self.bias_posterior = self.bias_posterior_fn( dtype, [self.units], "bias_posterior", self.trainable, self.add_variable) - if self.bias.prior_fn is None: - self.bias.prior = None + if self.bias_prior_fn is None: + self.bias_prior = None else: - self.bias.prior = self.bias.prior_fn( + self.bias_prior = self.bias_prior_fn( dtype, [self.units], "bias_prior", self.trainable, self.add_variable) self._built_bias_divergence = False @@ -405,54 +209,53 @@ class DenseVariational(layers_lib.Layer): if self.activation is not None: outputs = self.activation(outputs) # pylint: disable=not-callable if not self._built_kernel_divergence: - self._apply_divergence(self.kernel, name="divergence_kernel") + kernel_posterior = self.kernel_posterior + kernel_prior = self.kernel_prior + if isinstance(self.kernel_posterior, independent_lib.Independent): + kernel_posterior = kernel_posterior.distribution + if isinstance(self.kernel_prior, independent_lib.Independent): + kernel_prior = kernel_prior.distribution + self._apply_divergence(self.kernel_divergence_fn, + kernel_posterior, + kernel_prior, + self.kernel_posterior_tensor, + name="divergence_kernel") self._built_kernel_divergence = True if not self._built_bias_divergence: - self._apply_divergence(self.bias, name="divergence_bias") + bias_posterior = self.bias_posterior + bias_prior = self.bias_prior + if isinstance(self.bias_posterior, independent_lib.Independent): + bias_posterior = bias_posterior.distribution + if isinstance(self.bias_prior, independent_lib.Independent): + bias_prior = bias_prior.distribution + self._apply_divergence(self.bias_divergence_fn, + bias_posterior, + bias_prior, + self.bias_posterior_tensor, + name="divergence_bias") self._built_bias_divergence = True return outputs - def _apply_variational_kernel(self, inputs): - if not self.kernel_use_local_reparameterization: - self.kernel.posterior_tensor = self.kernel.posterior_tensor_fn( - self.kernel.posterior) - self.kernel.posterior_affine = None - self.kernel.posterior_affine_tensor = None - return self._matmul(inputs, self.kernel.posterior_tensor) - if not isinstance(self.kernel.posterior, normal_lib.Normal): - raise TypeError("`kernel_use_local_reparameterization=True` requires " - "`kernel_posterior_fn` produce an instance of " - "`tf.distributions.Normal` (saw: \"{}\").".format( - type(self.kernel.posterior).__name__)) - self.kernel.posterior_affine = normal_lib.Normal( - loc=self._matmul(inputs, self.kernel.posterior.loc), - scale=standard_ops.sqrt(self._matmul( - standard_ops.square(inputs), - standard_ops.square(self.kernel.posterior.scale)))) - self.kernel.posterior_affine_tensor = ( - self.kernel.posterior_tensor_fn(self.kernel.posterior_affine)) - self.kernel.posterior_tensor = None - return self.kernel.posterior_affine_tensor - def _apply_variational_bias(self, inputs): - if self.bias.posterior is None: - self.bias.posterior_tensor = None + if self.bias_posterior is None: + self.bias_posterior_tensor = None return inputs - self.bias.posterior_tensor = self.bias.posterior_tensor_fn( - self.bias.posterior) - return nn.bias_add(inputs, self.bias.posterior_tensor) - - def _apply_divergence(self, param, name): - if (param.divergence_fn is None or - param.posterior is None or - param.prior is None): - param.divergence = None + self.bias_posterior_tensor = self.bias_posterior_tensor_fn( + self.bias_posterior) + return nn.bias_add(inputs, self.bias_posterior_tensor) + + def _apply_divergence(self, divergence_fn, posterior, prior, + posterior_tensor, name): + if (divergence_fn is None or + posterior is None or + prior is None): + divergence = None return - param.divergence = standard_ops.identity( - param.divergence_fn( - param.posterior, param.prior, param.posterior_tensor), + divergence = standard_ops.identity( + divergence_fn( + posterior, prior, posterior_tensor), name=name) - self.add_loss(param.divergence) + self.add_loss(divergence) def _matmul(self, inputs, kernel): if inputs.shape.ndims <= 2: @@ -469,57 +272,467 @@ class DenseVariational(layers_lib.Layer): return input_shape[:-1].concatenate(self.units) -def dense_variational( +class DenseReparameterization(_DenseVariational): + """Densely-connected layer class with reparameterization estimator. + + This layer implements the Bayesian variational inference analogue to + a dense layer by assuming the `kernel` and/or the `bias` are drawn + from distributions. By default, the layer implements a stochastic + forward pass via sampling from the kernel and bias posteriors, + + ```none + kernel, bias ~ posterior + outputs = activation(matmul(inputs, kernel) + bias) + ``` + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Args: + units: Integer or Long, dimensionality of the output space. + activation: Activation function (`callable`). Set it to None to maintain a + linear activation. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + name: Python `str`, the name of the layer. Layers with the same name will + share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in + such cases. + reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous + layer by the same name. + + Properties: + units: Python integer, dimensionality of the output space. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tfp.layers.DenseReparameterization( + 512, activation=tf.nn.relu)(features) + logits = tfp.layers.DenseReparameterization(10)(net) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses reparameterization gradients to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + """ + + def __init__( + self, + units, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn( + is_singular=True), + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + **kwargs): + super(DenseReparameterization, self).__init__( + units=units, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, + **kwargs) + + def _apply_variational_kernel(self, inputs): + self.kernel_posterior_tensor = self.kernel_posterior_tensor_fn( + self.kernel_posterior) + self.kernel_posterior_affine = None + self.kernel_posterior_affine_tensor = None + return self._matmul(inputs, self.kernel_posterior_tensor) + + +def dense_reparameterization( inputs, units, activation=None, activity_regularizer=None, trainable=True, - kernel_use_local_reparameterization=True, - kernel_posterior_fn=default_mean_field_normal_fn(), + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), kernel_posterior_tensor_fn=lambda d: d.sample(), kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=default_mean_field_normal_fn(is_singular=True), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long bias_posterior_tensor_fn=lambda d: d.sample(), bias_prior_fn=None, bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), name=None, reuse=None): - """Densely-connected variational layer. + """Densely-connected layer with reparameterization estimator. + + This layer implements the Bayesian variational inference analogue to + a dense layer by assuming the `kernel` and/or the `bias` are drawn + from distributions. By default, the layer implements a stochastic + forward pass via sampling from the kernel and bias posteriors, + + ```none + kernel, bias ~ posterior + outputs = activation(matmul(inputs, kernel) + bias) + ``` + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Args: + inputs: Tensor input. + units: Integer or Long, dimensionality of the output space. + activation: Activation function (`callable`). Set it to None to maintain a + linear activation. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + name: Python `str`, the name of the layer. Layers with the same name will + share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in + such cases. + reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous + layer by the same name. + + Returns: + output: `Tensor` representing a the affine transformed input under a random + draw from the surrogate posterior distribution. + + #### Examples - This layer implements the Bayesian variational inference analogue to: - `outputs = activation(matmul(inputs, kernel) + bias)` - by assuming the `kernel` and/or the `bias` are random variables. + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. - The layer implements a stochastic dense calculation by making a Monte Carlo - approximation of a [variational Bayesian method based on KL divergence]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), i.e., + ```python + tfp = tf.contrib.bayesflow + + net = tfp.layers.dense_reparameterization( + features, 512, activation=tf.nn.relu) + logits = tfp.layers.dense_reparameterization(net, 10) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses reparameterization gradients to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + """ + layer = DenseReparameterization( + units, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse) + return layer.apply(inputs) + + +class DenseLocalReparameterization(_DenseVariational): + """Densely-connected layer class with local reparameterization estimator. + + This layer implements the Bayesian variational inference analogue to + a dense layer by assuming the `kernel` and/or the `bias` are drawn + from distributions. By default, the layer implements a stochastic + forward pass via sampling from the kernel and bias posteriors, ```none - -log p(y|x) = -log int_{R**d} p(y|x,w) p(w) dw - = -log int_{R**d} p(y,w|x) q(w|x) / q(w|x) dw - <= E_q(W|x)[-log p(y,W|x) + log q(W|x)] # Jensen's - = E_q(W|x)[-log p(y|x,W)] + KL[q(W|x), p(W)] - ~= m**-1 sum{ -log(y|x,w[j]) : w[j] ~ q(W|x), j=1..m } - + KL[q(W|x), p(W)] + kernel, bias ~ posterior + outputs = activation(matmul(inputs, kernel) + bias) + ``` + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Args: + units: Integer or Long, dimensionality of the output space. + activation: Activation function (`callable`). Set it to None to maintain a + linear activation. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + name: Python `str`, the name of the layer. Layers with the same name will + share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in + such cases. + reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous + layer by the same name. + + Properties: + units: Python integer, dimensionality of the output space. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tfp.layers.DenseLocalReparameterization( + 512, activation=tf.nn.relu)(features) + logits = tfp.layers.DenseLocalReparameterization(10)(net) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) ``` - where `W` denotes the (independent) `kernel` and `bias` random variables, `w` - is a random variate or outcome of `W`, `y` is the label, `x` is the evidence`, - and `~=` denotes an approximation which becomes exact as `m->inf`. The above - bound is sometimes referred to as the negative Evidence Lower BOund or - negative [ELBO](https://arxiv.org/abs/1601.00670). In context of a DNN, this - layer is appropriate to use when the final loss is a negative log-likelihood. + It uses local reparameterization gradients to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + """ + + def __init__( + self, + units, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn( + is_singular=True), + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + **kwargs): + super(DenseLocalReparameterization, self).__init__( + units=units, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, + **kwargs) + + def _apply_variational_kernel(self, inputs): + if (not isinstance(self.kernel_posterior, independent_lib.Independent) or + not isinstance(self.kernel_posterior.distribution, normal_lib.Normal)): + raise TypeError( + "`DenseLocalReparameterization` requires " + "`kernel_posterior_fn` produce an instance of " + "`tf.distributions.Independent(tf.distributions.Normal)` " + "(saw: \"{}\").".format(type(self.kernel_posterior).__name__)) + self.kernel_posterior_affine = normal_lib.Normal( + loc=self._matmul(inputs, self.kernel_posterior.distribution.loc), + scale=standard_ops.sqrt(self._matmul( + standard_ops.square(inputs), + standard_ops.square(self.kernel_posterior.distribution.scale)))) + self.kernel_posterior_affine_tensor = ( + self.kernel_posterior_tensor_fn(self.kernel_posterior_affine)) + self.kernel_posterior_tensor = None + return self.kernel_posterior_affine_tensor - The Monte-Carlo sum portion is used for the feed-forward calculation of the - DNN. The KL divergence portion can be added to the final loss via: - `loss += sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))`. + +def dense_local_reparameterization( + inputs, + units, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn( + is_singular=True), + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + reuse=None): + """Densely-connected layer with local reparameterization estimator. + + This layer implements the Bayesian variational inference analogue to + a dense layer by assuming the `kernel` and/or the `bias` are drawn + from distributions. By default, the layer implements a stochastic + forward pass via sampling from the kernel and bias posteriors, + + ```none + kernel, bias ~ posterior + outputs = activation(matmul(inputs, kernel) + bias) + ``` The arguments permit separate specification of the surrogate posterior (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - random variables (which together comprise `W`). + distributions. Args: inputs: Tensor input. @@ -529,10 +742,6 @@ def dense_variational( activity_regularizer: Regularizer function for the output. trainable: Boolean, if `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_use_local_reparameterization: Python `bool` indicating whether - `kernel` calculation should employ the Local Reparameterization Trick. - When `True`, `kernel_posterior_fn` must create an instance of - `tf.distributions.Normal`. kernel_posterior_fn: Python `callable` which creates `tf.distributions.Distribution` instance representing the surrogate posterior of the `kernel` parameter. Default value: @@ -574,14 +783,38 @@ def dense_variational( Returns: output: `Tensor` representing a the affine transformed input under a random draw from the surrogate posterior distribution. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tfp.layers.dense_local_reparameterization( + features, 512, activation=tf.nn.relu) + logits = tfp.layers.dense_local_reparameterization(net, 10) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses local reparameterization gradients to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. """ - layer = DenseVariational( + layer = DenseLocalReparameterization( units, activation=activation, activity_regularizer=activity_regularizer, trainable=trainable, - kernel_use_local_reparameterization=( - kernel_use_local_reparameterization), kernel_posterior_fn=kernel_posterior_fn, kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, kernel_prior_fn=kernel_prior_fn, @@ -597,201 +830,317 @@ def dense_variational( return layer.apply(inputs) -class NotSet(object): - """Helper to track whether a `VariationalParameter` value has been set.""" - pass +class DenseFlipout(_DenseVariational): + """Densely-connected layer class with Flipout estimator. + This layer implements the Bayesian variational inference analogue to + a dense layer by assuming the `kernel` and/or the `bias` are drawn + from distributions. By default, the layer implements a stochastic + forward pass via sampling from the kernel and bias posteriors, -class VariationalParameter(object): - """Struct-like container of variational parameter properties. + ```none + kernel, bias ~ posterior + outputs = activation(matmul(inputs, kernel) + bias) + ``` + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. - A `VariationalParameter` is intitialized with Python `callable`s which set the - value of correspondingly named members. Corresponding values have "set once" - semantics, i.e., once set to any value they are immutable. + Args: + units: Integer or Long, dimensionality of the output space. + activation: Activation function (`callable`). Set it to None to maintain a + linear activation. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + seed: Python scalar `int` which initializes the random number + generator. Default value: `None` (i.e., use global seed). + name: Python `str`, the name of the layer. Layers with the same name will + share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in + such cases. + reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous + layer by the same name. + + Properties: + units: Python integer, dimensionality of the output space. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + seed: Python integer, used to create random seeds. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tfp.layers.DenseFlipout( + 512, activation=tf.nn.relu)(features) + logits = tfp.layers.DenseFlipout(10)(net) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses the Flipout gradient estimator to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. """ def __init__( self, - posterior_fn, - posterior_tensor_fn, - prior_fn, - divergence_fn): - """Creates the `VariationalParameter` struct-like object. - - Args: - posterior_fn: Python `callable` which creates a - `tf.distribution.Distribution` like object representing the posterior - distribution. See `VariationalParameter.posterior_fn` for `callable`'s - required parameters. - posterior_tensor_fn: Python `callable` which computes a `Tensor` - which represents the `posterior`. - prior_fn: Python `callable` which creates a - `tf.distribution.Distribution` like object representing the prior - distribution. See `VariationalParameter.prior_fn` for `callable`'s - required parameters. - divergence_fn: Python `callable` which computes the KL divergence from - `posterior` to `prior`. See `VariationalParameter.divergence_fn` for - required `callable`'s parameters. - """ - self._posterior_fn = posterior_fn - self._posterior = NotSet() - self._posterior_tensor_fn = posterior_tensor_fn - self._posterior_tensor = NotSet() - self._prior_fn = prior_fn - self._prior = NotSet() - self._divergence_fn = divergence_fn - self._divergence = NotSet() - self._init_helper() - - @property - def posterior_fn(self): - """`callable` which creates `tf.distributions.Distribution`-like posterior. - - The `callable` must accept the following parameters: - name: Python `str` name prepended to any created (or existing) - `tf.Variable`s. - shape: Python `list`-like representing the parameter's event shape. - dtype: Type of parameter's event. - trainable: Python `bool` indicating all created `tf.Variable`s should be - added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. - add_variable_fn: `tf.get_variable`-like `callable` used to create (or - access existing) `tf.Variable`s. - - Returns: - posterior_fn: The Python `callable` specified in `__init__`. - """ - return self._posterior_fn - - @property - def posterior(self): - """`tf.distributions.Distribution`-like instance representing posterior.""" - return self._posterior - - @posterior.setter - def posterior(self, value): - """One-time setter of the `posterior` distribution.""" - if not isinstance(self._posterior, NotSet): - raise ValueError("Cannot override already set attribute.") - self._posterior = value - - @property - def posterior_tensor_fn(self): - """Creates `Tensor` representing the `posterior` distribution. - - The `callable` must accept the following parameters: - posterior: `tf.distributions.Distribution`-like instance. - - Returns: - posterior_tensor_fn: The Python `callable` specified in - `__init__`. - """ - return self._posterior_tensor_fn - - @property - def posterior_tensor(self): - """`Tensor` representing the `posterior` distribution.""" - return self._posterior_tensor - - @posterior_tensor.setter - def posterior_tensor(self, value): - """One-time setter of the `posterior_tensor`.""" - if not isinstance(self._posterior_tensor, NotSet): - raise ValueError("Cannot override already set attribute.") - self._posterior_tensor = value - - @property - def prior_fn(self): - """`callable` which creates `tf.distributions.Distribution`-like prior. - - The `callable` must accept the following parameters: - name: Python `str` name prepended to any created (or existing) - `tf.Variable`s. - shape: Python `list`-like representing the parameter's event shape. - dtype: Type of parameter's event. - trainable: Python `bool` indicating all created `tf.Variable`s should be - added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. - add_variable_fn: `tf.get_variable`-like `callable` used to create (or - access existing) `tf.Variable`s. - - Returns: - prior_fn: The Python `callable` specified in `__init__`. - """ - return self._prior_fn - - @property - def prior(self): - """`tf.distributions.Distribution`-like instance representing posterior.""" - return self._prior - - @prior.setter - def prior(self, value): - """One-time setter of the `prior` distribution.""" - if not isinstance(self._prior, NotSet): - raise ValueError("Cannot override already set attribute.") - self._prior = value - - @property - def divergence_fn(self): - """`callable` which computes KL-divergence `Tensor` from posterior to prior. - - The `callable` must accept the following parameters: - posterior: `tf.distributions.Distribution`-like instance. - prior: `tf.distributions.Distribution`-like instance. - posterior_tensor: `Tensor` representing value of posterior. - - Returns: - divergence_fn: The Python `callable` specified in `__init__`. - """ - return self._divergence_fn - - @property - def divergence(self): - """`Tensor` representing KL-divergence from posterior to prior.""" - return self._divergence - - @divergence.setter - def divergence(self, value): - """One-time setter of the `divergence`.""" - if not isinstance(self._divergence, NotSet): - raise ValueError("Cannot override already set attribute.") - self._divergence = value - - def _init_helper(self): - pass - - -class VariationalKernelParameter(VariationalParameter): - """Struct-like container of variational kernel properties. - - A `VariationalKernelParameter` is intitialized with Python `callable`s which - set the value of correspondingly named members. Corresponding values have "set - once" semantics, i.e., once set to any value they are immutable. + units, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn( + is_singular=True), + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + seed=None, + name=None, + **kwargs): + super(DenseFlipout, self).__init__( + units=units, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, + **kwargs) + self.seed = seed + + def _apply_variational_kernel(self, inputs): + if (not isinstance(self.kernel_posterior, independent_lib.Independent) or + not isinstance(self.kernel_posterior.distribution, normal_lib.Normal)): + raise TypeError( + "`DenseFlipout` requires " + "`kernel_posterior_fn` produce an instance of " + "`tf.distributions.Independent(tf.distributions.Normal)` " + "(saw: \"{}\").".format(type(self.kernel_posterior).__name__)) + self.kernel_posterior_affine = normal_lib.Normal( + loc=array_ops.zeros_like(self.kernel_posterior.distribution.loc), + scale=self.kernel_posterior.distribution.scale) + self.kernel_posterior_affine_tensor = ( + self.kernel_posterior_tensor_fn(self.kernel_posterior_affine)) + self.kernel_posterior_tensor = None + + input_shape = array_ops.shape(inputs) + batch_shape = input_shape[:-1] + + sign_input = random_sign(input_shape, dtype=inputs.dtype, seed=self.seed) + sign_output = random_sign( + array_ops.concat([batch_shape, + array_ops.expand_dims(self.units, 0)], 0), + dtype=inputs.dtype, + seed=distribution_util.gen_new_seed( + self.seed, salt="dense_flipout")) + perturbed_inputs = self._matmul( + inputs * sign_input, self.kernel_posterior_affine_tensor) * sign_output + + outputs = self._matmul(inputs, self.kernel_posterior.distribution.loc) + outputs += perturbed_inputs + return outputs + + +def dense_flipout( + inputs, + units, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn( + is_singular=True), + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + seed=None, + name=None, + reuse=None): + """Densely-connected layer with Flipout estimator. + + This layer implements the Bayesian variational inference analogue to + a dense layer by assuming the `kernel` and/or the `bias` are drawn + from distributions. By default, the layer implements a stochastic + forward pass via sampling from the kernel and bias posteriors, + + ```none + kernel, bias ~ posterior + outputs = activation(matmul(inputs, kernel) + bias) + ``` + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Args: + inputs: Tensor input. + units: Integer or Long, dimensionality of the output space. + activation: Activation function (`callable`). Set it to None to maintain a + linear activation. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + seed: Python scalar `int` which initializes the random number + generator. Default value: `None` (i.e., use global seed). + name: Python `str`, the name of the layer. Layers with the same name will + share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in + such cases. + reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous + layer by the same name. + + Returns: + output: `Tensor` representing a the affine transformed input under a random + draw from the surrogate posterior distribution. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tfp.layers.dense_flipout( + features, 512, activation=tf.nn.relu) + logits = tfp.layers.dense_flipout(net, 10) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses the Flipout gradient estimator to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. """ + layer = DenseFlipout( + units, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + seed=seed, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse) + return layer.apply(inputs) + - @property - def posterior_affine(self): - """`tf.distributions.Distribution` affine transformed posterior.""" - return self._posterior_affine - - @posterior_affine.setter - def posterior_affine(self, value): - """One-time setter of `posterior_affine`.""" - if not isinstance(self._posterior_affine, NotSet): - raise ValueError("Cannot override already set attribute.") - self._posterior_affine = value - - @property - def posterior_affine_tensor(self): - """`Tensor` representing the `posterior_affine` distribution.""" - return self._posterior_affine_tensor - - @posterior_affine_tensor.setter - def posterior_affine_tensor(self, value): - """One-time setter of the `posterior_affine_tensor`.""" - if not isinstance(self._posterior_affine_tensor, NotSet): - raise ValueError("Cannot override already set attribute.") - self._posterior_affine_tensor = value - - def _init_helper(self): - self._posterior_affine = NotSet() - self._posterior_affine_tensor = NotSet() +def random_sign(shape, dtype=dtypes.float32, seed=None): + """Draw values from {-1, 1} uniformly, i.e., Rademacher distribution.""" + random_bernoulli = random_ops.random_uniform(shape, minval=0, maxval=2, + dtype=dtypes.int32, + seed=seed) + return math_ops.cast(2 * random_bernoulli - 1, dtype) diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_util.py b/tensorflow/contrib/bayesflow/python/ops/layers_util.py new file mode 100644 index 0000000000000000000000000000000000000000..9a4fecf4e5dcb1e1008303b07b4f76d5e5ce557f --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/layers_util.py @@ -0,0 +1,180 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for probabilistic layers. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import deterministic as deterministic_lib +from tensorflow.contrib.distributions.python.ops import independent as independent_lib +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops.distributions import normal as normal_lib + + +def default_loc_scale_fn( + is_singular=False, + loc_initializer=init_ops.random_normal_initializer(stddev=0.1), + untransformed_scale_initializer=init_ops.random_normal_initializer( + mean=-3., stddev=0.1), + loc_regularizer=None, + untransformed_scale_regularizer=None, + loc_constraint=None, + untransformed_scale_constraint=None): + """Makes closure which creates `loc`, `scale` params from `tf.get_variable`. + + This function produces a closure which produces `loc`, `scale` using + `tf.get_variable`. The closure accepts the following arguments: + + dtype: Type of parameter's event. + shape: Python `list`-like representing the parameter's event shape. + name: Python `str` name prepended to any created (or existing) + `tf.Variable`s. + trainable: Python `bool` indicating all created `tf.Variable`s should be + added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. + add_variable_fn: `tf.get_variable`-like `callable` used to create (or + access existing) `tf.Variable`s. + + Args: + is_singular: Python `bool` indicating if `scale is None`. Default: `False`. + loc_initializer: Initializer function for the `loc` parameters. + The default is `tf.random_normal_initializer(mean=0., stddev=0.1)`. + untransformed_scale_initializer: Initializer function for the `scale` + parameters. Default value: `tf.random_normal_initializer(mean=-3., + stddev=0.1)`. This implies the softplus transformed result has mean + approximately `0.05` and std. deviation approximately `0.005`. + loc_regularizer: Regularizer function for the `loc` parameters. + The default (`None`) is to use the `tf.get_variable` default. + untransformed_scale_regularizer: Regularizer function for the `scale` + parameters. The default (`None`) is to use the `tf.get_variable` default. + loc_constraint: An optional projection function to be applied to the + loc after being updated by an `Optimizer`. The function must take as input + the unprojected variable and must return the projected variable (which + must have the same shape). Constraints are not safe to use when doing + asynchronous distributed training. + The default (`None`) is to use the `tf.get_variable` default. + untransformed_scale_constraint: An optional projection function to be + applied to the `scale` parameters after being updated by an `Optimizer` + (e.g. used to implement norm constraints or value constraints). The + function must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are not + safe to use when doing asynchronous distributed training. The default + (`None`) is to use the `tf.get_variable` default. + + Returns: + default_loc_scale_fn: Python `callable` which instantiates `loc`, `scale` + parameters from args: `dtype, shape, name, trainable, add_variable_fn`. + """ + def _fn(dtype, shape, name, trainable, add_variable_fn): + """Creates `loc`, `scale` parameters.""" + loc = add_variable_fn( + name=name + "_loc", + shape=shape, + initializer=loc_initializer, + regularizer=loc_regularizer, + constraint=loc_constraint, + dtype=dtype, + trainable=trainable) + if is_singular: + return loc, None + untransformed_scale = add_variable_fn( + name=name + "_untransformed_scale", + shape=shape, + initializer=untransformed_scale_initializer, + regularizer=untransformed_scale_regularizer, + constraint=untransformed_scale_constraint, + dtype=dtype, + trainable=trainable) + scale = (np.finfo(dtype.as_numpy_dtype).eps + + nn_ops.softplus(untransformed_scale)) + return loc, scale + return _fn + + +def default_mean_field_normal_fn( + is_singular=False, + loc_initializer=None, + untransformed_scale_initializer=None, + loc_regularizer=None, + untransformed_scale_regularizer=None, + loc_constraint=None, + untransformed_scale_constraint=None): + """Creates a function to build Normal distributions with trainable params. + + This function produces a closure which produces `tf.distributions.Normal` + parameterized by a loc` and `scale` each created using `tf.get_variable`. The + produced closure accepts the following arguments: + + name: Python `str` name prepended to any created (or existing) + `tf.Variable`s. + shape: Python `list`-like representing the parameter's event shape. + dtype: Type of parameter's event. + trainable: Python `bool` indicating all created `tf.Variable`s should be + added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. + add_variable_fn: `tf.get_variable`-like `callable` used to create (or + access existing) `tf.Variable`s. + + Args: + is_singular: Python `bool` if `True`, forces the special case limit of + `scale->0`, i.e., a `Deterministic` distribution. + loc_initializer: Initializer function for the `loc` parameters. + If `None` (default), values are initialized using the default + initializer used by `tf.get_variable`. + untransformed_scale_initializer: Initializer function for the `scale` + parameters. If `None` (default), values are initialized using the default + initializer used by `tf.get_variable`. + loc_regularizer: Regularizer function for the `loc` parameters. + untransformed_scale_regularizer: Regularizer function for the `scale` + parameters. + loc_constraint: An optional projection function to be applied to the + loc after being updated by an `Optimizer`. The function must take as input + the unprojected variable and must return the projected variable (which + must have the same shape). Constraints are not safe to use when doing + asynchronous distributed training. + untransformed_scale_constraint: An optional projection function to be + applied to the `scale` parameters after being updated by an `Optimizer` + (e.g. used to implement norm constraints or value constraints). The + function must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are not + safe to use when doing asynchronous distributed training. + + Returns: + make_normal_fn: Python `callable` which creates a `tf.distributions.Normal` + using from args: `dtype, shape, name, trainable, add_variable_fn`. + """ + loc_scale_fn_ = default_loc_scale_fn( + is_singular, + loc_initializer, + untransformed_scale_initializer, + loc_regularizer, + untransformed_scale_regularizer, + loc_constraint, + untransformed_scale_constraint) + def _fn(dtype, shape, name, trainable, add_variable_fn): + """Creates multivariate `Deterministic` or `Normal` distribution.""" + loc, scale = loc_scale_fn_(dtype, shape, name, trainable, add_variable_fn) + if scale is None: + dist = deterministic_lib.Deterministic(loc=loc) + else: + dist = normal_lib.Normal(loc=loc, scale=scale) + reinterpreted_batch_ndims = array_ops.shape(dist.batch_shape_tensor())[0] + return independent_lib.Independent( + dist, reinterpreted_batch_ndims=reinterpreted_batch_ndims) + return _fn diff --git a/tensorflow/contrib/boosted_trees/BUILD b/tensorflow/contrib/boosted_trees/BUILD index 66a04d42e93331de74b6f3d41f83f071115c1097..392ac7fa1ce600a64ee3b941b70b01447645e4aa 100644 --- a/tensorflow/contrib/boosted_trees/BUILD +++ b/tensorflow/contrib/boosted_trees/BUILD @@ -359,8 +359,8 @@ tf_custom_op_library( ], deps = [ "//tensorflow/contrib/boosted_trees/lib:example_partitioner", - "//tensorflow/contrib/boosted_trees/lib:feature-column-handlers", "//tensorflow/contrib/boosted_trees/lib:models", + "//tensorflow/contrib/boosted_trees/lib:node-stats", "//tensorflow/contrib/boosted_trees/lib:utils", "//tensorflow/contrib/boosted_trees/lib:weighted_quantiles", "//tensorflow/contrib/boosted_trees/proto:learner_proto_cc", @@ -404,10 +404,12 @@ tf_kernel_library( name = "split_handler_ops_kernels", srcs = ["kernels/split_handler_ops.cc"], deps = [ - "//tensorflow/contrib/boosted_trees/lib:feature-column-handlers", + "//tensorflow/contrib/boosted_trees/lib:node-stats", "//tensorflow/contrib/boosted_trees/proto:split_info_proto_cc", "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc", "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:protos_all_cc", + "//third_party/eigen3", ], alwayslink = 1, ) @@ -599,6 +601,7 @@ py_library( ":init_py", "//tensorflow/contrib/boosted_trees:gbdt_batch", "//tensorflow/contrib/boosted_trees/estimator_batch:custom_export_strategy", + "//tensorflow/contrib/boosted_trees/estimator_batch:dnn_tree_combined_estimator", "//tensorflow/contrib/boosted_trees/estimator_batch:init_py", "//tensorflow/contrib/boosted_trees/estimator_batch:trainer_hooks", "//tensorflow/contrib/boosted_trees/lib:categorical_split_handler", diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD index 7792c7127c0285dc2eb5b213da054674f6a81d64..48084d80167cc5c300ae62eaeac53c622dfce2a3 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD +++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD @@ -50,6 +50,7 @@ py_library( deps = [ "//tensorflow/contrib/learn", "//tensorflow/core:protos_all_py", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:platform", "//tensorflow/python:training", @@ -129,3 +130,33 @@ py_library( "//tensorflow/python:math_ops", ], ) + +py_library( + name = "dnn_tree_combined_estimator", + srcs = ["dnn_tree_combined_estimator.py"], + srcs_version = "PY2AND3", + deps = [ + ":trainer_hooks", + "//tensorflow/contrib/boosted_trees:gbdt_batch", + "//tensorflow/contrib/boosted_trees:model_ops_py", + "//tensorflow/contrib/learn", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + ], +) + +py_test( + name = "dnn_tree_combined_estimator_test", + size = "small", + srcs = ["dnn_tree_combined_estimator_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dnn_tree_combined_estimator", + "//tensorflow/contrib/boosted_trees:gbdt_batch", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + ], +) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py index ef8dee91b6cc05c4c3dd5eb3c81de4fb65b473e3..6ebc7d7911df878ec91701db8b75feb9a27d18a2 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py @@ -33,6 +33,8 @@ from tensorflow.python.platform import gfile from tensorflow.python.saved_model import loader as saved_model_loader from tensorflow.python.saved_model import tag_constants +_SPARSE_FLOAT_FEATURE_NAME_TEMPLATE = "%s_%d" + def make_custom_export_strategy(name, convert_fn, @@ -147,13 +149,12 @@ def convert_to_universal_format(dtec, sorted_feature_names, inequality_test.threshold.float_value = split.threshold elif node_type == "sparse_float_binary_split_default_left": split = gtflow_node.sparse_float_binary_split_default_left.split - node.default_direction = ( - generic_tree_model_pb2.BinaryNode.LEFT) - # TODO(nponomareva): adjust this id assignement when we allow multi- - # column sparse tensors. + node.default_direction = (generic_tree_model_pb2.BinaryNode.LEFT) feature_id = split.feature_column + num_dense inequality_test = node.inequality_left_child_test - inequality_test.feature_id.id.value = sorted_feature_names[feature_id] + inequality_test.feature_id.id.value = ( + _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % + (sorted_feature_names[feature_id], split.dimension_id)) inequality_test.type = ( generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL) inequality_test.threshold.float_value = split.threshold @@ -165,7 +166,9 @@ def convert_to_universal_format(dtec, sorted_feature_names, # column sparse tensors. feature_id = split.feature_column + num_dense inequality_test = node.inequality_left_child_test - inequality_test.feature_id.id.value = sorted_feature_names[feature_id] + inequality_test.feature_id.id.value = ( + _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % + (sorted_feature_names[feature_id], split.dimension_id)) inequality_test.type = ( generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL) inequality_test.threshold.float_value = split.threshold @@ -201,10 +204,14 @@ def _get_feature_importances(dtec, feature_names, num_dense_floats, split_column = feature_names[split.feature_column] elif node_type == "sparse_float_binary_split_default_left": split = tree_node.sparse_float_binary_split_default_left.split - split_column = feature_names[split.feature_column + num_dense_floats] + split_column = _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % ( + feature_names[split.feature_column + num_dense_floats], + split.dimension_id) elif node_type == "sparse_float_binary_split_default_right": split = tree_node.sparse_float_binary_split_default_right.split - split_column = feature_names[split.feature_column + num_dense_floats] + split_column = _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % ( + feature_names[split.feature_column + num_dense_floats], + split.dimension_id) elif node_type == "categorical_id_binary_split": split = tree_node.categorical_id_binary_split split_column = feature_names[split.feature_column + num_dense_floats + diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py index 4ed18b2d34c5af47826ab1c058f5d13797593bd4..492d9ca40c5cfa84e186020605429aacc02af6a6 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the conversion code from GTFlow format to Chauffeur.""" +"""Tests for the conversion code and for feature importances export. + +Tests that cover conversion from TFBT format to a tensorflow.contrib. +decision_tree generic_tree_model format and feature importances export. +""" from __future__ import absolute_import from __future__ import division @@ -95,10 +99,31 @@ class ConvertModelTest(test_util.TensorFlowTestCase): } } } + nodes { + sparse_float_binary_split_default_right { + split { + feature_column: 1 + dimension_id:3 + threshold: -0.4 + left_id: 7 + right_id: 8 + } + } + node_metadata { + gain: 3600 + } + } + nodes { + leaf { + vector { + value: 0.36 + } + } + } nodes { leaf { vector { - value: 0.3 + value: 18 } } } @@ -108,17 +133,25 @@ class ConvertModelTest(test_util.TensorFlowTestCase): """ dtec = tree_config_pb2.DecisionTreeEnsembleConfig() text_format.Merge(dtec_str, dtec) - feature_columns = ["feature_b", "feature_a", "feature_d"] + feature_columns = [ + "feature_b", + "feature_a", + "feature_a_m", + "feature_d", + ] return dtec, feature_columns def testConvertModel(self): dtec, feature_columns = self._make_trees() + # Assume 2 sparse float columns, one with 1 dimension, the second one with + # 5 dimensions. # The feature columns in the order they were added. out = custom_export_strategy.convert_to_universal_format( - dtec, feature_columns, 1, 1, - 1) + dtec, feature_columns, 1, 2, 1) + # Features a and a_m are sparse float features, a_m is multidimensional. expected_tree = """ features { key: "feature_a" } + features { key: "feature_a_m" } features { key: "feature_b" } features { key: "feature_d" } model { @@ -169,7 +202,6 @@ class ConvertModelTest(test_util.TensorFlowTestCase): } } } - nodes { node_id { value: 1 @@ -196,7 +228,7 @@ class ConvertModelTest(test_util.TensorFlowTestCase): inequality_left_child_test { feature_id { id { - value: "feature_a" + value: "feature_a_0" } } threshold { @@ -259,14 +291,51 @@ class ConvertModelTest(test_util.TensorFlowTestCase): node_id { value: 6 } + binary_node { + left_child_id { + value: 7 + } + right_child_id { + value: 8 + } + default_direction: RIGHT + inequality_left_child_test { + feature_id { + id { + value: "feature_a_m_3" + } + } + threshold { + float_value: -0.4 + } + } + } + } + nodes { + node_id { + value: 7 + } leaf { vector { value { - float_value: 0.03 + float_value: 0.036 } } } } + nodes { + node_id { + value: 8 + } + leaf { + vector { + value { + float_value: 1.8 + } + } + } + } + } } submodel_id { @@ -280,12 +349,15 @@ class ConvertModelTest(test_util.TensorFlowTestCase): def testFeatureImportance(self): dtec, feature_columns = self._make_trees() feature_importances = custom_export_strategy._get_feature_importances( - dtec, feature_columns, 1, 1, 1) - self.assertItemsEqual(["feature_b", "feature_a", "feature_d"], - feature_importances.keys()) + dtec, feature_columns, 1, 2, 1) + self.assertItemsEqual( + ["feature_b", "feature_a_0", "feature_a_m_3", "feature_d"], + feature_importances.keys()) self.assertAlmostEqual(50.0, feature_importances["feature_b"], places=4) - self.assertAlmostEqual(50.0, feature_importances["feature_a"], places=4) + self.assertAlmostEqual(50.0, feature_importances["feature_a_0"], places=4) self.assertAlmostEqual(50.0, feature_importances["feature_d"], places=4) + self.assertAlmostEqual( + 360.0, feature_importances["feature_a_m_3"], places=4) if __name__ == "__main__": diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..cec3892b57655dc967b4e7926f7f5a6a30084487 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py @@ -0,0 +1,515 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TensorFlow estimators for combined DNN + GBDT training model. + +The combined model trains a DNN first, then trains boosted trees to boost the +logits of the DNN. The input layer of the DNN (including the embeddings learned +over sparse features) can optionally be provided to the boosted trees as +an additional input feature. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.contrib import layers +from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks +from tensorflow.contrib.boosted_trees.python.ops import model_ops +from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch +from tensorflow.contrib.layers.python.layers import optimizers +from tensorflow.contrib.learn.python.learn.estimators import estimator +from tensorflow.contrib.learn.python.learn.estimators import head as head_lib +from tensorflow.contrib.learn.python.learn.estimators import model_fn +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.summary import summary +from tensorflow.python.training import training_util + + +_DNN_LEARNING_RATE = 0.001 + + +def _get_optimizer(optimizer): + if callable(optimizer): + return optimizer() + else: + return optimizer + + +def _add_hidden_layer_summary(value, tag): + summary.scalar("%s_fraction_of_zero_values" % tag, nn.zero_fraction(value)) + summary.histogram("%s_activation" % tag, value) + + +def _dnn_tree_combined_model_fn( + features, labels, mode, head, dnn_hidden_units, + dnn_feature_columns, tree_learner_config, num_trees, + tree_examples_per_layer, + config=None, dnn_optimizer="Adagrad", + dnn_activation_fn=nn.relu, dnn_dropout=None, + dnn_input_layer_partitioner=None, + dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, + tree_feature_columns=None, + tree_center_bias=True): + """DNN and GBDT combined model_fn. + + Args: + features: `dict` of `Tensor` objects. + labels: Labels used to train on. + mode: Mode we are in. (TRAIN/EVAL/INFER) + head: A `Head` instance. + dnn_hidden_units: List of hidden units per layer. + dnn_feature_columns: An iterable containing all the feature columns + used by the model's DNN. + tree_learner_config: A config for the tree learner. + num_trees: Number of trees to grow model to after training DNN. + tree_examples_per_layer: Number of examples to accumulate before + growing the tree a layer. This value has a big impact on model + quality and should be set equal to the number of examples in + training dataset if possible. It can also be a function that computes + the number of examples based on the depth of the layer that's + being built. + config: `RunConfig` of the estimator. + dnn_optimizer: string, `Optimizer` object, or callable that defines the + optimizer to use for training the DNN. If `None`, will use the Adagrad + optimizer with default learning rate of 0.001. + dnn_activation_fn: Activation function applied to each layer of the DNN. + If `None`, will use `tf.nn.relu`. + dnn_dropout: When not `None`, the probability to drop out a given + unit in the DNN. + dnn_input_layer_partitioner: Partitioner for input layer of the DNN. + Defaults to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. + dnn_input_layer_to_tree: Whether to provide the DNN's input layer + as a feature to the tree. + dnn_steps_to_train: Number of steps to train dnn for before switching + to gbdt. + tree_feature_columns: An iterable containing all the feature columns + used by the model's boosted trees. If dnn_input_layer_to_tree is + set to True, these features are in addition to dnn_feature_columns. + tree_center_bias: Whether a separate tree should be created for + first fitting the bias. + + Returns: + A `ModelFnOps` object. + Raises: + ValueError: if inputs are not valid. + """ + if not isinstance(features, dict): + raise ValueError("features should be a dictionary of `Tensor`s. " + "Given type: {}".format(type(features))) + + if not dnn_feature_columns: + raise ValueError("dnn_feature_columns must be specified") + + # Build DNN Logits. + dnn_parent_scope = "dnn" + dnn_partitioner = dnn_input_layer_partitioner or ( + partitioned_variables.min_max_variable_partitioner( + max_partitions=config.num_ps_replicas, + min_slice_size=64 << 20)) + + with variable_scope.variable_scope( + dnn_parent_scope, + values=tuple(six.itervalues(features)), + partitioner=dnn_partitioner): + + with variable_scope.variable_scope( + "input_from_feature_columns", + values=tuple(six.itervalues(features)), + partitioner=dnn_partitioner) as input_layer_scope: + input_layer = layers.input_from_feature_columns( + columns_to_tensors=features, + feature_columns=dnn_feature_columns, + weight_collections=[dnn_parent_scope], + scope=input_layer_scope) + previous_layer = input_layer + for layer_id, num_hidden_units in enumerate(dnn_hidden_units): + with variable_scope.variable_scope( + "hiddenlayer_%d" % layer_id, + values=(previous_layer,)) as hidden_layer_scope: + net = layers.fully_connected( + previous_layer, + num_hidden_units, + activation_fn=dnn_activation_fn, + variables_collections=[dnn_parent_scope], + scope=hidden_layer_scope) + if dnn_dropout is not None and mode == model_fn.ModeKeys.TRAIN: + net = layers.dropout(net, keep_prob=(1.0 - dnn_dropout)) + _add_hidden_layer_summary(net, hidden_layer_scope.name) + previous_layer = net + with variable_scope.variable_scope( + "logits", + values=(previous_layer,)) as logits_scope: + dnn_logits = layers.fully_connected( + previous_layer, + head.logits_dimension, + activation_fn=None, + variables_collections=[dnn_parent_scope], + scope=logits_scope) + _add_hidden_layer_summary(dnn_logits, logits_scope.name) + + def _dnn_train_op_fn(loss): + """Returns the op to optimize the loss.""" + return optimizers.optimize_loss( + loss=loss, + global_step=training_util.get_global_step(), + learning_rate=_DNN_LEARNING_RATE, + optimizer=_get_optimizer(dnn_optimizer), + name=dnn_parent_scope, + variables=ops.get_collection( + ops.GraphKeys.TRAINABLE_VARIABLES, + scope=dnn_parent_scope), + # Empty summaries to prevent optimizers from logging training_loss. + summaries=[]) + + # Build Tree Logits. + global_step = training_util.get_global_step() + with ops.device(global_step.device): + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, + tree_ensemble_config="", # Initialize an empty ensemble. + name="ensemble_model") + + tree_features = features.copy() + if dnn_input_layer_to_tree: + tree_features["dnn_input_layer"] = input_layer + tree_feature_columns.append(layers.real_valued_column("dnn_input_layer")) + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=config.is_chief, + num_ps_replicas=config.num_ps_replicas, + ensemble_handle=ensemble_handle, + center_bias=tree_center_bias, + examples_per_layer=tree_examples_per_layer, + learner_config=tree_learner_config, + feature_columns=tree_feature_columns, + logits_dimension=head.logits_dimension, + features=tree_features) + + with ops.name_scope("gbdt"): + predictions_dict = gbdt_model.predict(mode) + tree_logits = predictions_dict["predictions"] + + def _tree_train_op_fn(loss): + """Returns the op to optimize the loss.""" + update_op = gbdt_model.train(loss, predictions_dict, labels) + with ops.control_dependencies( + [update_op]), (ops.colocate_with(global_step)): + update_op = state_ops.assign_add(global_step, 1).op + return update_op + + tree_train_logits = dnn_logits + tree_logits + + def _no_train_op_fn(loss): + """Returns a no-op.""" + del loss + return control_flow_ops.no_op() + + model_fn_ops = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_no_train_op_fn, + logits=tree_train_logits) + dnn_train_op = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_dnn_train_op_fn, + logits=dnn_logits).train_op + tree_train_op = head.create_model_fn_ops( + features=tree_features, + mode=mode, + labels=labels, + train_op_fn=_tree_train_op_fn, + logits=tree_train_logits).train_op + + if tree_center_bias: + num_trees += 1 + finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor() + + model_fn_ops.training_hooks.extend([ + trainer_hooks.SwitchTrainOp( + dnn_train_op, dnn_steps_to_train, tree_train_op), + trainer_hooks.StopAfterNTrees( + num_trees, attempted_trees, finalized_trees)]) + + return model_fn_ops + + +class DNNBoostedTreeCombinedClassifier(estimator.Estimator): + """A classifier that uses a combined DNN/GBDT model.""" + + def __init__(self, + dnn_hidden_units, + dnn_feature_columns, + tree_learner_config, + num_trees, + tree_examples_per_layer, + n_classes=2, + weight_column_name=None, + model_dir=None, + config=None, + label_name=None, + label_keys=None, + feature_engineering_fn=None, + dnn_optimizer="Adagrad", + dnn_activation_fn=nn.relu, + dnn_dropout=None, + dnn_input_layer_partitioner=None, + dnn_input_layer_to_tree=True, + dnn_steps_to_train=10000, + tree_feature_columns=None, + tree_center_bias=True): + """Initializes a DNNBoostedTreeCombinedClassifier instance. + + Args: + dnn_hidden_units: List of hidden units per layer for DNN. + dnn_feature_columns: An iterable containing all the feature columns + used by the model's DNN. + tree_learner_config: A config for the tree learner. + num_trees: Number of trees to grow model to after training DNN. + tree_examples_per_layer: Number of examples to accumulate before + growing the tree a layer. This value has a big impact on model + quality and should be set equal to the number of examples in + training dataset if possible. It can also be a function that computes + the number of examples based on the depth of the layer that's + being built. + n_classes: The number of label classes. + weight_column_name: The name of weight column. + model_dir: Directory for model exports. + config: `RunConfig` of the estimator. + label_name: String, name of the key in label dict. Can be null if label + is a tensor (single headed models). + label_keys: Optional list of strings with size `[n_classes]` defining the + label vocabulary. Only supported for `n_classes` > 2. + feature_engineering_fn: Feature engineering function. Takes features and + labels which are the output of `input_fn` and returns features and + labels which will be fed into the model. + dnn_optimizer: string, `Optimizer` object, or callable that defines the + optimizer to use for training the DNN. If `None`, will use the Adagrad + optimizer with default learning rate. + dnn_activation_fn: Activation function applied to each layer of the DNN. + If `None`, will use `tf.nn.relu`. + dnn_dropout: When not `None`, the probability to drop out a given + unit in the DNN. + dnn_input_layer_partitioner: Partitioner for input layer of the DNN. + Defaults to `min_max_variable_partitioner` with `min_slice_size` + 64 << 20. + dnn_input_layer_to_tree: Whether to provide the DNN's input layer + as a feature to the tree. + dnn_steps_to_train: Number of steps to train dnn for before switching + to gbdt. + tree_feature_columns: An iterable containing all the feature columns + used by the model's boosted trees. If dnn_input_layer_to_tree is + set to True, these features are in addition to dnn_feature_columns. + tree_center_bias: Whether a separate tree should be created for + first fitting the bias. + """ + head = head_lib.multi_class_head( + n_classes=n_classes, + label_name=label_name, + label_keys=label_keys, + weight_column_name=weight_column_name, + enable_centered_bias=False) + + def _model_fn(features, labels, mode, config): + return _dnn_tree_combined_model_fn( + features, labels, mode, head, dnn_hidden_units, dnn_feature_columns, + tree_learner_config, num_trees, tree_examples_per_layer, config, + dnn_optimizer, dnn_activation_fn, dnn_dropout, + dnn_input_layer_partitioner, dnn_input_layer_to_tree, + dnn_steps_to_train, + tree_feature_columns, tree_center_bias) + + super(DNNBoostedTreeCombinedClassifier, self).__init__( + model_fn=_model_fn, model_dir=model_dir, + config=config, feature_engineering_fn=feature_engineering_fn) + + +class DNNBoostedTreeCombinedRegressor(estimator.Estimator): + """A regressor that uses a combined DNN/GBDT model.""" + + def __init__(self, + dnn_hidden_units, + dnn_feature_columns, + tree_learner_config, + num_trees, + tree_examples_per_layer, + weight_column_name=None, + model_dir=None, + config=None, + label_name=None, + label_dimension=1, + feature_engineering_fn=None, + dnn_optimizer="Adagrad", + dnn_activation_fn=nn.relu, + dnn_dropout=None, + dnn_input_layer_partitioner=None, + dnn_input_layer_to_tree=True, + dnn_steps_to_train=10000, + tree_feature_columns=None, + tree_center_bias=True): + """Initializes a DNNBoostedTreeCombinedRegressor instance. + + Args: + dnn_hidden_units: List of hidden units per layer for DNN. + dnn_feature_columns: An iterable containing all the feature columns + used by the model's DNN. + tree_learner_config: A config for the tree learner. + num_trees: Number of trees to grow model to after training DNN. + tree_examples_per_layer: Number of examples to accumulate before + growing the tree a layer. This value has a big impact on model + quality and should be set equal to the number of examples in + training dataset if possible. It can also be a function that computes + the number of examples based on the depth of the layer that's + being built. + weight_column_name: The name of weight column. + model_dir: Directory for model exports. + config: `RunConfig` of the estimator. + label_name: String, name of the key in label dict. Can be null if label + is a tensor (single headed models). + label_dimension: Number of regression labels per example. This is the size + of the last dimension of the labels `Tensor` (typically, this has shape + `[batch_size, label_dimension]`). + feature_engineering_fn: Feature engineering function. Takes features and + labels which are the output of `input_fn` and returns features and + labels which will be fed into the model. + dnn_optimizer: string, `Optimizer` object, or callable that defines the + optimizer to use for training the DNN. If `None`, will use the Adagrad + optimizer with default learning rate. + dnn_activation_fn: Activation function applied to each layer of the DNN. + If `None`, will use `tf.nn.relu`. + dnn_dropout: When not `None`, the probability to drop out a given + unit in the DNN. + dnn_input_layer_partitioner: Partitioner for input layer of the DNN. + Defaults to `min_max_variable_partitioner` with `min_slice_size` + 64 << 20. + dnn_input_layer_to_tree: Whether to provide the DNN's input layer + as a feature to the tree. + dnn_steps_to_train: Number of steps to train dnn for before switching + to gbdt. + tree_feature_columns: An iterable containing all the feature columns + used by the model's boosted trees. If dnn_input_layer_to_tree is + set to True, these features are in addition to dnn_feature_columns. + tree_center_bias: Whether a separate tree should be created for + first fitting the bias. + """ + head = head_lib.regression_head( + label_name=label_name, + label_dimension=label_dimension, + weight_column_name=weight_column_name, + enable_centered_bias=False) + + # num_classes needed for GradientBoostedDecisionTreeModel + if label_dimension == 1: + tree_learner_config.num_classes = 2 + else: + tree_learner_config.num_classes = label_dimension + + def _model_fn(features, labels, mode, config): + return _dnn_tree_combined_model_fn( + features, labels, mode, head, dnn_hidden_units, dnn_feature_columns, + tree_learner_config, num_trees, tree_examples_per_layer, config, + dnn_optimizer, dnn_activation_fn, dnn_dropout, + dnn_input_layer_partitioner, dnn_input_layer_to_tree, + dnn_steps_to_train, tree_feature_columns, tree_center_bias) + + super(DNNBoostedTreeCombinedRegressor, self).__init__( + model_fn=_model_fn, model_dir=model_dir, + config=config, feature_engineering_fn=feature_engineering_fn) + + +class DNNBoostedTreeCombinedEstimator(estimator.Estimator): + """An estimator that uses a combined DNN/GBDT model. + + Useful for training with user specified `Head`. + """ + + def __init__(self, + dnn_hidden_units, + dnn_feature_columns, + tree_learner_config, + num_trees, + tree_examples_per_layer, + head, + model_dir=None, + config=None, + feature_engineering_fn=None, + dnn_optimizer="Adagrad", + dnn_activation_fn=nn.relu, + dnn_dropout=None, + dnn_input_layer_partitioner=None, + dnn_input_layer_to_tree=True, + dnn_steps_to_train=10000, + tree_feature_columns=None, + tree_center_bias=True): + """Initializes a DNNBoostedTreeCombinedEstimator instance. + + Args: + dnn_hidden_units: List of hidden units per layer for DNN. + dnn_feature_columns: An iterable containing all the feature columns + used by the model's DNN. + tree_learner_config: A config for the tree learner. + num_trees: Number of trees to grow model to after training DNN. + tree_examples_per_layer: Number of examples to accumulate before + growing the tree a layer. This value has a big impact on model + quality and should be set equal to the number of examples in + training dataset if possible. It can also be a function that computes + the number of examples based on the depth of the layer that's + being built. + head: `Head` instance. + model_dir: Directory for model exports. + config: `RunConfig` of the estimator. + feature_engineering_fn: Feature engineering function. Takes features and + labels which are the output of `input_fn` and returns features and + labels which will be fed into the model. + dnn_optimizer: string, `Optimizer` object, or callable that defines the + optimizer to use for training the DNN. If `None`, will use the Adagrad + optimizer with default learning rate. + dnn_activation_fn: Activation function applied to each layer of the DNN. + If `None`, will use `tf.nn.relu`. + dnn_dropout: When not `None`, the probability to drop out a given + unit in the DNN. + dnn_input_layer_partitioner: Partitioner for input layer of the DNN. + Defaults to `min_max_variable_partitioner` with `min_slice_size` + 64 << 20. + dnn_input_layer_to_tree: Whether to provide the DNN's input layer + as a feature to the tree. + dnn_steps_to_train: Number of steps to train dnn for before switching + to gbdt. + tree_feature_columns: An iterable containing all the feature columns + used by the model's boosted trees. If dnn_input_layer_to_tree is + set to True, these features are in addition to dnn_feature_columns. + tree_center_bias: Whether a separate tree should be created for + first fitting the bias. + """ + def _model_fn(features, labels, mode, config): + return _dnn_tree_combined_model_fn( + features, labels, mode, head, dnn_hidden_units, dnn_feature_columns, + tree_learner_config, num_trees, tree_examples_per_layer, config, + dnn_optimizer, dnn_activation_fn, dnn_dropout, + dnn_input_layer_partitioner, dnn_input_layer_to_tree, + dnn_steps_to_train, + tree_feature_columns, tree_center_bias) + + super(DNNBoostedTreeCombinedEstimator, self).__init__( + model_fn=_model_fn, model_dir=model_dir, + config=config, feature_engineering_fn=feature_engineering_fn) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py new file mode 100644 index 0000000000000000000000000000000000000000..83d58c561008e8a5a69eb503d1605bb9e940f281 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py @@ -0,0 +1,105 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for combined DNN + GBDT estimators.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tempfile + +from tensorflow.contrib.boosted_trees.estimator_batch import dnn_tree_combined_estimator as estimator +from tensorflow.contrib.boosted_trees.proto import learner_pb2 +from tensorflow.contrib.layers.python.layers import feature_column +from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils +from tensorflow.contrib.learn.python.learn.estimators import run_config +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.platform import googletest + + +def _train_input_fn(): + features = { + "x": constant_op.constant([[2.], [1.], [1.]]) + } + label = constant_op.constant([[1], [0], [0]], dtype=dtypes.int32) + return features, label + + +def _eval_input_fn(): + features = { + "x": constant_op.constant([[1.], [2.], [2.]]) + } + label = constant_op.constant([[0], [1], [1]], dtype=dtypes.int32) + return features, label + + +class DNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase): + + def testClassifierContract(self): + estimator_test_utils.assert_estimator_contract( + self, estimator.DNNBoostedTreeCombinedClassifier) + + def testRegressorContract(self): + estimator_test_utils.assert_estimator_contract( + self, estimator.DNNBoostedTreeCombinedRegressor) + + def testEstimatorContract(self): + estimator_test_utils.assert_estimator_contract( + self, estimator.DNNBoostedTreeCombinedEstimator) + + def testNoDNNFeatureColumns(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + + with self.assertRaisesRegexp( + ValueError, + "dnn_feature_columns must be specified"): + classifier = estimator.DNNBoostedTreeCombinedClassifier( + dnn_hidden_units=[1], + dnn_feature_columns=[], + tree_learner_config=learner_config, + num_trees=1, + tree_examples_per_layer=3, + n_classes=2) + classifier.fit(input_fn=_train_input_fn, steps=5) + + def testFitAndEvaluateDontThrowException(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.DNNBoostedTreeCombinedClassifier( + dnn_hidden_units=[1], + dnn_feature_columns=[feature_column.real_valued_column("x")], + tree_learner_config=learner_config, + num_trees=1, + tree_examples_per_layer=3, + n_classes=2, + model_dir=model_dir, + config=config, + dnn_steps_to_train=10, + dnn_input_layer_to_tree=False, + tree_feature_columns=[feature_column.real_valued_column("x")]) + + classifier.fit(input_fn=_train_input_fn, steps=15) + classifier.evaluate(input_fn=_eval_input_fn, steps=1) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py b/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py index 79193fffc3d3fa97e20a12181bf20e6ad86dcb58..2e4151cac40f770e2bece70d752122eb7f34dd40 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py @@ -24,6 +24,7 @@ from tensorflow.contrib.learn.python.learn import session_run_hook from tensorflow.contrib.learn.python.learn.session_run_hook import SessionRunArgs from tensorflow.core.framework.summary_pb2 import Summary from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import training_util from tensorflow.python.training.summary_io import SummaryWriterCache @@ -175,3 +176,40 @@ class StopAfterNTrees(session_run_hook.SessionRunHook): logging.info("Requesting stop since we have reached %d trees.", num_finalized_trees) run_context.request_stop() + + +class SwitchTrainOp(session_run_hook.SessionRunHook): + """Hook that switches the train op after specified number of steps. + + Hook that replaces the train op depending on the number of steps of training + that have taken place. The first_train_op is used till train_steps steps + are reached. Thereafter the second_train_op is used. + """ + + def __init__(self, first_train_op, train_steps, second_train_op): + """Initializes a `SwitchTrainOp`.""" + self._first_train_op = first_train_op + self._second_train_op = second_train_op + self._train_steps = train_steps + + def _get_train_op_for_global_step(self, current_step): + """Gets train_op for current global step.""" + if current_step < self._train_steps: + return self._first_train_op + return self._second_train_op + + def begin(self): + self._global_step_tensor = training_util.get_global_step() + self._current_train_op = control_flow_ops.no_op() + if self._global_step_tensor is None: + raise RuntimeError( + "Global step should be created to use SwitchTrainOp.") + + def before_run(self, run_context): # pylint: disable=unused-argument + return session_run_hook.SessionRunArgs( + {"global_step": self._global_step_tensor, + "train_op": self._current_train_op}) + + def after_run(self, run_context, run_values): + self._current_train_op = self._get_train_op_for_global_step( + run_values.results["global_step"]) diff --git a/tensorflow/contrib/boosted_trees/examples/boston.py b/tensorflow/contrib/boosted_trees/examples/boston.py index 2c0a3c4912b82aba88e2f8f1b97a227c894ee2ae..e9dbdb0fd784052eeb36ac1aa9342165ef2ac0a7 100644 --- a/tensorflow/contrib/boosted_trees/examples/boston.py +++ b/tensorflow/contrib/boosted_trees/examples/boston.py @@ -22,7 +22,7 @@ r"""Demonstrates a regression on Boston housing data. python tensorflow/contrib/boosted_trees/examples/boston.py \ --batch_size=404 --output_dir="/tmp/boston" --depth=4 --learning_rate=0.1 \ - --num_eval_steps=1 --num_trees=500 --l2=4 \ + --num_eval_steps=1 --num_trees=500 --l2=0.001 \ --vmodule=training_ops=1 When training is done, mean squared error on eval data is reported. @@ -37,8 +37,10 @@ from __future__ import division from __future__ import print_function import argparse +import os import sys import tensorflow as tf +from tensorflow.contrib.boosted_trees.estimator_batch import custom_export_strategy from tensorflow.contrib.boosted_trees.estimator_batch.estimator import GradientBoostedDecisionTreeRegressor from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.layers.python.layers import feature_column @@ -51,22 +53,18 @@ _BOSTON_NUM_FEATURES = 13 def _get_tfbt(output_dir, feature_cols): """Configures TF Boosted Trees estimator based on flags.""" learner_config = learner_pb2.LearnerConfig() - learner_config.learning_rate_tuner.fixed.learning_rate = FLAGS.learning_rate learner_config.regularization.l1 = 0.0 - # Set the regularization per instance in such a way that - # regularization for the full training data is equal to l2 flag. - learner_config.regularization.l2 = FLAGS.l2 / FLAGS.batch_size + learner_config.regularization.l2 = FLAGS.l2 learner_config.constraints.max_tree_depth = FLAGS.depth - learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE run_config = tf.contrib.learn.RunConfig(save_checkpoints_secs=300) # Create a TF Boosted trees regression estimator. estimator = GradientBoostedDecisionTreeRegressor( learner_config=learner_config, - # For the WHOLE_TREE strategy, set the examples_per_layer to be equal to - # batch size. + # This should be the number of examples. For large datasets it can be + # larger than the batch_size. examples_per_layer=FLAGS.batch_size, feature_columns=feature_cols, label_dimension=1, @@ -77,6 +75,14 @@ def _get_tfbt(output_dir, feature_cols): return estimator +def _convert_fn(dtec, sorted_feature_names, num_dense, num_sparse_float, + num_sparse_int, export_dir, unused_eval_result): + universal_format = custom_export_strategy.convert_to_universal_format( + dtec, sorted_feature_names, num_dense, num_sparse_float, num_sparse_int) + with tf.gfile.GFile(os.path.join(export_dir, "tree_proto"), "w") as f: + f.write(str(universal_format)) + + def _make_experiment_fn(output_dir): """Creates experiment for gradient boosted decision trees.""" (x_train, y_train), (x_test, @@ -88,21 +94,31 @@ def _make_experiment_fn(output_dir): batch_size=FLAGS.batch_size, num_epochs=None, shuffle=True) - eval_input_fn = tf.estimator.inputs.numpy_input_fn( x={"x": x_test}, y=y_test, num_epochs=1, shuffle=False) feature_columns = [ feature_column.real_valued_column("x", dimension=_BOSTON_NUM_FEATURES) ] - + feature_spec = tf.contrib.layers.create_feature_spec_for_parsing( + feature_columns) + serving_input_fn = tf.contrib.learn.utils.build_parsing_serving_input_fn( + feature_spec) + # An export strategy that outputs the feature importance and also exports + # the internal tree representation in another format. + export_strategy = custom_export_strategy.make_custom_export_strategy( + "exports", + convert_fn=_convert_fn, + feature_columns=feature_columns, + export_input_fn=serving_input_fn) return tf.contrib.learn.Experiment( estimator=_get_tfbt(output_dir, feature_columns), train_input_fn=train_input_fn, eval_input_fn=eval_input_fn, train_steps=None, eval_steps=FLAGS.num_eval_steps, - eval_metrics=None) + eval_metrics=None, + export_strategies=[export_strategy]) def main(unused_argv): diff --git a/tensorflow/contrib/boosted_trees/examples/boston_combined.py b/tensorflow/contrib/boosted_trees/examples/boston_combined.py new file mode 100644 index 0000000000000000000000000000000000000000..e04b56afbfd266dc13a5b0d78d171ea273415ee3 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/examples/boston_combined.py @@ -0,0 +1,165 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Regression on Boston housing data using DNNBoostedTreeCombinedRegressor. + + Example Usage: + + python tensorflow/contrib/boosted_trees/examples/boston_combined.py \ + --batch_size=404 --output_dir="/tmp/boston" \ + --dnn_hidden_units="8,4" --dnn_steps_to_train=1000 \ + --tree_depth=4 --tree_learning_rate=0.1 \ + --num_trees=100 --tree_l2=0.001 --num_eval_steps=1 \ + --vmodule=training_ops=1 + + When training is done, mean squared error on eval data is reported. + Point tensorboard to the directory for the run to see how the training + progresses: + + tensorboard --logdir=/tmp/boston + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys +import tensorflow as tf + +from tensorflow.contrib.boosted_trees.estimator_batch.dnn_tree_combined_estimator import DNNBoostedTreeCombinedRegressor +from tensorflow.contrib.boosted_trees.proto import learner_pb2 +from tensorflow.contrib.layers.python.layers import feature_column +from tensorflow.contrib.learn.python.learn import learn_runner +from tensorflow.contrib.learn.python.learn.utils import input_fn_utils +from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils + +_BOSTON_NUM_FEATURES = 13 + + +def _get_estimator(output_dir, feature_cols): + """Configures DNNBoostedTreeCombinedRegressor based on flags.""" + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.fixed.learning_rate = ( + FLAGS.tree_learning_rate) + learner_config.regularization.l1 = 0.0 + learner_config.regularization.l2 = FLAGS.tree_l2 + learner_config.constraints.max_tree_depth = FLAGS.tree_depth + + run_config = tf.contrib.learn.RunConfig(save_summary_steps=1) + + # Create a DNNBoostedTreeCombinedRegressor estimator. + estimator = DNNBoostedTreeCombinedRegressor( + dnn_hidden_units=[int(x) for x in FLAGS.dnn_hidden_units.split(",")], + dnn_feature_columns=feature_cols, + tree_learner_config=learner_config, + num_trees=FLAGS.num_trees, + # This should be the number of examples. For large datasets it can be + # larger than the batch_size. + tree_examples_per_layer=FLAGS.batch_size, + model_dir=output_dir, + config=run_config, + dnn_input_layer_to_tree=True, + dnn_steps_to_train=FLAGS.dnn_steps_to_train) + return estimator + + +def _make_experiment_fn(output_dir): + """Creates experiment for DNNBoostedTreeCombinedRegressor.""" + (x_train, y_train), (x_test, + y_test) = tf.keras.datasets.boston_housing.load_data() + + train_input_fn = tf.estimator.inputs.numpy_input_fn( + x={"x": x_train}, + y=y_train, + batch_size=FLAGS.batch_size, + num_epochs=None, + shuffle=True) + eval_input_fn = tf.estimator.inputs.numpy_input_fn( + x={"x": x_test}, y=y_test, num_epochs=1, shuffle=False) + + feature_columns = [ + feature_column.real_valued_column("x", dimension=_BOSTON_NUM_FEATURES) + ] + feature_spec = tf.contrib.layers.create_feature_spec_for_parsing( + feature_columns) + serving_input_fn = input_fn_utils.build_parsing_serving_input_fn(feature_spec) + export_strategies = [ + saved_model_export_utils.make_export_strategy(serving_input_fn)] + return tf.contrib.learn.Experiment( + estimator=_get_estimator(output_dir, feature_columns), + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + train_steps=None, + eval_steps=FLAGS.num_eval_steps, + eval_metrics=None, + export_strategies=export_strategies) + + +def main(unused_argv): + learn_runner.run( + experiment_fn=_make_experiment_fn, + output_dir=FLAGS.output_dir, + schedule="train_and_evaluate") + + +if __name__ == "__main__": + tf.logging.set_verbosity(tf.logging.INFO) + parser = argparse.ArgumentParser() + # Define the list of flags that users can change. + parser.add_argument( + "--batch_size", + type=int, + default=1000, + help="The batch size for reading data.") + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Choose the dir for the output.") + parser.add_argument( + "--num_eval_steps", + type=int, + default=1, + help="The number of steps to run evaluation for.") + # Flags for configuring DNNBoostedTreeCombinedRegressor. + parser.add_argument( + "--dnn_hidden_units", + type=str, + default="8,4", + help="Hidden layers for DNN.") + parser.add_argument( + "--dnn_steps_to_train", + type=int, + default=1000, + help="Number of steps to train DNN.") + parser.add_argument( + "--tree_depth", type=int, default=4, help="Maximum depth of trees.") + parser.add_argument( + "--tree_l2", type=float, default=1.0, help="l2 regularization per batch.") + parser.add_argument( + "--tree_learning_rate", + type=float, + default=0.1, + help=("Learning rate (shrinkage weight) with which each " + "new tree is added.")) + parser.add_argument( + "--num_trees", + type=int, + default=None, + required=True, + help="Number of trees to grow before stopping.") + + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index 3bd30d8678920c1320bf6fedc2f40f5922237a92..18b4abd654ea3541d646a43ac901aca1a678446f 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -16,7 +16,7 @@ #include #include -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/feature-column-handler.h" +#include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h" #include "tensorflow/contrib/boosted_trees/proto/split_info.pb.h" #include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" #include "tensorflow/core/framework/device_base.h" @@ -490,11 +490,11 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { } dense_split->set_feature_column(feature_column_group_id_); // Set the feature index for the best feature column. - const int64 best_feature_id = + const int64 best_dimension_id = bucket_ids_and_dimensions(best_element_idx, 1); const int32 best_bucket_id = bucket_ids_and_dimensions(best_element_idx, 0); - dense_split->set_feature_id(best_feature_id); + dense_split->set_dimension_id(best_dimension_id); dense_split->set_threshold(bucket_boundaries(best_bucket_id)); auto* left_child = split_info.mutable_left_child(); diff --git a/tensorflow/contrib/boosted_trees/lib/BUILD b/tensorflow/contrib/boosted_trees/lib/BUILD index 107ff0d295bee530c1711a97849fbd3c6cdb2f00..131bd48562a55a08981ac73277e93024db0d85d3 100644 --- a/tensorflow/contrib/boosted_trees/lib/BUILD +++ b/tensorflow/contrib/boosted_trees/lib/BUILD @@ -406,51 +406,9 @@ tf_cc_test( ) # Learner/stochastic - -cc_library( - name = "feature-column-handlers", - srcs = [ - "learner/stochastic/handlers/bias-feature-column-handler.cc", - "learner/stochastic/handlers/categorical-feature-column-handler.cc", - "learner/stochastic/handlers/dense-quantized-feature-column-handler.cc", - "learner/stochastic/handlers/sparse-quantized-feature-column-handler.cc", - ], - hdrs = [ - "learner/stochastic/handlers/bias-feature-column-handler.h", - "learner/stochastic/handlers/categorical-feature-column-handler.h", - "learner/stochastic/handlers/dense-quantized-feature-column-handler.h", - "learner/stochastic/handlers/feature-column-handler.h", - "learner/stochastic/handlers/sparse-quantized-feature-column-handler.h", - ], - deps = [ - ":feature-split-candidate", - ":feature-stats-accumulator", - "//tensorflow/contrib/boosted_trees/proto:learner_proto_cc", - "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:protos_all_cc", - ], -) - -tf_cc_test( - name = "feature-column-handlers_test", - size = "small", - srcs = [ - "learner/stochastic/handlers/bias-feature-column-handler_test.cc", - "learner/stochastic/handlers/categorical-feature-column-handler_test.cc", - "learner/stochastic/handlers/dense-quantized-feature-column-handler_test.cc", - "learner/stochastic/handlers/sparse-quantized-feature-column-handler_test.cc", - ], - deps = [ - ":feature-column-handlers", - "//tensorflow/core:tensor_testutil", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - cc_library( name = "gradient-stats", - hdrs = ["learner/stochastic/stats/gradient-stats.h"], + hdrs = ["learner/common/stats/gradient-stats.h"], deps = [ "//tensorflow/core:framework_headers_lib", "//third_party/eigen3", @@ -459,7 +417,7 @@ cc_library( cc_library( name = "node-stats", - hdrs = ["learner/stochastic/stats/node-stats.h"], + hdrs = ["learner/common/stats/node-stats.h"], deps = [ ":gradient-stats", "//tensorflow/contrib/boosted_trees/proto:learner_proto_cc", @@ -471,7 +429,7 @@ cc_library( cc_library( name = "split-stats", - hdrs = ["learner/stochastic/stats/split-stats.h"], + hdrs = ["learner/common/stats/split-stats.h"], deps = [ ":node-stats", ], @@ -479,7 +437,7 @@ cc_library( cc_library( name = "feature-split-candidate", - hdrs = ["learner/stochastic/stats/feature-split-candidate.h"], + hdrs = ["learner/common/stats/feature-split-candidate.h"], deps = [ ":split-stats", "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc", @@ -489,7 +447,7 @@ cc_library( tf_cc_test( name = "node-stats_test", size = "small", - srcs = ["learner/stochastic/stats/node-stats_test.cc"], + srcs = ["learner/common/stats/node-stats_test.cc"], deps = [ ":node-stats", "//tensorflow/core:tensor_testutil", diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py index 72e20aaa127cda592bd314786cddb925cc87a075..7df514cd207c5e781f3b4abaa2020016b197669d 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py @@ -436,7 +436,7 @@ def dense_make_stats_update(is_active, are_buckets_ready, float_column, quantized_feature = quantile_ops.quantiles([float_column], [], [quantile_buckets], [], []) quantized_feature = math_ops.cast(quantized_feature[0], dtypes.int64) - quantized_feature = array_ops.squeeze(quantized_feature) + quantized_feature = array_ops.squeeze(quantized_feature, axis=0) return (example_partition_ids, quantized_feature, gradients, hessians) def not_ready_inputs_fn(): @@ -468,7 +468,7 @@ def sparse_make_stats_update( [sparse_column_indices]) quantized_feature = math_ops.cast(quantized_feature[1], dtypes.int64) - quantized_feature = array_ops.squeeze(quantized_feature) + quantized_feature = array_ops.squeeze(quantized_feature, axis=0) example_indices, _ = array_ops.split( sparse_column_indices, num_or_size_splits=2, axis=1) diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py index ee16a5f838a65f20db4436eb86527518621b6d8d..54d03018d9e266beabbbabd78ebbb80cfe689c04 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py @@ -1121,6 +1121,87 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(len(gains), 0) self.assertEqual(len(splits), 0) + def testDegenerativeCase(self): + with self.test_session() as sess: + # One data example only, one leaf and thus one quantile bucket.The same + # situation is when all examples have the same values. This case was + # causing before a failure. + gradients = array_ops.constant([0.2]) + hessians = array_ops.constant([0.12]) + example_partitions = array_ops.constant([1], dtype=dtypes.int32) + indices = array_ops.constant([[0, 0]], dtype=dtypes.int64) + values = array_ops.constant([0.58]) + sparse_column = sparse_tensor.SparseTensor(indices, values, [1, 1]) + + gradient_shape = tensor_shape.scalar() + hessian_shape = tensor_shape.scalar() + class_id = -1 + + split_handler = ordinal_split_handler.SparseSplitHandler( + l1_regularization=0, + l2_regularization=2, + tree_complexity_regularization=0, + min_node_weight=0, + epsilon=0.01, + num_quantiles=2, + feature_column_group_id=0, + sparse_float_column=sparse_column, + init_stamp_token=0, + gradient_shape=gradient_shape, + hessian_shape=hessian_shape, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS) + resources.initialize_resources(resources.shared_resources()).run() + + empty_gradients, empty_hessians = get_empty_tensors( + gradient_shape, hessian_shape) + example_weights = array_ops.ones([1, 1], dtypes.float32) + + update_1 = split_handler.update_stats_sync( + 0, + example_partitions, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_1]): + are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + + with ops.control_dependencies([are_splits_ready]): + update_2 = split_handler.update_stats_sync( + 1, + example_partitions, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_2]): + are_splits_ready2, partitions, gains, splits = ( + split_handler.make_splits(1, 2, class_id)) + are_splits_ready, are_splits_ready2, partitions, gains, splits = ( + sess.run([ + are_splits_ready, are_splits_ready2, partitions, gains, splits + ])) + + # During the first iteration, inequality split handlers are not going to + # have any splits. Make sure that we return not_ready in that case. + self.assertFalse(are_splits_ready) + self.assertTrue(are_splits_ready2) + + self.assertAllEqual([1], partitions) + self.assertAllEqual([0.0], gains) + + split_info = split_info_pb2.SplitInfo() + split_info.ParseFromString(splits[0]) + split_node = split_info.split_node.sparse_float_binary_split_default_left + + self.assertEqual(0, split_node.split.feature_column) + + self.assertAllClose(0.58, split_node.split.threshold) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/feature-split-candidate.h b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/feature-split-candidate.h similarity index 90% rename from tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/feature-split-candidate.h rename to tensorflow/contrib/boosted_trees/lib/learner/common/stats/feature-split-candidate.h index fe22691178213094b9affcdee06af98011f85bd2..339c2e0fded10e6a7b140da62e152e2868ffd164 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/feature-split-candidate.h +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/feature-split-candidate.h @@ -13,10 +13,10 @@ // limitations under the License. // // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_FEATURE_SPLIT_CANDIDATE_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_FEATURE_SPLIT_CANDIDATE_H_ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_ -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/split-stats.h" +#include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/split-stats.h" #include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" namespace tensorflow { @@ -58,4 +58,4 @@ struct FeatureSplitCandidate { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_FEATURE_SPLIT_CANDIDATE_H_ +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/gradient-stats.h b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/gradient-stats.h similarity index 98% rename from tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/gradient-stats.h rename to tensorflow/contrib/boosted_trees/lib/learner/common/stats/gradient-stats.h index dad64bf165a41bc4f32eea6b37e7afb569887a06..34e3ddb777242553d62035a51f1aec33d0f9ba54 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/gradient-stats.h +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/gradient-stats.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_GRADIENT_STATS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_GRADIENT_STATS_H_ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_ #include @@ -190,4 +190,4 @@ inline GradientStats operator-(const GradientStats& a, const GradientStats& b) { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_GRADIENT_STATS_H_ +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats.h b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h similarity index 98% rename from tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats.h rename to tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h index 4e5f53874df2207ffa6664a33675f84ef055394b..642a183aec5c7e591579fa5ee91d45729bfb624d 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats.h +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_NODE_STATS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_NODE_STATS_H_ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_ #include "third_party/eigen3/Eigen/Core" #include "third_party/eigen3/Eigen/Eigenvalues" -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/gradient-stats.h" +#include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/gradient-stats.h" #include "tensorflow/contrib/boosted_trees/proto/learner.pb.h" #include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" #include "tensorflow/core/framework/shape_inference.h" @@ -298,4 +298,4 @@ struct NodeStats { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_NODE_STATS_H_ +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats_test.cc b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats_test.cc similarity index 99% rename from tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats_test.cc rename to tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats_test.cc index ecb7a04efb96248210d9af770c8377b7f6906598..f867e77d3ef0609774628b2a9c36ca52bcf2a957 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats.h" +#include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/split-stats.h b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/split-stats.h similarity index 94% rename from tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/split-stats.h rename to tensorflow/contrib/boosted_trees/lib/learner/common/stats/split-stats.h index f700cbced833543227de39f54c9ecbb03a7ce7c9..054ccd9a8cd0be0c48b14cca013f15677deba900 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/split-stats.h +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/split-stats.h @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_SPLIT_STATS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_SPLIT_STATS_H_ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_ #include -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats.h" +#include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h" namespace tensorflow { namespace boosted_trees { @@ -81,4 +81,4 @@ struct SplitStats { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_SPLIT_STATS_H_ +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler.cc b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler.cc deleted file mode 100644 index b880cf2c47989b1434f17802befb7dd7c248b36f..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler.cc +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { - -void BiasFeatureColumnHandler::AggregateGradientStats( - const std::vector& example_partition_ids, - const Tensor& example_first_order_gradients, - const Tensor& example_second_order_gradients, - FeatureStatsAccumulator* - gradient_stats_accumulator) const { - // Pass over all examples and aggregate gradient stats for each sub-root. - for (int64 example_idx = 0; example_idx < batch_size_; ++example_idx) { - auto partition_id = example_partition_ids[example_idx]; - gradient_stats_accumulator->AddStats( - slot_id_, class_id_, partition_id, kBiasFeatureId, - GradientStats(example_first_order_gradients, - example_second_order_gradients, example_idx)); - } -} - -void BiasFeatureColumnHandler::GenerateFeatureSplitCandidates( - const LearnerConfig& learner_config, const std::vector& roots, - const std::vector& root_stats, - const FeatureStatsAccumulator& - gradient_stats_accumulator, - std::vector* split_candidates) const { - split_candidates->clear(); - split_candidates->reserve(roots.size()); - boosted_trees::trees::TreeNode tree_node; - for (size_t root_idx = 0; root_idx < roots.size(); ++root_idx) { - const NodeStats& root_node_stats = root_stats[root_idx]; - tree_node.Clear(); - root_node_stats.FillLeaf(class_id_, tree_node.mutable_leaf()); - split_candidates->emplace_back(slot_id_, tree_node, - SplitStats(learner_config, root_node_stats)); - } -} - -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler.h b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler.h deleted file mode 100644 index 5c0f99185a63db33a391a98fa16f37bef99507c9..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler.h +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_H_ // NOLINT -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_H_ // NOLINT - -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/feature-column-handler.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { - -// Handler for a bias feature column in the single class case. -// This handler is useful even if we don't introduce a bias feature because -// it allows us to aggregate stats per partition which in turn allows us -// to compute node stats for each root to split. -class BiasFeatureColumnHandler : public FeatureColumnHandler { - public: - BiasFeatureColumnHandler(const uint32 class_id, const uint32 slot_id, - const int64 batch_size) - : FeatureColumnHandler(class_id, slot_id, batch_size) {} - - void AggregateGradientStats( - const std::vector& example_partition_ids, - const Tensor& example_first_order_gradients, - const Tensor& example_second_order_gradients, - FeatureStatsAccumulator* - gradient_stats_accumulator) const override; - - void GenerateFeatureSplitCandidates( - const LearnerConfig& learner_config, const std::vector& roots, - const std::vector& root_stats, - const FeatureStatsAccumulator& - gradient_stats_accumulator, - std::vector* split_candidates) const override; - - static constexpr auto kBiasFeatureId = 0; -}; - -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_H_ // NOLINT diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler_test.cc b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler_test.cc deleted file mode 100644 index f4c7df7fabda1a38d7e6cca4c5c8bc81cb7551b1..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler_test.cc +++ /dev/null @@ -1,135 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler.h" - -#include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { -namespace { - -using boosted_trees::learner::LearnerConfig; - -const auto kClassId = 7; -const auto kSlotId = 0; -const auto kBatchSize = 4; - -using FeatureStatsAccumulator = - FeatureStatsAccumulator; - -class BiasFeatureColumnHandlerTest : public ::testing::Test { - protected: - BiasFeatureColumnHandlerTest() - : example_first_order_gradients_( - test::AsTensor({0.2f, -0.5f, 1.2f, 4.0f}, {4})), - example_second_order_gradients_( - test::AsTensor({0.12f, 0.07f, 0.2f, 0.13f}, {4})), - example_partitions_({0, 0, 1, 3}) { - // Set L2 regularization. - learner_config_.mutable_regularization()->set_l2(2.0f); - learner_config_.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS); - - // Create handler. - handler_.reset(new BiasFeatureColumnHandler(kClassId, kSlotId, kBatchSize)); - } - - LearnerConfig learner_config_; - const Tensor example_first_order_gradients_; - const Tensor example_second_order_gradients_; - const std::vector example_partitions_; - std::unique_ptr handler_; -}; - -TEST_F(BiasFeatureColumnHandlerTest, AggregateGradientStats) { - // Create handler. - FeatureStatsAccumulator accumulator(1); - handler_->AggregateGradientStats( - example_partitions_, example_first_order_gradients_, - example_second_order_gradients_, &accumulator); - - // Check stats for each partition. - // Partition 0. - EXPECT_GRADIENT_STATS_EQ( - GradientStats(-0.3f, 0.19f), - accumulator.GetStats(kSlotId, kClassId, 0, - BiasFeatureColumnHandler::kBiasFeatureId)); - // Partition 1. - EXPECT_GRADIENT_STATS_EQ( - GradientStats(1.2f, 0.2f), - accumulator.GetStats(kSlotId, kClassId, 1, - BiasFeatureColumnHandler::kBiasFeatureId)); - // Partition 2. - EXPECT_GRADIENT_STATS_EQ( - GradientStats(0.0f, 0.0f), - accumulator.GetStats(kSlotId, kClassId, 2, - BiasFeatureColumnHandler::kBiasFeatureId)); - // Partition 3. - EXPECT_GRADIENT_STATS_EQ( - GradientStats(4.0f, 0.13f), - accumulator.GetStats(kSlotId, kClassId, 3, - BiasFeatureColumnHandler::kBiasFeatureId)); -} - -TEST_F(BiasFeatureColumnHandlerTest, GenerateFeatureSplitCandidates) { - // Create handler. - FeatureStatsAccumulator accumulator(1); - handler_->AggregateGradientStats( - example_partitions_, example_first_order_gradients_, - example_second_order_gradients_, &accumulator); - - // Get feature split candidates for two roots 0 and 3. - // Root 0 has zero gain and root 3 has the same gain as the leaf. - const std::vector roots = {0, 3}; - const std::vector& root_stats = { - NodeStats(1), NodeStats(learner_config_, GradientStats(4.0f, 0.13f))}; - std::vector split_candidates; - handler_->GenerateFeatureSplitCandidates(learner_config_, roots, root_stats, - accumulator, &split_candidates); - // Expect two candidate splits (one per root). - EXPECT_EQ(2, split_candidates.size()); - - // Verify first candidate for root 0, gain is expected to be the same as - // the left child since the root node gain is zero. - const SplitStats expected_split_stats0(learner_config_, root_stats[0]); - EXPECT_SPLIT_STATS_EQ(expected_split_stats0, split_candidates[0].split_stats); - const auto& tree_node0 = split_candidates[0].tree_node; - EXPECT_EQ(boosted_trees::trees::TreeNode::kLeaf, tree_node0.node_case()); - EXPECT_EQ(1, tree_node0.leaf().sparse_vector().index_size()); - EXPECT_EQ(kClassId, tree_node0.leaf().sparse_vector().index(0)); - EXPECT_EQ(1, tree_node0.leaf().sparse_vector().value_size()); - EXPECT_EQ(root_stats[0].weight_contribution[0], - tree_node0.leaf().sparse_vector().value(0)); - - // Verify second candidate for root 3, gain is expected to be zero as - // the left child gain is equal to the parent gain. - const SplitStats expected_split_stats1(learner_config_, root_stats[1]); - EXPECT_SPLIT_STATS_EQ(expected_split_stats1, split_candidates[1].split_stats); - const auto& tree_node1 = split_candidates[1].tree_node; - EXPECT_EQ(boosted_trees::trees::TreeNode::kLeaf, tree_node1.node_case()); - EXPECT_EQ(1, tree_node1.leaf().sparse_vector().index_size()); - EXPECT_EQ(kClassId, tree_node1.leaf().sparse_vector().index(0)); - EXPECT_EQ(1, tree_node1.leaf().sparse_vector().value_size()); - EXPECT_EQ(root_stats[1].weight_contribution[0], - tree_node1.leaf().sparse_vector().value(0)); -} - -} // namespace -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler.cc b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler.cc deleted file mode 100644 index 3a6c409f846c9ca0bd6b5101e96447642b949978..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler.cc +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler.h" - -#include "tensorflow/core/platform/macros.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { - -namespace { - -// Creates a categorical Id split node without assigning children. -boosted_trees::trees::TreeNode CreateCategoricalIdNode( - const int32 feature_column, const int32 id) { - boosted_trees::trees::TreeNode split_node; - auto* split = split_node.mutable_categorical_id_binary_split(); - split->set_feature_column(feature_column); - split->set_feature_id(id); - return split_node; -} - -} // namespace - -void CategoricalFeatureColumnHandler::AggregateGradientStats( - const std::vector& example_partition_ids, - const Tensor& example_first_order_gradients, - const Tensor& example_second_order_gradients, - FeatureStatsAccumulator* - gradient_stats_accumulator) const { - // Pass over all rows and aggregate gradient stats for each feature id. - const int64 num_rows = indices_.dimension(0); - for (int64 row_idx = 0; row_idx < num_rows; ++row_idx) { - auto example_idx = indices_(row_idx, 0); - auto feature_id = values_(row_idx); - const GradientStats norm_gradient_stats(example_first_order_gradients, - example_second_order_gradients, - example_idx); - auto partition_id = example_partition_ids[example_idx]; - gradient_stats_accumulator->AddStats(slot_id_, class_id_, partition_id, - feature_id, norm_gradient_stats); - } -} - -void CategoricalFeatureColumnHandler::GenerateFeatureSplitCandidates( - const LearnerConfig& learner_config, const std::vector& roots, - const std::vector& root_stats, - const FeatureStatsAccumulator& - gradient_stats_accumulator, - std::vector* split_candidates) const { - // Build a reverse lookup of partition id to root idx. - std::unordered_map partition_id_to_root_idx; - partition_id_to_root_idx.reserve(roots.size()); - for (size_t root_idx = 0; root_idx < roots.size(); ++root_idx) { - partition_id_to_root_idx[roots[root_idx]] = root_idx; - } - - // Initialize split candidates. - split_candidates->clear(); - if (!roots.empty()) { - FeatureSplitCandidate empty_candidate( - root_stats[0].weight_contribution.size()); - split_candidates->resize(roots.size(), empty_candidate); - } - for (auto& split_candidate : *split_candidates) { - split_candidate.split_stats.gain = std::numeric_limits::lowest(); - } - - // Evaluate split candidates for every root as each is a separate - // logical partition over the examples. - // Then for each root, we evaluate every feature id as an equality split - // and pick the highest split gain. - for (const auto& entry : - gradient_stats_accumulator.GetFeatureStats(slot_id_)) { - DCHECK_EQ(entry.first.class_id, class_id_); - - // Get partition id and root node stats. - const int32 partition_id = entry.first.partition_id; - auto root_idx_it = partition_id_to_root_idx.find(partition_id); - if (root_idx_it == partition_id_to_root_idx.end()) { - // Inactive partition. - continue; - } - size_t root_idx = root_idx_it->second; - const NodeStats& root_node_stats = root_stats[root_idx]; - - // Get gradient stats. - const auto& left_gradient_stats = entry.second; - auto right_gradient_stats = - root_node_stats.gradient_stats - left_gradient_stats; - - // Get node stats. - NodeStats left_node_stats(learner_config, left_gradient_stats); - NodeStats right_node_stats(learner_config, right_gradient_stats); - - // Generate split candidate and update best split candidate for the - // current root if needed. - FeatureSplitCandidate split_candidate( - slot_id_, - CreateCategoricalIdNode(feature_column_, entry.first.feature_id), - SplitStats(learner_config, root_node_stats, left_node_stats, - right_node_stats)); - FeatureSplitCandidate& best_split_candidate = (*split_candidates)[root_idx]; - if (TF_PREDICT_FALSE(best_split_candidate.tree_node.node_case() == - boosted_trees::trees::TreeNode::NODE_NOT_SET)) { - // Always replace candidates with no node set. - best_split_candidate = std::move(split_candidate); - } else if (TF_PREDICT_FALSE(split_candidate.split_stats.gain == - best_split_candidate.split_stats.gain)) { - // Tie break on feature id. - auto best_split_feature_id = - best_split_candidate.tree_node.categorical_id_binary_split() - .feature_id(); - if (entry.first.feature_id < best_split_feature_id) { - best_split_candidate = std::move(split_candidate); - } - } else if (split_candidate.split_stats.gain > - best_split_candidate.split_stats.gain) { - best_split_candidate = std::move(split_candidate); - } - } -} - -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler.h b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler.h deleted file mode 100644 index ef964ba716c6adf9cf9c291cca5f52f7a6efe26f..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler.h +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_CATEGORICAL_FEATURE_COLUMN_HANDLER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_CATEGORICAL_FEATURE_COLUMN_HANDLER_H_ - -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/feature-column-handler.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { - -// Handler for a categorical feature column in the single class case. -class CategoricalFeatureColumnHandler : public FeatureColumnHandler { - public: - CategoricalFeatureColumnHandler(const int32 class_id, const int32 slot_id, - const int64 batch_size, - const int32 feature_column, - TTypes::ConstMatrix indices, - TTypes::ConstVec values) - : FeatureColumnHandler(class_id, slot_id, batch_size), - feature_column_(feature_column), - indices_(indices), - values_(values) {} - - void AggregateGradientStats( - const std::vector& example_partition_ids, - const Tensor& example_first_order_gradients, - const Tensor& example_second_order_gradients, - FeatureStatsAccumulator* - gradient_stats_accumulator) const override; - - void GenerateFeatureSplitCandidates( - const LearnerConfig& learner_config, const std::vector& roots, - const std::vector& root_stats, - const FeatureStatsAccumulator& - gradient_stats_accumulator, - std::vector* split_candidates) const override; - - protected: - const int32 feature_column_; - TTypes::ConstMatrix indices_; - TTypes::ConstVec values_; -}; - -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_CATEGORICAL_FEATURE_COLUMN_HANDLER_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler_test.cc b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler_test.cc deleted file mode 100644 index ea82b3f086d24dc1f9ceb4783abd68be35b34b00..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler_test.cc +++ /dev/null @@ -1,165 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler.h" - -#include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { -namespace { - -using boosted_trees::learner::LearnerConfig; - -const auto kClassId = 7; -const auto kSlotId = 0; -const auto kBatchSize = 4; -const auto kFeatureColumn = 3; - -using FeatureStatsAccumulator = - FeatureStatsAccumulator; - -class CategoricalFeatureColumnHandlerTest : public ::testing::Test { - protected: - // The data looks like the following: - // Example | Gradients | Partition | Feature Id | - // i0 | (0.2, 0.12) | 0 | 1,2 | - // i1 | (-0.5, 0.07) | 0 | | - // i2 | (1.2, 0.2) | 0 | 2 | - // i3 | (4.0, 0.13) | 1 | 0 | - CategoricalFeatureColumnHandlerTest() - : example_first_order_gradients_( - test::AsTensor({0.2f, -0.5f, 1.2f, 4.0f}, {4})), - example_second_order_gradients_( - test::AsTensor({0.12f, 0.07f, 0.2f, 0.13f}, {4})), - example_partitions_({0, 0, 0, 1}), - indices_(test::AsTensor({0, 0, 0, 1, 2, 0, 3, 0}, {4, 2})), - values_(test::AsTensor({1, 2, 2, 0}, {4})) { - // Set L2 regularization. - learner_config_.mutable_regularization()->set_l2(2.0f); - learner_config_.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS); - // Create handler. - handler_.reset(new CategoricalFeatureColumnHandler( - kClassId, kSlotId, kBatchSize, kFeatureColumn, indices_.matrix(), - values_.vec())); - } - - LearnerConfig learner_config_; - const Tensor example_first_order_gradients_; - const Tensor example_second_order_gradients_; - const std::vector example_partitions_; - const Tensor indices_; - const Tensor values_; - std::unique_ptr handler_; -}; - -TEST_F(CategoricalFeatureColumnHandlerTest, AggregateGradientStats) { - // Create handler. - FeatureStatsAccumulator accumulator(1); - handler_->AggregateGradientStats( - example_partitions_, example_first_order_gradients_, - example_second_order_gradients_, &accumulator); - - // Check stats for each partition and feature. - // Partition 0, Feature 0. - EXPECT_GRADIENT_STATS_EQ(GradientStats(0.0f, 0.0f), - accumulator.GetStats(kSlotId, kClassId, 0, 0)); - // Partition 0, Feature 1. - EXPECT_GRADIENT_STATS_EQ(GradientStats(0.2f, 0.12f), - accumulator.GetStats(kSlotId, kClassId, 0, 1)); - // Partition 0, Feature 2. - EXPECT_GRADIENT_STATS_EQ(GradientStats(0.2f + 1.2f, 0.12f + 0.2f), - accumulator.GetStats(kSlotId, kClassId, 0, 2)); - - // Partition 1, Feature 0. - EXPECT_GRADIENT_STATS_EQ(GradientStats(4.0f, 0.13f), - accumulator.GetStats(kSlotId, kClassId, 1, 0)); - // Partition 1, Feature 1. - EXPECT_GRADIENT_STATS_EQ(GradientStats(0.0f, 0.0f), - accumulator.GetStats(kSlotId, kClassId, 1, 1)); - // Partition 1, Feature 2. - EXPECT_GRADIENT_STATS_EQ(GradientStats(0.0f, 0.0f), - accumulator.GetStats(kSlotId, kClassId, 1, 2)); -} - -TEST_F(CategoricalFeatureColumnHandlerTest, GenerateFeatureSplitCandidates) { - // Create handler. - FeatureStatsAccumulator accumulator(1); - handler_->AggregateGradientStats( - example_partitions_, example_first_order_gradients_, - example_second_order_gradients_, &accumulator); - - // Get feature split candidates for two roots 0 and 1. - // The root stats are derived from the per-partition total gradient stats. - const std::vector roots = {0, 1, 5}; - const std::vector& root_stats = { - NodeStats(learner_config_, GradientStats(0.9f, 0.39f)), - NodeStats(learner_config_, GradientStats(4.0f, 0.13f)), NodeStats(1)}; - std::vector split_candidates; - handler_->GenerateFeatureSplitCandidates(learner_config_, roots, root_stats, - accumulator, &split_candidates); - // Expect three candidate splits (one per root). - EXPECT_EQ(3, split_candidates.size()); - - // Verify candidate for root 0, the best split occurs when we route - // example i0, i2 left and i1 right. - const NodeStats expected_left_node0(learner_config_, - GradientStats(0.2f + 1.2f, 0.12f + 0.2f)); - const NodeStats expected_right_node0( - learner_config_, - root_stats[0].gradient_stats - expected_left_node0.gradient_stats); - const SplitStats expected_split_stats0(learner_config_, root_stats[0], - expected_left_node0, - expected_right_node0); - EXPECT_SPLIT_STATS_EQ(expected_split_stats0, split_candidates[0].split_stats); - - const auto& tree_node0 = split_candidates[0].tree_node; - EXPECT_EQ( - boosted_trees::trees::TreeNode::kCategoricalIdBinarySplitFieldNumber, - tree_node0.node_case()); - const auto& split0 = tree_node0.categorical_id_binary_split(); - EXPECT_EQ(2, split0.feature_id()); - EXPECT_EQ(kFeatureColumn, split0.feature_column()); - - // Verify candidate for root 1, there's only one active feature here - // so zero gain is expected. - const NodeStats expected_left_node1(learner_config_, - root_stats[1].gradient_stats); - const NodeStats expected_right_node1(learner_config_, GradientStats(0, 0)); - const SplitStats expected_split_stats1(learner_config_, root_stats[1], - expected_left_node1, - expected_right_node1); - EXPECT_SPLIT_STATS_EQ(expected_split_stats1, split_candidates[1].split_stats); - const auto& tree_node1 = split_candidates[1].tree_node; - EXPECT_EQ( - boosted_trees::trees::TreeNode::kCategoricalIdBinarySplitFieldNumber, - tree_node1.node_case()); - const auto& split1 = tree_node1.categorical_id_binary_split(); - EXPECT_EQ(0, split1.feature_id()); - EXPECT_EQ(kFeatureColumn, split1.feature_column()); - - // Verify there are no candidate splits for root 5. - const auto& tree_node2 = split_candidates[2].tree_node; - EXPECT_EQ(boosted_trees::trees::TreeNode::NODE_NOT_SET, - tree_node2.node_case()); -} - -} // namespace -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler.cc b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler.cc deleted file mode 100644 index ca7bb71e7d0b0fc945ee29092b1e36022d4c0943..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler.cc +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { - -namespace { - -// Creates a dense split node without assigning children. -boosted_trees::trees::TreeNode CreateDenseSplitNode(const int32 feature_column, - const float threshold) { - boosted_trees::trees::TreeNode split_node; - auto* split = split_node.mutable_dense_float_binary_split(); - split->set_feature_column(feature_column); - split->set_threshold(threshold); - return split_node; -} - -} // namespace - -void DenseQuantizedFeatureColumnHandler::AggregateGradientStats( - const std::vector& example_partition_ids, - const Tensor& example_first_order_gradients, - const Tensor& example_second_order_gradients, - FeatureStatsAccumulator* - gradient_stats_accumulator) const { - // Pass over all examples and aggregate gradient stats for each partition - // and quantized feature bucket. - for (int64 example_idx = 0; example_idx < batch_size_; ++example_idx) { - auto partition_id = example_partition_ids[example_idx]; - auto feature_id = dense_quantized_values_(example_idx); - gradient_stats_accumulator->AddStats( - slot_id_, class_id_, partition_id, feature_id, - GradientStats(example_first_order_gradients, - example_second_order_gradients, example_idx)); - } -} - -void DenseQuantizedFeatureColumnHandler::GenerateFeatureSplitCandidates( - const LearnerConfig& learner_config, const std::vector& roots, - const std::vector& root_stats, - const FeatureStatsAccumulator& - gradient_stats_accumulator, - std::vector* split_candidates) const { - // Evaluate split candidates for every root as each is a separate - // logical partition over the examples. - // Then for each root, we do a forward-only pass over the quantized - // feature buckets accumulating gradients from left to right. - // Split gains are evaluated at every threshold and the best split is picked. - split_candidates->clear(); - split_candidates->reserve(roots.size()); - for (size_t root_idx = 0; root_idx < roots.size(); ++root_idx) { - // Get partition Id and root node stats. - const int32 partition_id = roots[root_idx]; - const NodeStats& root_node_stats = root_stats[root_idx]; - - // Forward left to right pass over quantiles. - GradientStats left_gradient_stats; - GradientStats right_gradient_stats(root_node_stats.gradient_stats); - FeatureSplitCandidate best_split_candidate( - root_node_stats.weight_contribution.size()); - best_split_candidate.split_stats.gain = - std::numeric_limits::lowest(); - for (int bucket_id = 0; bucket_id < dense_quantiles_.size(); ++bucket_id) { - // Get gradient stats. - auto gradient_stats = gradient_stats_accumulator.GetStats( - slot_id_, class_id_, partition_id, bucket_id); - if (gradient_stats.IsZero()) { - continue; - } - - // Update gradient stats. - left_gradient_stats += gradient_stats; - right_gradient_stats = - root_node_stats.gradient_stats - left_gradient_stats; - - // Get node stats - NodeStats left_node_stats(learner_config, left_gradient_stats); - NodeStats right_node_stats(learner_config, right_gradient_stats); - - // Generate split candidate. - const float threshold = dense_quantiles_(bucket_id); - FeatureSplitCandidate split_candidate( - slot_id_, CreateDenseSplitNode(dense_feature_column_, threshold), - SplitStats(learner_config, root_node_stats, left_node_stats, - right_node_stats)); - if (split_candidate.split_stats.gain > - best_split_candidate.split_stats.gain) { - best_split_candidate = std::move(split_candidate); - } - } - - // Add best candidate for partition. - split_candidates->push_back(std::move(best_split_candidate)); - } -} - -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler.h b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler.h deleted file mode 100644 index 0f3858e4d8c406e9ec3ae7079b241e94ef4aa35c..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler.h +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_DENSE_QUANTIZED_FEATURE_COLUMN_HANDLER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_DENSE_QUANTIZED_FEATURE_COLUMN_HANDLER_H_ - -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/feature-column-handler.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { - -// Handler for a dense quantized feature column in the single class case. -class DenseQuantizedFeatureColumnHandler : public FeatureColumnHandler { - public: - DenseQuantizedFeatureColumnHandler( - const int32 class_id, const int32 slot_id, const int64 batch_size, - const int32 dense_feature_column, TTypes::ConstVec dense_quantiles, - TTypes::ConstVec dense_quantized_values) - : FeatureColumnHandler(class_id, slot_id, batch_size), - dense_feature_column_(dense_feature_column), - dense_quantiles_(dense_quantiles), - dense_quantized_values_(dense_quantized_values) {} - - void AggregateGradientStats( - const std::vector& example_partition_ids, - const Tensor& example_first_order_gradients, - const Tensor& example_second_order_gradients, - FeatureStatsAccumulator* - gradient_stats_accumulator) const override; - - void GenerateFeatureSplitCandidates( - const LearnerConfig& learner_config, const std::vector& roots, - const std::vector& root_stats, - const FeatureStatsAccumulator& - gradient_stats_accumulator, - std::vector* split_candidates) const override; - - protected: - const int32 dense_feature_column_; - TTypes::ConstVec dense_quantiles_; - TTypes::ConstVec dense_quantized_values_; -}; - -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_DENSE_QUANTIZED_FEATURE_COLUMN_HANDLER_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler_test.cc b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler_test.cc deleted file mode 100644 index 1bc9d733ad3090f1cfc9547644061f54d7d2c8c6..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler_test.cc +++ /dev/null @@ -1,155 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler.h" - -#include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { -namespace { - -using boosted_trees::learner::LearnerConfig; - -const auto kClassId = 1; -const auto kSlotId = 0; -const auto kBatchSize = 4; -const auto kFeatureColumn = 2; - -using FeatureStatsAccumulator = - FeatureStatsAccumulator; - -class DenseQuantizedFeatureColumnHandlerTest : public ::testing::Test { - protected: - // The data looks like the following: - // Example | Gradients | Partition | Dense Quantile | - // i0 | (0.2, 0.12) | 0 | 1 | - // i1 | (-0.5, 0.07) | 0 | 1 | - // i2 | (1.2, 0.2) | 0 | 0 | - // i3 | (4.0, 0.13) | 1 | 1 | - DenseQuantizedFeatureColumnHandlerTest() - : example_first_order_gradients_( - test::AsTensor({0.2f, -0.5f, 1.2f, 4.0f}, {4})), - example_second_order_gradients_( - test::AsTensor({0.12f, 0.07f, 0.2f, 0.13f}, {4})), - example_partitions_({0, 0, 0, 1}), - dense_quantiles_(test::AsTensor({0.3f, 0.52f}, {2})), - dense_quantized_values_(test::AsTensor({1, 1, 0, 1}, {4})) { - // Set L2 regularization. - learner_config_.mutable_regularization()->set_l2(2.0f); - learner_config_.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS); - // Create handler. - handler_.reset(new DenseQuantizedFeatureColumnHandler( - kClassId, kSlotId, kBatchSize, kFeatureColumn, - dense_quantiles_.vec(), dense_quantized_values_.vec())); - } - - LearnerConfig learner_config_; - const Tensor example_first_order_gradients_; - const Tensor example_second_order_gradients_; - const std::vector example_partitions_; - const Tensor dense_quantiles_; - const Tensor dense_quantized_values_; - std::unique_ptr handler_; -}; - -TEST_F(DenseQuantizedFeatureColumnHandlerTest, AggregateGradientStats) { - // Create handler. - FeatureStatsAccumulator accumulator(1); - handler_->AggregateGradientStats( - example_partitions_, example_first_order_gradients_, - example_second_order_gradients_, &accumulator); - - // Check stats for each partition and feature. - // Partition 0, Feature 0. - EXPECT_GRADIENT_STATS_EQ(GradientStats(1.2f, 0.2f), - accumulator.GetStats(kSlotId, kClassId, 0, 0)); - // Partition 0, Feature 1. - EXPECT_GRADIENT_STATS_EQ(GradientStats(-0.3f, 0.19f), - accumulator.GetStats(kSlotId, kClassId, 0, 1)); - // Partition 1, Feature 0. - EXPECT_GRADIENT_STATS_EQ(GradientStats(0.0f, 0.0f), - accumulator.GetStats(kSlotId, kClassId, 1, 0)); - // Partition 1, Feature 1. - EXPECT_GRADIENT_STATS_EQ(GradientStats(4.0f, 0.13f), - accumulator.GetStats(kSlotId, kClassId, 1, 1)); -} - -TEST_F(DenseQuantizedFeatureColumnHandlerTest, GenerateFeatureSplitCandidates) { - // Create handler. - FeatureStatsAccumulator accumulator(1); - handler_->AggregateGradientStats( - example_partitions_, example_first_order_gradients_, - example_second_order_gradients_, &accumulator); - - // Get feature split candidates for two roots 0 and 1. - // The root stats are derived from the per-partition total gradient stats. - const std::vector roots = {0, 1, 5}; - const std::vector& root_stats = { - NodeStats(learner_config_, GradientStats(0.9f, 0.39f)), - NodeStats(learner_config_, GradientStats(4.0f, 0.13f)), NodeStats(1)}; - std::vector split_candidates; - handler_->GenerateFeatureSplitCandidates(learner_config_, roots, root_stats, - accumulator, &split_candidates); - // Expect three candidate splits (one per root). - EXPECT_EQ(3, split_candidates.size()); - - // Verify candidate for root 0, the best split occurs when we route - // example i2 left and i0, i1 right. - const NodeStats expected_left_node0(learner_config_, - GradientStats(1.2f, 0.2f)); - const NodeStats expected_right_node0( - learner_config_, - root_stats[0].gradient_stats - expected_left_node0.gradient_stats); - const SplitStats expected_split_stats0(learner_config_, root_stats[0], - expected_left_node0, - expected_right_node0); - EXPECT_SPLIT_STATS_EQ(expected_split_stats0, split_candidates[0].split_stats); - const auto& tree_node0 = split_candidates[0].tree_node; - EXPECT_EQ(boosted_trees::trees::TreeNode::kDenseFloatBinarySplit, - tree_node0.node_case()); - const auto& split0 = tree_node0.dense_float_binary_split(); - EXPECT_FLOAT_EQ(dense_quantiles_.vec()(0), split0.threshold()); - EXPECT_EQ(kFeatureColumn, split0.feature_column()); - - // Verify candidate for root 1, there's only one active bucket here - // so zero gain is expected. - const NodeStats expected_left_node1(learner_config_, - root_stats[1].gradient_stats); - const NodeStats expected_right_node1(learner_config_, GradientStats(0, 0)); - const SplitStats expected_split_stats1(learner_config_, root_stats[1], - expected_left_node1, - expected_right_node1); - EXPECT_SPLIT_STATS_EQ(expected_split_stats1, split_candidates[1].split_stats); - const auto& tree_node1 = split_candidates[1].tree_node; - EXPECT_EQ(boosted_trees::trees::TreeNode::kDenseFloatBinarySplit, - tree_node1.node_case()); - const auto& split1 = tree_node1.dense_float_binary_split(); - EXPECT_FLOAT_EQ(dense_quantiles_.vec()(1), split1.threshold()); - EXPECT_EQ(kFeatureColumn, split1.feature_column()); - - // Verify there are no candidate splits for root 5. - const auto& tree_node2 = split_candidates[2].tree_node; - EXPECT_EQ(boosted_trees::trees::TreeNode::NODE_NOT_SET, - tree_node2.node_case()); -} - -} // namespace -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/feature-column-handler.h b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/feature-column-handler.h deleted file mode 100644 index 8bd2092f9609cb684b89f70cab35a92789fb39a4..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/feature-column-handler.h +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_FEATURE_COLUMN_HANDLER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_FEATURE_COLUMN_HANDLER_H_ - -#include -#include "tensorflow/contrib/boosted_trees/lib/learner/common/accumulators/feature-stats-accumulator.h" -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/feature-split-candidate.h" -#include "tensorflow/contrib/boosted_trees/proto/learner.pb.h" -#include "tensorflow/core/framework/attr_value.pb.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_types.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { - -// Handler interface for feature columns. Each feature column type may -// have its own handler which encapsulates the logic of aggregating gradient -// stats as well as generating split candidates for each partition. -// Handlers can be stateful and must be thread compatible. -class FeatureColumnHandler { - public: - FeatureColumnHandler(const int32 class_id, const int32 slot_id, - const int64 batch_size) - : class_id_(class_id), slot_id_(slot_id), batch_size_(batch_size) {} - - virtual ~FeatureColumnHandler() {} - FeatureColumnHandler(const FeatureColumnHandler& other) = delete; - FeatureColumnHandler& operator=(const FeatureColumnHandler& other) = delete; - - // Aggregates example gradient stats for the feature column. - virtual void AggregateGradientStats( - const std::vector& example_partition_ids, - const Tensor& example_first_order_gradients, - const Tensor& example_second_order_gradients, - FeatureStatsAccumulator* - gradient_stats_accumulator) const = 0; - - // Generates feature column split candidates for the specified roots. - virtual void GenerateFeatureSplitCandidates( - const LearnerConfig& learner_config, const std::vector& roots, - const std::vector& root_stats, - const FeatureStatsAccumulator& - gradient_stats_accumulator, - std::vector* split_candidates) const = 0; - - // Accessors. - int32 class_id() const { return class_id_; } - int32 slot_id() const { return slot_id_; } - int64 batch_size() const { return batch_size_; } - - protected: - // The class Id. - const int32 class_id_; - - // The slod Id for use as a unique Id across all feature columns. - const int32 slot_id_; - - // Size of the batch of examples. - const int64 batch_size_; -}; - -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_FEATURE_COLUMN_HANDLER_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler.cc b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler.cc deleted file mode 100644 index a0e9efbbc5030e8c2e25fafab98271337a2e582a..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler.cc +++ /dev/null @@ -1,172 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { - -namespace { - -// Creates a sparse default right split node without assigning children. -boosted_trees::trees::TreeNode CreateSparseSplitNodeDefaultRight( - int32 feature_column, float threshold) { - boosted_trees::trees::TreeNode split_node; - auto* split = split_node.mutable_sparse_float_binary_split_default_right() - ->mutable_split(); - split->set_feature_column(feature_column); - split->set_threshold(threshold); - return split_node; -} - -// Creates a sparse default left split node without assigning children. -boosted_trees::trees::TreeNode CreateSparseSplitNodeDefaultLeft( - int32 feature_column, float threshold) { - boosted_trees::trees::TreeNode split_node; - auto* split = split_node.mutable_sparse_float_binary_split_default_left() - ->mutable_split(); - split->set_feature_column(feature_column); - split->set_threshold(threshold); - return split_node; -} - -} // namespace - -void SparseQuantizedFeatureColumnHandler::AggregateGradientStats( - const std::vector& example_partition_ids, - const Tensor& example_first_order_gradients, - const Tensor& example_second_order_gradients, - FeatureStatsAccumulator* - gradient_stats_accumulator) const { - // Pass over all rows and aggregate gradient stats for each partition - // and quantized feature bucket. - const int64 num_rows = sparse_indices_.dimension(0); - for (int64 row_idx = 0; row_idx < num_rows; ++row_idx) { - auto example_idx = sparse_indices_(row_idx, 0); - auto partition_id = example_partition_ids[example_idx]; - auto feature_id = sparse_quantized_values_(row_idx); - gradient_stats_accumulator->AddStats( - slot_id_, class_id_, partition_id, feature_id, - GradientStats(example_first_order_gradients, - example_second_order_gradients, example_idx)); - } -} - -void SparseQuantizedFeatureColumnHandler::GenerateFeatureSplitCandidates( - const LearnerConfig& learner_config, const std::vector& roots, - const std::vector& root_stats, - const FeatureStatsAccumulator& - gradient_stats_accumulator, - std::vector* split_candidates) const { - // Evaluate split candidates for every root as each is a separate - // logical partition over the examples. - // Then for each root, we do both a forward left to right pass and a backward - // right to left pass over the quantized feature buckets accumulating - // gradients on one side and using the root aggregate gradients to get the - // gradients for the other side. Split gains are evaluated for each pass at - // every threshold and the best split is picked. - split_candidates->clear(); - split_candidates->reserve(roots.size()); - for (size_t root_idx = 0; root_idx < roots.size(); ++root_idx) { - // Get partition Id and root node stats. - const int32 partition_id = roots[root_idx]; - const NodeStats& root_node_stats = root_stats[root_idx]; - - // Forward pass with right default direction. - GradientStats left_gradient_stats; - GradientStats right_gradient_stats(root_node_stats.gradient_stats); - FeatureSplitCandidate best_split_candidate( - root_node_stats.weight_contribution.size()); - best_split_candidate.split_stats.gain = - std::numeric_limits::lowest(); - for (int bucket_id = 0; bucket_id < sparse_quantiles_.size(); ++bucket_id) { - // Get gradient stats. - auto gradient_stats = gradient_stats_accumulator.GetStats( - slot_id_, class_id_, partition_id, bucket_id); - if (gradient_stats.IsZero()) { - continue; - } - - // Update gradient stats. - left_gradient_stats += gradient_stats; - right_gradient_stats = - root_node_stats.gradient_stats - left_gradient_stats; - - // Get node stats - NodeStats left_node_stats(learner_config, left_gradient_stats); - NodeStats right_node_stats(learner_config, right_gradient_stats); - - // Generate split candidate. - const float threshold = sparse_quantiles_(bucket_id); - FeatureSplitCandidate split_candidate( - slot_id_, - CreateSparseSplitNodeDefaultRight(sparse_feature_column_, threshold), - SplitStats(learner_config, root_node_stats, left_node_stats, - right_node_stats)); - if (split_candidate.split_stats.gain > - best_split_candidate.split_stats.gain) { - best_split_candidate = std::move(split_candidate); - } - } - - // Determine if we need a backward pass by checking if the residual gradient - // after forward aggregation is almost the same as the aggregated gradient. - // for the current root. This helps avoid unnecessary computation as well - // as consistency due to floating point precision. - if (!right_gradient_stats.IsAlmostZero()) { - // Backward pass with left default direction. - right_gradient_stats = GradientStats(); - left_gradient_stats = root_node_stats.gradient_stats; - for (int bucket_id = sparse_quantiles_.size() - 1; bucket_id > 0; - --bucket_id) { - // Get gradient stats. - auto gradient_stats = gradient_stats_accumulator.GetStats( - slot_id_, class_id_, partition_id, bucket_id); - if (gradient_stats.IsZero()) { - continue; - } - - // Update gradient stats. - right_gradient_stats += gradient_stats; - left_gradient_stats = root_node_stats.gradient_stats - gradient_stats; - - // Get node stats - NodeStats left_node_stats(learner_config, left_gradient_stats); - NodeStats right_node_stats(learner_config, right_gradient_stats); - - // Generate split candidate. - const float threshold = sparse_quantiles_(bucket_id - 1); - FeatureSplitCandidate split_candidate( - slot_id_, - CreateSparseSplitNodeDefaultLeft(sparse_feature_column_, threshold), - SplitStats(learner_config, root_node_stats, left_node_stats, - right_node_stats)); - if (split_candidate.split_stats.gain > - best_split_candidate.split_stats.gain) { - best_split_candidate = std::move(split_candidate); - } - } - } - - // Add best candidate for partition. - split_candidates->push_back(std::move(best_split_candidate)); - } -} - -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler.h b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler.h deleted file mode 100644 index eb63e705471a65e8448bda38b2e31eb971d5c1bb..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler.h +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_SPARSE_QUANTIZED_FEATURE_COLUMN_HANDLER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_SPARSE_QUANTIZED_FEATURE_COLUMN_HANDLER_H_ - -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/feature-column-handler.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { - -// Handler for a sparse quantized feature column in the single class case. -class SparseQuantizedFeatureColumnHandler : public FeatureColumnHandler { - public: - SparseQuantizedFeatureColumnHandler( - const int32 class_id, const int32 slot_id, const int64 batch_size, - const int32 sparse_feature_column, - TTypes::ConstVec sparse_quantiles, - TTypes::ConstMatrix sparse_indices, - TTypes::ConstVec sparse_quantized_values) - : FeatureColumnHandler(class_id, slot_id, batch_size), - sparse_feature_column_(sparse_feature_column), - sparse_quantiles_(sparse_quantiles), - sparse_indices_(sparse_indices), - sparse_quantized_values_(sparse_quantized_values) {} - - void AggregateGradientStats( - const std::vector& example_partition_ids, - const Tensor& example_first_order_gradients, - const Tensor& example_second_order_gradients, - FeatureStatsAccumulator* - gradient_stats_accumulator) const override; - - void GenerateFeatureSplitCandidates( - const LearnerConfig& learner_config, const std::vector& roots, - const std::vector& root_stats, - const FeatureStatsAccumulator& - gradient_stats_accumulator, - std::vector* split_candidates) const override; - - protected: - const int32 sparse_feature_column_; - TTypes::ConstVec sparse_quantiles_; - TTypes::ConstMatrix sparse_indices_; - TTypes::ConstVec sparse_quantized_values_; -}; - -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_SPARSE_QUANTIZED_FEATURE_COLUMN_HANDLER_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler_test.cc b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler_test.cc deleted file mode 100644 index 643d936ad23850e601bc5518d69c8637011f53c0..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler_test.cc +++ /dev/null @@ -1,162 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler.h" - -#include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { -namespace { - -using boosted_trees::learner::LearnerConfig; - -const auto kClassId = 3; -const auto kSlotId = 0; -const auto kBatchSize = 4; -const auto kFeatureColumn = 4; - -using FeatureStatsAccumulator = - FeatureStatsAccumulator; - -class SparseQuantizedFeatureColumnHandlerTest : public ::testing::Test { - protected: - // The data looks like the following: - // Example | Gradients | Partition | Sparse Quantile | - // i0 | (0.2, 0.12) | 0 | 1 | - // i1 | (-0.5, 0.07) | 0 | N/A | - // i2 | (1.2, 0.2) | 0 | 0 | - // i3 | (4.0, 0.13) | 1 | 1 | - SparseQuantizedFeatureColumnHandlerTest() - : example_first_order_gradients_( - test::AsTensor({0.2f, -0.5f, 1.2f, 4.0f}, {4})), - example_second_order_gradients_( - test::AsTensor({0.12f, 0.07f, 0.2f, 0.13f}, {4})), - example_partitions_({0, 0, 0, 1}), - sparse_quantiles_(test::AsTensor({0.3f, 0.52f}, {2})), - sparse_indices_(test::AsTensor({0, 0, 2, 0, 3, 0}, {3, 2})), - sparse_quantized_values_(test::AsTensor({1, 0, 1}, {3})) { - // Set L2 regularization. - learner_config_.mutable_regularization()->set_l2(2.0f); - learner_config_.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS); - // Create handler. - handler_.reset(new SparseQuantizedFeatureColumnHandler( - kClassId, kSlotId, kBatchSize, kFeatureColumn, - sparse_quantiles_.vec(), sparse_indices_.matrix(), - sparse_quantized_values_.vec())); - } - - LearnerConfig learner_config_; - const Tensor example_first_order_gradients_; - const Tensor example_second_order_gradients_; - const std::vector example_partitions_; - const Tensor sparse_quantiles_; - const Tensor sparse_indices_; - const Tensor sparse_quantized_values_; - std::unique_ptr handler_; -}; - -TEST_F(SparseQuantizedFeatureColumnHandlerTest, AggregateGradientStats) { - // Create handler. - FeatureStatsAccumulator accumulator(1); - handler_->AggregateGradientStats( - example_partitions_, example_first_order_gradients_, - example_second_order_gradients_, &accumulator); - - // Check stats for each partition and feature. - // Partition 0, Feature 0. - EXPECT_GRADIENT_STATS_EQ(GradientStats(1.2f, 0.2f), - accumulator.GetStats(kSlotId, kClassId, 0, 0)); - // Partition 0, Feature 1. - EXPECT_GRADIENT_STATS_EQ(GradientStats(0.2f, 0.12f), - accumulator.GetStats(kSlotId, kClassId, 0, 1)); - // Partition 1, Feature 0. - EXPECT_GRADIENT_STATS_EQ(GradientStats(0.0f, 0.0f), - accumulator.GetStats(kSlotId, kClassId, 1, 0)); - // Partition 1, Feature 1. - EXPECT_GRADIENT_STATS_EQ(GradientStats(4.0f, 0.13f), - accumulator.GetStats(kSlotId, kClassId, 1, 1)); -} - -TEST_F(SparseQuantizedFeatureColumnHandlerTest, - GenerateFeatureSplitCandidates) { - // Create handler. - FeatureStatsAccumulator accumulator(1); - handler_->AggregateGradientStats( - example_partitions_, example_first_order_gradients_, - example_second_order_gradients_, &accumulator); - - // Get feature split candidates for two roots 0 and 1. - // The root stats are derived from the per-partition total gradient stats. - const std::vector roots = {0, 1, 9}; - const std::vector& root_stats = { - NodeStats(learner_config_, GradientStats(0.9f, 0.39f)), - NodeStats(learner_config_, GradientStats(4.0f, 0.13f)), NodeStats(1)}; - std::vector split_candidates; - handler_->GenerateFeatureSplitCandidates(learner_config_, roots, root_stats, - accumulator, &split_candidates); - // Expect three candidate splits (one per root). - EXPECT_EQ(3, split_candidates.size()); - - // Verify candidate for root 0, the best split occurs when we route - // example i0 and i2 to the left and i1 to the right (by default direction). - const NodeStats expected_left_node0(learner_config_, - GradientStats(0.2f + 1.2f, 0.12f + 0.2f)); - const NodeStats expected_right_node0( - learner_config_, - root_stats[0].gradient_stats - expected_left_node0.gradient_stats); - const SplitStats expected_split_stats0(learner_config_, root_stats[0], - expected_left_node0, - expected_right_node0); - EXPECT_SPLIT_STATS_EQ(expected_split_stats0, split_candidates[0].split_stats); - const auto& tree_node0 = split_candidates[0].tree_node; - EXPECT_EQ(boosted_trees::trees::TreeNode::kSparseFloatBinarySplitDefaultRight, - tree_node0.node_case()); - const auto& split0 = - tree_node0.sparse_float_binary_split_default_right().split(); - EXPECT_FLOAT_EQ(sparse_quantiles_.vec()(1), split0.threshold()); - EXPECT_EQ(kFeatureColumn, split0.feature_column()); - - // Verify candidate for root 1, there's only one active bucket here - // so zero gain is expected. - const NodeStats expected_left_node1(learner_config_, - root_stats[1].gradient_stats); - const NodeStats expected_right_node1(learner_config_, GradientStats(0, 0)); - const SplitStats expected_split_stats1(learner_config_, root_stats[1], - expected_left_node1, - expected_right_node1); - EXPECT_SPLIT_STATS_EQ(expected_split_stats1, split_candidates[1].split_stats); - const auto& tree_node1 = split_candidates[1].tree_node; - EXPECT_EQ(boosted_trees::trees::TreeNode::kSparseFloatBinarySplitDefaultRight, - tree_node1.node_case()); - const auto& split1 = - tree_node1.sparse_float_binary_split_default_right().split(); - EXPECT_FLOAT_EQ(sparse_quantiles_.vec()(1), split1.threshold()); - EXPECT_EQ(kFeatureColumn, split1.feature_column()); - - // Verify there are no candidate splits for root 9. - const auto& tree_node2 = split_candidates[2].tree_node; - EXPECT_EQ(boosted_trees::trees::TreeNode::NODE_NOT_SET, - tree_node2.node_case()); -} - -} // namespace -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc index f8750e7191673274772fc869c198dd5fbbefbc49..0e5578693a7b90b16eada1127cad992612fb6dad 100644 --- a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc +++ b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc @@ -52,13 +52,13 @@ int DecisionTree::Traverse(const DecisionTreeConfig& config, example.sparse_float_features[split.feature_column()]; // Feature id for the split when multivalent sparse float column, or 0 // by default. - const int32 feature_id = split.feature_id(); + const int32 dimension_id = split.dimension_id(); - node_id = - !sparse_feature[feature_id].has_value() || - sparse_feature[feature_id].get_value() <= split.threshold() - ? split.left_id() - : split.right_id(); + node_id = !sparse_feature[dimension_id].has_value() || + sparse_feature[dimension_id].get_value() <= + split.threshold() + ? split.left_id() + : split.right_id(); break; } case TreeNode::kSparseFloatBinarySplitDefaultRight: { @@ -68,12 +68,12 @@ int DecisionTree::Traverse(const DecisionTreeConfig& config, example.sparse_float_features[split.feature_column()]; // Feature id for the split when multivalent sparse float column, or 0 // by default. - const int32 feature_id = split.feature_id(); - node_id = - sparse_feature[feature_id].has_value() && - sparse_feature[feature_id].get_value() <= split.threshold() - ? split.left_id() - : split.right_id(); + const int32 dimension_id = split.dimension_id(); + node_id = sparse_feature[dimension_id].has_value() && + sparse_feature[dimension_id].get_value() <= + split.threshold() + ? split.left_id() + : split.right_id(); break; } case TreeNode::kCategoricalIdBinarySplit: { diff --git a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree_test.cc b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree_test.cc index 93924d429c19aef51b6f1d85655de3798a76e3e0..58fe8e335af28fe811c1ee785578aa58d898335b 100644 --- a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree_test.cc @@ -190,7 +190,7 @@ TEST_F(DecisionTreeTest, TraverseSparseBinarySplit) { tree_config.add_nodes()->mutable_leaf(); // Split on first column - split_node->set_feature_id(0); + split_node->set_dimension_id(0); split_node->set_threshold(2.0f); // Both instances have this feature value. @@ -199,7 +199,7 @@ TEST_F(DecisionTreeTest, TraverseSparseBinarySplit) { EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it)); // Split on second column - split_node->set_feature_id(1); + split_node->set_dimension_id(1); split_node->set_threshold(5.0f); // First instance does not have it (default right), second does have it. @@ -208,7 +208,7 @@ TEST_F(DecisionTreeTest, TraverseSparseBinarySplit) { EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it)); // Split on third column - split_node->set_feature_id(2); + split_node->set_dimension_id(2); split_node->set_threshold(3.0f); example_it = example_iterable.begin(); diff --git a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h index 7a550d6f7328765d8815a947885e47fa0b0a8f8b..badc629a118f768d5aa25ef1b94b8190e6910c7f 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h @@ -56,7 +56,7 @@ class BatchFeatures { *num_sparse_int_features = sparse_int_feature_columns_.size(); if (*num_dense_float_features == 0 && *num_sparse_float_features == 0 && *num_sparse_int_features == 0) { - return errors::FailedPrecondition("Not intialized yet."); + return errors::FailedPrecondition("Not initialized yet."); } return Status::OK(); } diff --git a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc index 0d46565a1962b88cbb267f3d6043610758790578..ccee9530b6897924453461c13b1238402c0f6cfa 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc @@ -97,7 +97,7 @@ class IndicesRowIterator } bool operator<(const IndicesRowIterator& other) const { - return (row_idx_ < other.row_idx_); + return (row_idx_ < other.row_idx_); } bool operator==(const IndicesRowIterator& other) const { diff --git a/tensorflow/contrib/boosted_trees/proto/tree_config.proto b/tensorflow/contrib/boosted_trees/proto/tree_config.proto index f14abf45a517ad7c4c6d7bb1ab88b7a1d47d6fb6..fc570c1083d01a65760a456c109dad93afd9f62a 100644 --- a/tensorflow/contrib/boosted_trees/proto/tree_config.proto +++ b/tensorflow/contrib/boosted_trees/proto/tree_config.proto @@ -53,9 +53,9 @@ message DenseFloatBinarySplit { // Float feature column and split threshold describing // the rule feature <= threshold. int32 feature_column = 1; - // If feature column is multivalent, this holds the index of the feature for - // the split. Defaults to 0. - int32 feature_id = 5; + // If feature column is multivalent, this holds the index of the dimensiong + // for the split. Defaults to 0. + int32 dimension_id = 5; float threshold = 2; // Node children indexing into a contiguous diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py index 9ada844601afbe7f0a6993444c7c4ed0e16a01ca..c1acf351603dd80c2d14c7ee0a5b4c89706bc1bf 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py @@ -93,7 +93,7 @@ def _set_float_split(split, feat_col, thresh, l_id, r_id, feature_dim_id=None): split.left_id = l_id split.right_id = r_id if feature_dim_id is not None: - split.feature_id = feature_dim_id + split.dimension_id = feature_dim_id def _set_categorical_id_split(split, feat_col, feat_id, l_id, r_id): diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py index 2a72961504b7e8a256afd8f77dce79ba756230f0..888d5c57ed33446c8b6f18d2d1e393647613d132 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py @@ -48,15 +48,16 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase): def testBasicQuantileBuckets(self): """Sets up the quantile summary op test as follows. - Create a batch of 6 examples having a dense and sparse features. + Create a batch of 6 examples having a dense and sparse features. SparseM is + a sparse multi-dimensional (multivalent) feature. The data looks like this - | Instance | instance weights | Dense 0 | Sparse 0 - | 0 | 10 | 1 | - | 1 | 1 | 2 | 2 - | 2 | 1 | 3 | 3 - | 3 | 1 | 4 | 4 - | 4 | 1 | 4 | 5 - | 5 | 1 | 5 | 6 + | Instance | instance weights | Dense 0 | Sparse 0 | SparseM + | 0 | 10 | 1 | | | | + | 1 | 1 | 2 | 2 | 2 | | + | 2 | 1 | 3 | 3 | 3 | | + | 3 | 1 | 4 | 4 | | 4 | + | 4 | 1 | 4 | 5 | | 5 | + | 5 | 1 | 5 | 6 | | 6 | """ dense_float_tensor_0 = constant_op.constant( @@ -66,20 +67,29 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase): sparse_values_0 = constant_op.constant( [2, 3, 4, 5, 6], dtype=dtypes.float32) sparse_shape_0 = constant_op.constant([6, 1], dtype=dtypes.int64) + # Multi-dimensional feature that should have the same quantiles as Sparse 0. + sparse_indices_m = constant_op.constant( + [[1, 1], [2, 0], [3, 1], [4, 1], [5, 1]], dtype=dtypes.int64) + sparse_values_m = constant_op.constant( + [2, 3, 4, 5, 6], dtype=dtypes.float32) + sparse_shape_m = constant_op.constant([6, 2], dtype=dtypes.int64) + example_weights = constant_op.constant( [10, 1, 1, 1, 1, 1], dtype=dtypes.float32) with self.test_session(): config = self._gen_config(0.33, 3) dense_buckets, sparse_buckets = quantile_ops.quantile_buckets( - [dense_float_tensor_0], [sparse_indices_0], [sparse_values_0], - [sparse_shape_0], + [dense_float_tensor_0], [sparse_indices_0, sparse_indices_m], + [sparse_values_0, sparse_values_m], [sparse_shape_0, sparse_shape_m], example_weights=example_weights, dense_config=[config], - sparse_config=[config]) + sparse_config=[config, config]) self.assertAllEqual([1, 3, 5], dense_buckets[0].eval()) self.assertAllEqual([2, 4, 6.], sparse_buckets[0].eval()) + # Multidimensional sparse. + self.assertAllEqual([2, 4, 6.], sparse_buckets[1].eval()) def testStreamingQuantileBucketsWithVaryingBatch(self): """Sets up the quantile summary op test as follows. @@ -214,10 +224,10 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase): resources.initialize_resources(resources.shared_resources()).run() sparse_indices_0 = constant_op.constant( - [[1, 0], [2, 0], [3, 0], [4, 0], [5, 0]], dtype=dtypes.int64) + [[1, 0], [2, 1], [3, 0], [4, 2], [5, 0]], dtype=dtypes.int64) sparse_values_0 = constant_op.constant( [2.0, 3.0, 4.0, 5.0, 6.0], dtype=dtypes.float32) - sparse_shape_0 = constant_op.constant([6, 1], dtype=dtypes.int64) + sparse_shape_0 = constant_op.constant([6, 3], dtype=dtypes.int64) example_weights = constant_op.constant( [10, 1, 1, 1, 1, 1], dtype=dtypes.float32, shape=[6, 1]) update = accumulator.add_summary( diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py index 7c2e3a3b208c696731ef12be5e9cbab66dc99355..28834ef55bf8e1f32cc8f2380a4be3bf3824d8e1 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py @@ -240,7 +240,7 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): self.assertEqual(0, split_node.split.feature_column) # Sparse is one dimensional. - self.assertEqual(0, split_node.split.feature_id) + self.assertEqual(0, split_node.split.dimension_id) self.assertAllClose(0.52, split_node.split.threshold) @@ -263,7 +263,7 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): self.assertEqual(0, split_node.split.feature_column) # Sparse is one dimensional. - self.assertEqual(0, split_node.split.feature_id) + self.assertEqual(0, split_node.split.dimension_id) self.assertAllClose(0.52, split_node.split.threshold) @@ -373,7 +373,7 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): self.assertEqual(0, split_node.split.feature_column) # Split happened on second dimension. - self.assertEqual(1, split_node.split.feature_id) + self.assertEqual(1, split_node.split.dimension_id) self.assertAllClose(0.58, split_node.split.threshold) @@ -395,7 +395,7 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): self.assertAllClose([expected_right_weight], right_child.value) self.assertEqual(0, split_node.split.feature_column) - self.assertEqual(2, split_node.split.feature_id) + self.assertEqual(2, split_node.split.dimension_id) self.assertAllClose(0.6, split_node.split.threshold) diff --git a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py index 7e8e15e7d8c89d1adaa472b1da7e8bb3c73ca17e..294e04002adac62fc123a3242a05a1b36f422433 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py @@ -45,6 +45,7 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): init_stamp_token, epsilon, num_quantiles, + max_elements=None, name=None, container=None): """Creates a QuantileAccumulator object. @@ -53,6 +54,7 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): init_stamp_token: The initial value for the stamp token. epsilon: Error bound on the quantile computation. num_quantiles: Number of quantiles to produce from the final summary. + max_elements: Maximum number of elements added to the accumulator. name: the name to save the accumulator under. container: An optional `string`. Defaults to `""` """ @@ -67,6 +69,7 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): self._quantile_accumulator_handle, init_stamp_token, epsilon=epsilon, + max_elements=max_elements, num_quantiles=num_quantiles) is_initialized_op = gen_quantile_ops.quantile_accumulator_is_initialized( self._quantile_accumulator_handle) diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index 6094dae6b59d8b05bb12a28cf167a536e6825287..b95956dae2a62b28643cd31815c5f5650eca337b 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -322,9 +322,11 @@ class GradientBoostedDecisionTreeModel(object): self._feature_columns = feature_columns self._learner_config_serialized = learner_config.SerializeToString() self._attempted_trees = variables.Variable( - initial_value=array_ops.zeros([], dtypes.int64), trainable=False) + initial_value=array_ops.zeros([], dtypes.int64), trainable=False, + name="attempted_trees") self._finalized_trees = variables.Variable( - initial_value=array_ops.zeros([], dtypes.int64), trainable=False) + initial_value=array_ops.zeros([], dtypes.int64), trainable=False, + name="finalized_trees") if not features: raise ValueError("Features dictionary must be specified.") (fc_names, dense_floats, sparse_float_indices, sparse_float_values, diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py index 16e24d97ddee0751e0b808b89080074c1b4baba7..dba51d4f527792d2a8dedc693f74c07119fd231d 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py @@ -912,8 +912,10 @@ class GbdtTest(test_util.TensorFlowTestCase): self.assertEqual(1, len(output.trees[0].nodes[2].leaf.sparse_vector.index)) self.assertEqual(3, output.trees[0].nodes[2].leaf.sparse_vector.index[0]) - self.assertAlmostEqual( - 0.893284678459, output.trees[0].nodes[2].leaf.sparse_vector.value[0]) + self.assertAllClose( + 0.893284678459, + output.trees[0].nodes[2].leaf.sparse_vector.value[0], + atol=1e-4, rtol=1e-4) if __name__ == "__main__": diff --git a/tensorflow/contrib/boosted_trees/python/utils/losses_test.py b/tensorflow/contrib/boosted_trees/python/utils/losses_test.py index dde16426863b60e9df64da1ee6b36caec273bfd6..ccb8509c0347f9c9b6f1e8f4f620230aac9a6c2d 100644 --- a/tensorflow/contrib/boosted_trees/python/utils/losses_test.py +++ b/tensorflow/contrib/boosted_trees/python/utils/losses_test.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import math - import numpy as np from tensorflow.contrib.boosted_trees.python.utils import losses @@ -60,35 +58,27 @@ class LossesTest(test_util.TensorFlowTestCase): neg_loss = loss_for_negatives.eval() # For positive labels, points <= 0.3 get max loss of e. # For negative labels, these points have minimum loss of 1/e. - for i in range(2): - self.assertAlmostEqual(math.exp(1), pos_loss[i], places=4) - self.assertAlmostEqual(math.exp(-1), neg_loss[i], places=4) + self.assertAllClose(np.exp(np.ones([2, 1])), pos_loss[:2], atol=1e-4) + self.assertAllClose(np.exp(-np.ones([2, 1])), neg_loss[:2], atol=1e-4) # For positive lables, p oints with predictions 0.7 and larger get minimum # loss value of 1/e. For negative labels, these points are wrongly # classified and get loss e. - for i in range(6, 10): - self.assertAlmostEqual(math.exp(-1), pos_loss[i], places=4) - self.assertAlmostEqual(math.exp(1), neg_loss[i], places=4) + self.assertAllClose(np.exp(-np.ones([4, 1])), pos_loss[6:10], atol=1e-4) + self.assertAllClose(np.exp(np.ones([4, 1])), neg_loss[6:10], atol=1e-4) # Points in between 0.5-eps, 0..5+eps get loss exp(-label_m*y), where # y = 1/eps *x -1/(2eps), where x is the probability and label_m is either # 1 or -1 (for label of 0). - for i in range(2, 6): - self.assertAlmostEqual( - math.exp(-1.0 * (predictions_probs[i] * 1.0 / eps - 0.5 / eps)), - pos_loss[i], - places=4) - self.assertAlmostEqual( - math.exp(1.0 * (predictions_probs[i] * 1.0 / eps - 0.5 / eps)), - neg_loss[i], - places=4) + self.assertAllClose( + np.exp(-(predictions_probs[2:6] * 1.0 / eps - 0.5 / eps)), + pos_loss[2:6], atol=1e-4) + self.assertAllClose( + np.exp(predictions_probs[2:6] * 1.0 / eps - 0.5 / eps), + neg_loss[2:6], atol=1e-4) def test_per_example_squared_loss(self): - def _squared_loss(p, y): - return np.mean(1.0 * (p - y) * (p - y)) - labels = np.array([[0.123], [224.2], [-3], [2], [.3]], dtype=np.float32) weights = array_ops.ones([5, 1], dtypes.float32) predictions = np.array( @@ -99,9 +89,8 @@ class LossesTest(test_util.TensorFlowTestCase): predictions) loss = loss_tensor.eval() - for i in range(5): - self.assertAlmostEqual( - _squared_loss(labels[i], predictions[i]), loss[i], places=4) + self.assertAllClose( + np.square(labels[:5] - predictions[:5]), loss[:5], atol=1e-4) if __name__ == "__main__": diff --git a/tensorflow/contrib/cloud/BUILD b/tensorflow/contrib/cloud/BUILD index aa8f5ed12bc6f779e3c1a923b9225ec283189747..fe8bd072afd43a64fa62a65bd8900b5a98dbe761 100644 --- a/tensorflow/contrib/cloud/BUILD +++ b/tensorflow/contrib/cloud/BUILD @@ -60,9 +60,7 @@ tf_py_test( size = "small", srcs = ["python/ops/bigquery_reader_ops_test.py"], additional_deps = [ - ":bigquery_reader_ops_op_lib", ":cloud_py", - "//tensorflow/contrib/cloud/kernels:bigquery_reader_ops", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc index b31b882fa19a7eaad304d6d423961234f9affef4..e9b79a066def566096d6c3f3745974423e3371d1 100644 --- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc @@ -421,7 +421,7 @@ TEST_F(BigQueryTableAccessorTest, MultiplePagesTest) { TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); EXPECT_EQ(3, row_id); EXPECT_TRUE(accessor_->Done()); - + Example expected_example; ASSERT_TRUE(protobuf::TextFormat::ParseFromString(kTestExampleProtoWithNulls, &expected_example)); diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index f0144e9faa26801b6491b242b04fda8905f15306..c74da9cabd6816bc9c7891e32937534cff2d677d 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -80,13 +80,9 @@ class TPUClusterResolver(ClusterResolver): raise ImportError('googleapiclient must be installed before using the ' 'TPU cluster resolver') - # TODO(b/67375680): Remove custom URL once TPU APIs are finalized self._service = discovery.build( - 'tpu', - 'v1', - credentials=self._credentials, - discoveryServiceUrl='https://storage.googleapis.com' - '/tpu-api-definition/v1alpha1.json') + 'tpu', 'v1alpha1', + credentials=self._credentials) else: self._service = service diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index 77a3fc0c8322117f50265e56952b68480583de02..8d023cc81dd73751f0b5690f3649ded3fc038155 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -18,7 +18,6 @@ cmake_policy(SET CMP0022 NEW) # Options option(tensorflow_VERBOSE "Enable for verbose output" OFF) -option(tensorflow_ENABLE_GPU "Enable GPU support" OFF) option(tensorflow_ENABLE_SSL_SUPPORT "Enable boringssl support" OFF) option(tensorflow_ENABLE_GRPC_SUPPORT "Enable gRPC support" ON) option(tensorflow_ENABLE_HDFS_SUPPORT "Enable HDFS support" OFF) @@ -34,6 +33,13 @@ option(tensorflow_BUILD_SHARED_LIB "Build TensorFlow as a shared library" OFF) option(tensorflow_OPTIMIZE_FOR_NATIVE_ARCH "Enable compiler optimizations for the native processor architecture (if available)" ON) option(tensorflow_WIN_CPU_SIMD_OPTIONS "Enables CPU SIMD instructions") option(tensorflow_ENABLE_SNAPPY_SUPPORT "Enable SNAPPY compression support" ON) +option(tensorflow_DISABLE_EIGEN_FORCEINLINE "Disable forceinline, to speed up build on windows." OFF) + +# GPU, CUDA and cuDNN options +option(tensorflow_ENABLE_GPU "Enable GPU support" OFF) +set(tensorflow_CUDA_VERSION "9.0" CACHE STRING "CUDA version to build against") +set(tensorflow_CUDNN_VERSION "7" CACHE STRING "cuDNN version to build against") + if(HAIKU) option(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE "Enable PIE support" OFF) else() @@ -53,7 +59,15 @@ if (NOT WIN32) set(tensorflow_CUDNN_INCLUDE /usr/include) endif (NOT tensorflow_CUDNN_INCLUDE) option(tensorflow_PATH_CUDNN_STATIC_LIB "Override PATH_STATIC_LIB for libcudnn_static.a" ${tensorflow_PATH_STATIC_LIB}) + if (NOT tensorflow_PATH_CUDNN_STATIC_LIB) + # option's default value is OFF. Fill it with real default values + set (tensorflow_PATH_CUDNN_STATIC_LIB ${tensorflow_PATH_STATIC_LIB}) + endif (NOT tensorflow_PATH_CUDNN_STATIC_LIB) option(tensorflow_PATH_NCCL_STATIC_LIB "Override PATH_STATIC_LIB for libnccl_static.a" ${tensorflow_PATH_STATIC_LIB}) + if (NOT tensorflow_PATH_NCCL_STATIC_LIB) + # option's default value is OFF. Fill it with real default values + set (tensorflow_PATH_NCCL_STATIC_LIB ${tensorflow_PATH_STATIC_LIB}) + endif (NOT tensorflow_PATH_NCCL_STATIC_LIB) option(tensorflow_CUDA_LIBRARY_PATH "Designate the default CUDA library paths" /usr/local/cuda/lib64) if (NOT tensorflow_CUDA_LIBRARY_PATH) # option's default value is OFF. Fill it with real default values @@ -92,6 +106,10 @@ else() set(CMAKE_POSITION_INDEPENDENT_CODE OFF) endif() +if (tensorflow_DISABLE_EIGEN_FORCEINLINE) + add_definitions(-DEIGEN_STRONG_INLINE=inline) +endif() + add_definitions(-DEIGEN_AVOID_STL_ARRAY) if(WIN32) add_definitions(-DNOMINMAX -D_WIN32_WINNT=0x0A00 -DLANG_CXX11 -DCOMPILER_MSVC) @@ -262,7 +280,7 @@ if (tensorflow_ENABLE_GPU) list(APPEND CMAKE_LIBRARY_PATH "${tensorflow_CUDA_LIBRARY_PATH}/stubs") endif (NOT WIN32) - find_package(CUDA 8.0 REQUIRED) + find_package(CUDA ${tensorflow_CUDA_VERSION} REQUIRED) # by default we assume compute cabability 3.5 and 5.2. If you change this change it in # CUDA_NVCC_FLAGS and cuda_config.h below @@ -316,13 +334,16 @@ if (tensorflow_ENABLE_GPU) ${CUDA_curand_LIBRARY} ${CUDA_cupti_LIBRARY} ${CUDA_cusolver_LIBRARY} ${cudnn_STATIC_LIBRARY} ${culibos_STATIC_LIBRARY} ${nccl_STATIC_LIBRARY}) endif (WIN32) + # Remove "." from CUDA version variable. + string(REPLACE "." "" short_CUDA_VER ${tensorflow_CUDA_VERSION}) + # create cuda_config.h FILE(WRITE ${tensorflow_source_dir}/third_party/gpus/cuda/cuda_config.h "#ifndef CUDA_CUDA_CONFIG_H_\n" "#define CUDA_CUDA_CONFIG_H_\n" "#define TF_CUDA_CAPABILITIES CudaVersion(\"3.0\"),CudaVersion(\"3.5\"),CudaVersion(\"5.2\")\n" - "#define TF_CUDA_VERSION \"64_80\"\n" - "#define TF_CUDNN_VERSION \"64_6\"\n" + "#define TF_CUDA_VERSION \"64_${short_CUDA_VER}\"\n" + "#define TF_CUDNN_VERSION \"64_${tensorflow_CUDNN_VERSION}\"\n" "#define TF_CUDA_TOOLKIT_PATH \"${CUDA_TOOLKIT_ROOT_DIR}\"\n" "#endif // CUDA_CUDA_CONFIG_H_\n" ) @@ -360,15 +381,15 @@ if (tensorflow_ENABLE_GPU) if(WIN32) set(tensorflow_BUILD_INFO_FLAGS --build_config cuda --key_value msvcp_dll_name=msvcp140.dll - cudart_dll_name=cudart64_80.dll - cuda_version_number=8.0 + cudart_dll_name=cudart64_${short_CUDA_VER}.dll + cuda_version_number=${tensorflow_CUDA_VERSION} nvcuda_dll_name=nvcuda.dll - cudnn_dll_name=cudnn64_6.dll - cudnn_version_number=6) + cudnn_dll_name=cudnn64_${tensorflow_CUDNN_VERSION}.dll + cudnn_version_number=${tensorflow_CUDNN_VERSION}) else(WIN32) set(tensorflow_BUILD_INFO_FLAGS --build_config cuda --key_value - cuda_version_number=8.0 - cudnn_version_number=6) + cuda_version_number=${tensorflow_CUDA_VERSION} + cudnn_version_number=${tensorflow_CUDNN_VERSION}) endif(WIN32) else(tensorflow_ENABLE_GPU) set(tensorflow_BUILD_INFO_FLAGS --build_config cpu --key_value @@ -383,11 +404,7 @@ endif() # Let's get to work! include(tf_core_framework.cmake) -# NOTE: Disabled until issue #3996 is fixed. -# include(tf_stream_executor.cmake) -if (tensorflow_ENABLE_GPU) - include(tf_stream_executor.cmake) -endif() +include(tf_stream_executor.cmake) include(tf_core_cpu.cmake) include(tf_core_ops.cmake) diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md index 4ddfec5960d2b759bacb376202cd8dab6ef2b024..4be733a2809f366a214fa2bb853bccffb10ecaba 100644 --- a/tensorflow/contrib/cmake/README.md +++ b/tensorflow/contrib/cmake/README.md @@ -19,23 +19,6 @@ for instructions on how to install a pre-built TensorFlow package on Windows. ### Current known limitations * It is not possible to load a custom Op library. * GCS file system is not supported. -* The following Ops are not currently implemented: - - Dequantize - - QuantizeAndDequantize - - QuantizedAvgPool - - QuantizedBatchNomWithGlobalNormalization - - QuantizedBiasAdd - - QuantizedConcat - - QuantizedConv2D - - QuantizedMatmul - - QuantizedMaxPoo - - QuantizeDownAndShrinkRange - - QuantizedRelu - - QuantizedRelu6 - - QuantizedReshape - - QuantizeV2 - - RequantizationRange - - Requantize ## Building with CMake diff --git a/tensorflow/contrib/cmake/external/gemmlowp.cmake b/tensorflow/contrib/cmake/external/gemmlowp.cmake index 3b146657bfc9bdd54db14839195af45972e67aff..a235442dc5c0a07e249653381436eeae81575883 100644 --- a/tensorflow/contrib/cmake/external/gemmlowp.cmake +++ b/tensorflow/contrib/cmake/external/gemmlowp.cmake @@ -14,8 +14,8 @@ # ============================================================================== include (ExternalProject) -set(gemmlowp_URL https://mirror.bazel.build/github.com/google/gemmlowp/archive/010bb3e71a26ca1d0884a167081d092b43563996.zip) -set(gemmlowp_HASH SHA256=dd2557072bde12141419cb8320a9c25e6ec41a8ae53c2ac78c076a347bb46d9d) +set(gemmlowp_URL https://github.com/google/gemmlowp/archive/6a2a90822e8546fc2bfa7044de0faf1c1cb4862f.zip) +set(gemmlowp_HASH SHA256=3447948d219f3270383766bbe08942888c0eb4e0ca6663c0e0548502ec5bb77d) set(gemmlowp_BUILD ${CMAKE_CURRENT_BINARY_DIR}/gemmlowp/src/gemmlowp) set(gemmlowp_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/gemmlowp/src/gemmlowp) diff --git a/tensorflow/contrib/cmake/external/nsync.cmake b/tensorflow/contrib/cmake/external/nsync.cmake index 155c91cb97dbe5ef33c318efb5544a9fa22166c7..05080060479b6240edb8ab9f65160b3dd182feb9 100644 --- a/tensorflow/contrib/cmake/external/nsync.cmake +++ b/tensorflow/contrib/cmake/external/nsync.cmake @@ -16,7 +16,7 @@ include (ExternalProject) set(nsync_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/nsync/public) set(nsync_URL https://github.com/google/nsync) -set(nsync_TAG 93815892dddafe9146a5f7e7042281d59d0f4323) +set(nsync_TAG 8502189abfa44c249c01c2cad64e6ed660a9a668) set(nsync_BUILD ${CMAKE_CURRENT_BINARY_DIR}/nsync/src/nsync) set(nsync_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/nsync/install) diff --git a/tensorflow/contrib/cmake/external/re2.cmake b/tensorflow/contrib/cmake/external/re2.cmake index b56f4b089813247f3ab1c751538ba4b05cacb5b6..d10f5959f71dd350e6e2bcb81be8882b203fb231 100644 --- a/tensorflow/contrib/cmake/external/re2.cmake +++ b/tensorflow/contrib/cmake/external/re2.cmake @@ -45,4 +45,5 @@ ExternalProject_Add(re2 endif() -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_INSTALL_PREFIX:STRING=${re2_INSTALL} + -DRE2_BUILD_TESTING:BOOL=OFF ) diff --git a/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt b/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt index 594c2492d4fd68b50c8493321a2c4dcc2d41917e..aaae18a313dd082b428654091c9411600c981ec9 100644 --- a/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt +++ b/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt @@ -158,12 +158,21 @@ if (NOT "${NSYNC_LANGUAGE}X" STREQUAL "c++11X") elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "NetBSDX") include_directories ("${PROJECT_SOURCE_DIR}/platform/netbsd") set (NSYNC_POSIX ON) + set (NSYNC_OS_EXTRA_SRC + "platform/posix/src/nsync_semaphore_mutex.c" + ) elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "FreeBSDX") include_directories ("${PROJECT_SOURCE_DIR}/platform/freebsd") set (NSYNC_POSIX ON) + set (NSYNC_OS_EXTRA_SRC + "platform/posix/src/nsync_semaphore_mutex.c" + ) elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "OpenBSDX") include_directories ("${PROJECT_SOURCE_DIR}/platform/openbsd") set (NSYNC_POSIX ON) + set (NSYNC_OS_EXTRA_SRC + "platform/posix/src/nsync_semaphore_mutex.c" + ) endif () endif () diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt new file mode 100644 index 0000000000000000000000000000000000000000..92edce77dff699e75d1873ad0f56c6c489fbc571 --- /dev/null +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -0,0 +1,453 @@ +tensorflow +tensorflow/core +tensorflow/core/example +tensorflow/core/framework +tensorflow/core/lib +tensorflow/core/lib/core +tensorflow/core/protobuf +tensorflow/core/util +tensorflow/examples +tensorflow/examples/tutorials +tensorflow/examples/tutorials/mnist +tensorflow/python +tensorflow/python/client +tensorflow/python/data +tensorflow/python/data/ops +tensorflow/python/data/util +tensorflow/python/debug +tensorflow/python/debug/cli +tensorflow/python/debug/examples +tensorflow/python/debug/lib +tensorflow/python/debug/wrappers +tensorflow/python/eager +tensorflow/python/estimator +tensorflow/python/estimator/canned +tensorflow/python/estimator/export +tensorflow/python/estimator/inputs +tensorflow/python/estimator/inputs/queues +tensorflow/python/feature_column +tensorflow/python/framework +tensorflow/python/grappler +tensorflow/python/keras +tensorflow/python/keras/activations +tensorflow/python/keras/applications +tensorflow/python/keras/applications/inception_resnet_v2 +tensorflow/python/keras/applications/inception_v3 +tensorflow/python/keras/applications/mobilenet +tensorflow/python/keras/applications/resnet50 +tensorflow/python/keras/applications/vgg16 +tensorflow/python/keras/applications/vgg19 +tensorflow/python/keras/applications/xception +tensorflow/python/keras/backend +tensorflow/python/keras/callbacks +tensorflow/python/keras/constraints +tensorflow/python/keras/datasets +tensorflow/python/keras/datasets/boston_housing +tensorflow/python/keras/datasets/cifar10 +tensorflow/python/keras/datasets/cifar100 +tensorflow/python/keras/datasets/fashion_mnist +tensorflow/python/keras/datasets/imdb +tensorflow/python/keras/datasets/mnist +tensorflow/python/keras/datasets/reuters +tensorflow/python/keras/estimator +tensorflow/python/keras/initializers +tensorflow/python/keras/layers +tensorflow/python/keras/losses +tensorflow/python/keras/metrics +tensorflow/python/keras/models +tensorflow/python/keras/optimizers +tensorflow/python/keras/preprocessing +tensorflow/python/keras/preprocessing/image +tensorflow/python/keras/preprocessing/sequence +tensorflow/python/keras/preprocessing/text +tensorflow/python/keras/regularizers +tensorflow/python/keras/utils +tensorflow/python/keras/wrappers +tensorflow/python/keras/wrappers/scikit_learn +tensorflow/python/keras/_impl +tensorflow/python/keras/_impl/keras +tensorflow/python/keras/_impl/keras/applications +tensorflow/python/keras/_impl/keras/datasets +tensorflow/python/keras/_impl/keras/engine +tensorflow/python/keras/_impl/keras/layers +tensorflow/python/keras/_impl/keras/preprocessing +tensorflow/python/keras/_impl/keras/utils +tensorflow/python/keras/_impl/keras/wrappers +tensorflow/python/kernel_tests +tensorflow/python/kernel_tests/distributions +tensorflow/python/kernel_tests/linalg +tensorflow/python/kernel_tests/random +tensorflow/python/layers +tensorflow/python/lib +tensorflow/python/lib/core +tensorflow/python/lib/io +tensorflow/python/ops +tensorflow/python/ops/distributions +tensorflow/python/ops/linalg +tensorflow/python/ops/losses +tensorflow/python/platform +tensorflow/python/platform/default +tensorflow/python/platform/summary +tensorflow/python/profiler/ +tensorflow/python/profiler/internal +tensorflow/python/saved_model +tensorflow/python/summary +tensorflow/python/summary/writer +tensorflow/python/tools +tensorflow/python/training +tensorflow/python/user_ops +tensorflow/python/util +tensorflow/python/util/protobuf +tensorflow/tools +tensorflow/tools/graph_transforms +tensorflow/contrib +tensorflow/contrib/all_reduce +tensorflow/contrib/all_reduce/python +tensorflow/contrib/android +tensorflow/contrib/android/java +tensorflow/contrib/android/java/org +tensorflow/contrib/android/java/org/tensorflow +tensorflow/contrib/android/java/org/tensorflow/contrib +tensorflow/contrib/android/java/org/tensorflow/contrib/android +tensorflow/contrib/android/jni +tensorflow/contrib/batching +tensorflow/contrib/batching/kernels +tensorflow/contrib/batching/python +tensorflow/contrib/batching/python/ops +tensorflow/contrib/bayesflow +tensorflow/contrib/bayesflow/examples +tensorflow/contrib/bayesflow/examples/reinforce_simple +tensorflow/contrib/bayesflow/python +tensorflow/contrib/bayesflow/python/ops +tensorflow/contrib/boosted_trees +tensorflow/contrib/boosted_trees/estimator_batch +tensorflow/contrib/boosted_trees/kernels +tensorflow/contrib/boosted_trees/ops +tensorflow/contrib/boosted_trees/proto +tensorflow/contrib/boosted_trees/python +tensorflow/contrib/boosted_trees/python/ops +tensorflow/contrib/cloud +tensorflow/contrib/cloud/kernels +tensorflow/contrib/cloud/ops +tensorflow/contrib/cloud/python +tensorflow/contrib/cloud/python/ops +tensorflow/contrib/cluster_resolver +tensorflow/contrib/cluster_resolver/python +tensorflow/contrib/cluster_resolver/python/training +tensorflow/contrib/compiler +tensorflow/contrib/copy_graph +tensorflow/contrib/copy_graph/python +tensorflow/contrib/copy_graph/python/util +tensorflow/contrib/crf +tensorflow/contrib/crf/python +tensorflow/contrib/crf/python/ops +tensorflow/contrib/cudnn_rnn +tensorflow/contrib/cudnn_rnn/kernels +tensorflow/contrib/cudnn_rnn/ops +tensorflow/contrib/cudnn_rnn/python +tensorflow/contrib/cudnn_rnn/python/layers +tensorflow/contrib/cudnn_rnn/python/ops +tensorflow/contrib/data +tensorflow/contrib/data/kernels +tensorflow/contrib/data/python +tensorflow/contrib/data/python/kernel_tests +tensorflow/contrib/data/python/ops +tensorflow/contrib/decision_trees +tensorflow/contrib/decision_trees/proto +tensorflow/contrib/deprecated +tensorflow/contrib/distributions +tensorflow/contrib/distributions/python +tensorflow/contrib/distributions/python/ops +tensorflow/contrib/distributions/python/ops/bijectors +tensorflow/contrib/eager +tensorflow/contrib/eager/python +tensorflow/contrib/estimator +tensorflow/contrib/estimator/python +tensorflow/contrib/estimator/python/estimator +tensorflow/contrib/factorization +tensorflow/contrib/factorization/examples +tensorflow/contrib/factorization/kernels +tensorflow/contrib/factorization/ops +tensorflow/contrib/factorization/python +tensorflow/contrib/factorization/python/ops +tensorflow/contrib/ffmpeg +tensorflow/contrib/ffmpeg/default +tensorflow/contrib/framework +tensorflow/contrib/framework/kernels +tensorflow/contrib/framework/ops +tensorflow/contrib/framework/python +tensorflow/contrib/framework/python/framework +tensorflow/contrib/framework/python/ops +tensorflow/contrib/fused_conv +tensorflow/contrib/fused_conv/kernels +tensorflow/contrib/fused_conv/python +tensorflow/contrib/fused_conv/python/ops +tensorflow/contrib/gan +tensorflow/contrib/gan/python +tensorflow/contrib/gan/python/estimator +tensorflow/contrib/gan/python/estimator/python +tensorflow/contrib/gan/python/eval +tensorflow/contrib/gan/python/eval/python +tensorflow/contrib/gan/python/features +tensorflow/contrib/gan/python/features/python +tensorflow/contrib/gan/python/losses +tensorflow/contrib/gan/python/losses/python +tensorflow/contrib/graph_editor +tensorflow/contrib/graph_editor/examples +tensorflow/contrib/grid_rnn +tensorflow/contrib/grid_rnn/python +tensorflow/contrib/grid_rnn/python/ops +tensorflow/contrib/hooks +tensorflow/contrib/hooks/python +tensorflow/contrib/image +tensorflow/contrib/image/kernels +tensorflow/contrib/image/ops +tensorflow/contrib/image/python +tensorflow/contrib/image/python/ops +tensorflow/contrib/input_pipeline +tensorflow/contrib/input_pipeline/kernels +tensorflow/contrib/input_pipeline/ops +tensorflow/contrib/input_pipeline/python +tensorflow/contrib/input_pipeline/python/ops +tensorflow/contrib/integrate +tensorflow/contrib/integrate/python +tensorflow/contrib/integrate/python/ops +tensorflow/contrib/ios_examples +tensorflow/contrib/ios_examples/benchmark +tensorflow/contrib/ios_examples/benchmark/benchmark.xcodeproj +tensorflow/contrib/ios_examples/benchmark/data +tensorflow/contrib/ios_examples/camera +tensorflow/contrib/ios_examples/camera/camera_example.xcodeproj +tensorflow/contrib/ios_examples/camera/en.lproj +tensorflow/contrib/ios_examples/simple +tensorflow/contrib/ios_examples/simple/data +tensorflow/contrib/ios_examples/simple/tf_ios_makefile_example.xcodeproj +tensorflow/contrib/keras +tensorflow/contrib/keras/api +tensorflow/contrib/keras/api/keras +tensorflow/contrib/keras/api/keras/activations +tensorflow/contrib/keras/api/keras/applications +tensorflow/contrib/keras/api/keras/applications/inception_v3 +tensorflow/contrib/keras/api/keras/applications/mobilenet +tensorflow/contrib/keras/api/keras/applications/resnet50 +tensorflow/contrib/keras/api/keras/applications/vgg16 +tensorflow/contrib/keras/api/keras/applications/vgg19 +tensorflow/contrib/keras/api/keras/applications/xception +tensorflow/contrib/keras/api/keras/backend +tensorflow/contrib/keras/api/keras/callbacks +tensorflow/contrib/keras/api/keras/constraints +tensorflow/contrib/keras/api/keras/datasets +tensorflow/contrib/keras/api/keras/datasets/boston_housing +tensorflow/contrib/keras/api/keras/datasets/cifar10 +tensorflow/contrib/keras/api/keras/datasets/cifar100 +tensorflow/contrib/keras/api/keras/datasets/imdb +tensorflow/contrib/keras/api/keras/datasets/mnist +tensorflow/contrib/keras/api/keras/datasets/reuters +tensorflow/contrib/keras/api/keras/initializers +tensorflow/contrib/keras/api/keras/layers +tensorflow/contrib/keras/api/keras/losses +tensorflow/contrib/keras/api/keras/metrics +tensorflow/contrib/keras/api/keras/models +tensorflow/contrib/keras/api/keras/optimizers +tensorflow/contrib/keras/api/keras/preprocessing +tensorflow/contrib/keras/api/keras/preprocessing/image +tensorflow/contrib/keras/api/keras/preprocessing/sequence +tensorflow/contrib/keras/api/keras/preprocessing/text +tensorflow/contrib/keras/api/keras/regularizers +tensorflow/contrib/keras/api/keras/utils +tensorflow/contrib/keras/api/keras/wrappers +tensorflow/contrib/keras/api/keras/wrappers/scikit_learn +tensorflow/contrib/kernel_methods +tensorflow/contrib/kernel_methods/python +tensorflow/contrib/kernel_methods/python/mappers +tensorflow/contrib/kfac +tensorflow/contrib/kfac/examples +tensorflow/contrib/kfac/python +tensorflow/contrib/kfac/python/ops +tensorflow/contrib/labeled_tensor +tensorflow/contrib/labeled_tensor/python +tensorflow/contrib/labeled_tensor/python/ops +tensorflow/contrib/layers +tensorflow/contrib/layers/kernels +tensorflow/contrib/layers/ops +tensorflow/contrib/layers/python +tensorflow/contrib/layers/python/layers +tensorflow/contrib/layers/python/ops +tensorflow/contrib/learn +tensorflow/contrib/learn/python +tensorflow/contrib/learn/python/learn +tensorflow/contrib/learn/python/learn/dataframe +tensorflow/contrib/learn/python/learn/dataframe/queues +tensorflow/contrib/learn/python/learn/dataframe/transforms +tensorflow/contrib/learn/python/learn/datasets +tensorflow/contrib/learn/python/learn/datasets/data +tensorflow/contrib/learn/python/learn/estimators +tensorflow/contrib/learn/python/learn/learn_io +tensorflow/contrib/learn/python/learn/ops +tensorflow/contrib/learn/python/learn/preprocessing +tensorflow/contrib/learn/python/learn/utils +tensorflow/contrib/legacy_seq2seq +tensorflow/contrib/legacy_seq2seq/python +tensorflow/contrib/legacy_seq2seq/python/ops +tensorflow/contrib/libsvm +tensorflow/contrib/libsvm/python +tensorflow/contrib/libsvm/python/kernel_tests +tensorflow/contrib/libsvm/python/ops +tensorflow/contrib/linalg +tensorflow/contrib/linalg/python +tensorflow/contrib/linalg/python/ops +tensorflow/contrib/linear_optimizer +tensorflow/contrib/linear_optimizer/kernels +tensorflow/contrib/linear_optimizer/kernels/g3doc +tensorflow/contrib/linear_optimizer/python +tensorflow/contrib/linear_optimizer/python/ops +tensorflow/contrib/lookup +tensorflow/contrib/losses +tensorflow/contrib/losses/python +tensorflow/contrib/losses/python/losses +tensorflow/contrib/losses/python/metric_learning +tensorflow/contrib/makefile +tensorflow/contrib/memory_stats +tensorflow/contrib/memory_stats/kernels +tensorflow/contrib/memory_stats/ops +tensorflow/contrib/memory_stats/python +tensorflow/contrib/memory_stats/python/ops +tensorflow/contrib/meta_graph_transform +tensorflow/contrib/metrics +tensorflow/contrib/metrics/ops +tensorflow/contrib/metrics/python +tensorflow/contrib/metrics/python/metrics +tensorflow/contrib/metrics/python/ops +tensorflow/contrib/model_pruning +tensorflow/contrib/model_pruning/examples +tensorflow/contrib/model_pruning/examples/cifar10 +tensorflow/contrib/model_pruning/python +tensorflow/contrib/model_pruning/python/layers +tensorflow/contrib/nccl +tensorflow/contrib/nccl/kernels +tensorflow/contrib/nccl/ops +tensorflow/contrib/nccl/python +tensorflow/contrib/nccl/python/ops +tensorflow/contrib/ndlstm +tensorflow/contrib/ndlstm/python +tensorflow/contrib/nearest_neighbor/kernels +tensorflow/contrib/nearest_neighbor/ops +tensorflow/contrib/nearest_neighbor/python +tensorflow/contrib/nearest_neighbor/python/ops +tensorflow/contrib/nn +tensorflow/contrib/nn/python +tensorflow/contrib/nn/python/ops +tensorflow/contrib/opt +tensorflow/contrib/opt/python +tensorflow/contrib/opt/python/training +tensorflow/contrib/pi_examples +tensorflow/contrib/pi_examples/camera +tensorflow/contrib/pi_examples/label_image +tensorflow/contrib/pi_examples/label_image/data +tensorflow/contrib/periodic_resample +tensorflow/contrib/periodic_resample/python +tensorflow/contrib/periodic_resample/python/kernels +tensorflow/contrib/periodic_resample/python/ops +tensorflow/contrib/predictor +tensorflow/contrib/quantization +tensorflow/contrib/quantization/python +tensorflow/contrib/quantize +tensorflow/contrib/quantize/python +tensorflow/contrib/receptive_field +tensorflow/contrib/receptive_field/python +tensorflow/contrib/reduce_slice_ops +tensorflow/contrib/reduce_slice_ops/kernels +tensorflow/contrib/reduce_slice_ops/ops +tensorflow/contrib/reduce_slice_ops/python +tensorflow/contrib/reduce_slice_ops/python/ops +tensorflow/contrib/remote_fused_graph/pylib +tensorflow/contrib/remote_fused_graph/pylib/python +tensorflow/contrib/remote_fused_graph/pylib/python/ops +tensorflow/contrib/resampler +tensorflow/contrib/resampler/kernels +tensorflow/contrib/resampler/ops +tensorflow/contrib/resampler/python +tensorflow/contrib/resampler/python/ops +tensorflow/contrib/rnn +tensorflow/contrib/rnn/kernels +tensorflow/contrib/rnn/ops +tensorflow/contrib/rnn/python +tensorflow/contrib/rnn/python/kernel_tests +tensorflow/contrib/rnn/python/ops +tensorflow/contrib/saved_model +tensorflow/contrib/saved_model/python +tensorflow/contrib/saved_model/python/saved_model +tensorflow/contrib/seq2seq +tensorflow/contrib/seq2seq/kernels +tensorflow/contrib/seq2seq/ops +tensorflow/contrib/seq2seq/python +tensorflow/contrib/seq2seq/python/ops +tensorflow/contrib/session_bundle +tensorflow/contrib/session_bundle/example +tensorflow/contrib/signal +tensorflow/contrib/signal/python +tensorflow/contrib/signal/python/ops +tensorflow/contrib/slim +tensorflow/contrib/slim/python +tensorflow/contrib/slim/python/slim +tensorflow/contrib/slim/python/slim/data +tensorflow/contrib/slim/python/slim/nets +tensorflow/contrib/solvers +tensorflow/contrib/solvers/python +tensorflow/contrib/solvers/python/ops +tensorflow/contrib/sparsemax +tensorflow/contrib/sparsemax/python +tensorflow/contrib/sparsemax/python/ops +tensorflow/contrib/specs +tensorflow/contrib/specs/python +tensorflow/contrib/staging +tensorflow/contrib/stat_summarizer +tensorflow/contrib/stat_summarizer/python +tensorflow/contrib/stateless +tensorflow/contrib/stateless/python +tensorflow/contrib/summary +tensorflow/contrib/tensorboard +tensorflow/contrib/tensorboard/plugins +tensorflow/contrib/tensorboard/plugins/projector +tensorflow/contrib/tensor_forest +tensorflow/contrib/tensor_forest/client +tensorflow/contrib/tensor_forest/core +tensorflow/contrib/tensor_forest/core/ops +tensorflow/contrib/tensor_forest/data +tensorflow/contrib/tensor_forest/hybrid +tensorflow/contrib/tensor_forest/hybrid/core +tensorflow/contrib/tensor_forest/hybrid/core/ops +tensorflow/contrib/tensor_forest/hybrid/ops +tensorflow/contrib/tensor_forest/hybrid/python +tensorflow/contrib/tensor_forest/hybrid/python/layers +tensorflow/contrib/tensor_forest/hybrid/python/models +tensorflow/contrib/tensor_forest/hybrid/python/ops +tensorflow/contrib/tensor_forest/kernels +tensorflow/contrib/tensor_forest/python +tensorflow/contrib/tensor_forest/python/ops +tensorflow/contrib/testing +tensorflow/contrib/testing/python +tensorflow/contrib/testing/python/framework +tensorflow/contrib/text +tensorflow/contrib/text/kernels +tensorflow/contrib/text/ops +tensorflow/contrib/text/python +tensorflow/contrib/text/python/ops +tensorflow/contrib/tfprof +tensorflow/contrib/timeseries +tensorflow/contrib/timeseries/examples +tensorflow/contrib/timeseries/examples/data +tensorflow/contrib/timeseries/python +tensorflow/contrib/timeseries/python/timeseries +tensorflow/contrib/timeseries/python/timeseries/state_space_models +tensorflow/contrib/tpu +tensorflow/contrib/tpu/ops +tensorflow/contrib/tpu/profiler +tensorflow/contrib/tpu/python +tensorflow/contrib/tpu/python/ops +tensorflow/contrib/tpu/python/profiler +tensorflow/contrib/tpu/python/tpu +tensorflow/contrib/training +tensorflow/contrib/training/python +tensorflow/contrib/training/python/training +tensorflow/contrib/util diff --git a/tensorflow/contrib/cmake/python_protos.txt b/tensorflow/contrib/cmake/python_protos.txt new file mode 100644 index 0000000000000000000000000000000000000000..8a9c406d8b118c10ddcaafb0e4fc242aa79cdb57 --- /dev/null +++ b/tensorflow/contrib/cmake/python_protos.txt @@ -0,0 +1,19 @@ +tensorflow/core +tensorflow/core/profiler +tensorflow/python +tensorflow/contrib/boosted_trees/proto +tensorflow/contrib/cloud/kernels +tensorflow/contrib/decision_trees/proto +tensorflow/contrib/gdr +tensorflow/contrib/lite/toco +tensorflow/contrib/mpi +tensorflow/contrib/mpi_collectives +tensorflow/contrib/session_bundle +tensorflow/contrib/tensor_forest/proto +tensorflow/contrib/tensorboard/graph_explorer/proto +tensorflow/contrib/tensorboard/plugins/projector +tensorflow/contrib/tensorboard/plugins/trace +tensorflow/contrib/tpu/proto +tensorflow/contrib/tpu/profiler +tensorflow/contrib/training/python/training +tensorflow/contrib/verbs diff --git a/tensorflow/contrib/cmake/python_protos_cc.txt b/tensorflow/contrib/cmake/python_protos_cc.txt new file mode 100644 index 0000000000000000000000000000000000000000..d4a257b25c814a1464308d0e6ce3ce65d21f6a36 --- /dev/null +++ b/tensorflow/contrib/cmake/python_protos_cc.txt @@ -0,0 +1,5 @@ +tensorflow/core/profiler +tensorflow/python +tensorflow/contrib/session_bundle +tensorflow/contrib/tensorboard +tensorflow/contrib/training diff --git a/tensorflow/contrib/cmake/tf_cc_ops.cmake b/tensorflow/contrib/cmake/tf_cc_ops.cmake index f63aca4a835e213ef6d420845df9bb537514e142..6e2ac203f9a7f96cb14752a91483840a9eb6b451 100644 --- a/tensorflow/contrib/cmake/tf_cc_ops.cmake +++ b/tensorflow/contrib/cmake/tf_cc_ops.cmake @@ -83,7 +83,7 @@ foreach(tf_cc_op_lib_name ${tf_cc_op_lib_names}) ${cc_ops_target_dir}/${tf_cc_op_lib_name}.cc ${cc_ops_target_dir}/${tf_cc_op_lib_name}_internal.h ${cc_ops_target_dir}/${tf_cc_op_lib_name}_internal.cc - COMMAND ${tf_cc_op_lib_name}_gen_cc ${cc_ops_target_dir}/${tf_cc_op_lib_name}.h ${cc_ops_target_dir}/${tf_cc_op_lib_name}.cc ${tensorflow_source_dir}/tensorflow/cc/ops/op_gen_overrides.pbtxt ${cc_ops_include_internal} + COMMAND ${tf_cc_op_lib_name}_gen_cc ${cc_ops_target_dir}/${tf_cc_op_lib_name}.h ${cc_ops_target_dir}/${tf_cc_op_lib_name}.cc ${tensorflow_source_dir}/tensorflow/cc/ops/op_gen_overrides.pbtxt ${cc_ops_include_internal} ${tensorflow_source_dir}/tensorflow/core/api_def/base_api DEPENDS ${tf_cc_op_lib_name}_gen_cc create_cc_ops_header_dir ) diff --git a/tensorflow/contrib/cmake/tf_core_cpu.cmake b/tensorflow/contrib/cmake/tf_core_cpu.cmake index 5c01ca382fb9cc7a01a6f2b60a510c59f0aa7119..e4213ea2a47da2a7381cccd0504235ad62018d4e 100644 --- a/tensorflow/contrib/cmake/tf_core_cpu.cmake +++ b/tensorflow/contrib/cmake/tf_core_cpu.cmake @@ -63,7 +63,7 @@ if (tensorflow_ENABLE_GPU) file(GLOB_RECURSE tf_core_gpu_srcs "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/*.cc" "${tensorflow_source_dir}/tensorflow/core/platform/default/gpu/cupti_wrapper.cc" - "${tensorflow_source_dir}/tensorflow/core/platform/default/gpu_tracer.cc" + "${tensorflow_source_dir}/tensorflow/core/platform/default/device_tracer.cc" "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu_device_factory.cc" "${tensorflow_source_dir}/tensorflow/core/grappler/devices.h" "${tensorflow_source_dir}/tensorflow/core/grappler/devices.cc" diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index c607546f4a5244fb6e7cd12db874f07a962f6f4d..5ec1a8d04fa41c6b36400fc0998af77592866150 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -211,7 +211,7 @@ if (NOT tensorflow_ENABLE_GPU) list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_gpu_srcs}) else() file(GLOB tf_core_platform_srcs_exclude - "${tensorflow_source_dir}/tensorflow/core/platform/default/gpu_tracer.cc") + "${tensorflow_source_dir}/tensorflow/core/platform/default/device_tracer.cc") list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_srcs_exclude}) endif() diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index b1102cecbe2d64b5bfb8e5ed95ca1478a74c7fa4..d3b6c0bdd385432dc469133c00960ebba0dbeec5 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -55,10 +55,6 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/lib/learner/common/partitioners/example_partitioner.cc" - "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler.cc" - "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler.cc" - "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler.cc" - "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/model_ops.cc" @@ -89,6 +85,8 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc" "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc" + "${tensorflow_source_dir}/tensorflow/contrib/libsvm/kernels/decode_libsvm_op.cc" + "${tensorflow_source_dir}/tensorflow/contrib/libsvm/ops/libsvm_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_manager.cc" "${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc" @@ -154,9 +152,6 @@ list(REMOVE_ITEM tf_core_kernels_srcs ${tf_core_kernels_exclude_srcs}) if(WIN32) file(GLOB_RECURSE tf_core_kernels_windows_exclude_srcs # not working on windows yet - "${tensorflow_source_dir}/tensorflow/core/kernels/meta_support.*" - "${tensorflow_source_dir}/tensorflow/core/kernels/*quantiz*.h" - "${tensorflow_source_dir}/tensorflow/core/kernels/*quantiz*.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/neon/*" # not in core - those are loaded dynamically as dll "${tensorflow_source_dir}/tensorflow/contrib/nearest_neighbor/kernels/hyperplane_lsh_probes.cc" diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index 4a61ed7a3548b1992ddc71acb8a7761e252296ea..e8c2cd347327843d10d13c1d24a800ff776aa8c1 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -92,6 +92,7 @@ GENERATE_CONTRIB_OP_LIBRARY(image_sirds "${tensorflow_source_dir}/tensorflow/con GENERATE_CONTRIB_OP_LIBRARY(layers_sparse_feature_cross "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc") GENERATE_CONTRIB_OP_LIBRARY(memory_stats "${tensorflow_source_dir}/tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(nccl "${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(periodic_resample "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/ops/array_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(nearest_neighbor "${tensorflow_source_dir}/tensorflow/contrib/nearest_neighbor/ops/nearest_neighbor_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(resampler "${tensorflow_source_dir}/tensorflow/contrib/resampler/ops/resampler_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(rnn_gru "${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/gru_ops.cc") diff --git a/tensorflow/contrib/cmake/tf_grappler.cmake b/tensorflow/contrib/cmake/tf_grappler.cmake index a7841c98e83ec8c3eb91edfd9d639e169cb5f440..410490531a300c091afdd857d7f2d4e789a4c80e 100644 --- a/tensorflow/contrib/cmake/tf_grappler.cmake +++ b/tensorflow/contrib/cmake/tf_grappler.cmake @@ -23,7 +23,7 @@ file(GLOB tf_grappler_srcs "${tensorflow_source_dir}/tensorflow/python/grappler/model_analyzer.cc" "${tensorflow_source_dir}/tensorflow/python/grappler/model_analyzer.h" ) - + add_library(tf_grappler OBJECT ${tf_grappler_srcs}) add_dependencies(tf_grappler tf_core_cpu) \ No newline at end of file diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 61b3fd715ddc8f47e1f2724cb805dc5065448619..8db6929e31a1a5f5c793721f455a664bd6741b06 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -120,32 +120,34 @@ function(RELATIVE_PROTOBUF_GENERATE_CPP SRCS HDRS ROOT_DIR) set(${HDRS} ${${HDRS}} PARENT_SCOPE) endfunction() -file(GLOB_RECURSE tf_protos_python_srcs RELATIVE ${tensorflow_source_dir} - "${tensorflow_source_dir}/tensorflow/core/*.proto" - "${tensorflow_source_dir}/tensorflow/core/profiler/*.proto" - "${tensorflow_source_dir}/tensorflow/python/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/proto/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/decision_trees/proto/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/session_bundle/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/proto/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/tpu/profiler/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/training/*.proto" -) +FILE(READ python_protos.txt python_protos) +# Convert file contents into a CMake list (where each element in the list is one line of the file) +STRING(REGEX REPLACE ";" "\\\\;" python_protos "${python_protos}") +STRING(REGEX REPLACE "\n" ";" python_protos "${python_protos}") + +foreach(python_proto ${python_protos}) + file(GLOB_RECURSE tf_python_protos_src RELATIVE ${tensorflow_source_dir} + "${tensorflow_source_dir}/${python_proto}/*.proto" + ) + list(APPEND tf_python_protos_srcs ${tf_python_protos_src}) +endforeach(python_proto) + RELATIVE_PROTOBUF_GENERATE_PYTHON( - ${tensorflow_source_dir} PYTHON_PROTO_GENFILES ${tf_protos_python_srcs} + ${tensorflow_source_dir} PYTHON_PROTO_GENFILES ${tf_python_protos_srcs} ) -# NOTE(mrry): Avoid regenerating the tensorflow/core protos because this -# can cause benign-but-failing-on-Windows-due-to-file-locking conflicts -# when two rules attempt to generate the same file. -file(GLOB_RECURSE tf_python_protos_cc_srcs RELATIVE ${tensorflow_source_dir} - "${tensorflow_source_dir}/tensorflow/core/profiler/*.proto" - "${tensorflow_source_dir}/tensorflow/python/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/session_bundle/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/training/*.proto" -) +FILE(READ python_protos_cc.txt python_protos_cc) +# Convert file contents into a CMake list (where each element in the list is one line of the file) +STRING(REGEX REPLACE ";" "\\\\;" python_protos_cc "${python_protos_cc}") +STRING(REGEX REPLACE "\n" ";" python_protos_cc "${python_protos_cc}") + +foreach(python_proto_cc ${python_protos_cc}) + file(GLOB_RECURSE tf_python_protos_cc_src RELATIVE ${tensorflow_source_dir} + "${tensorflow_source_dir}/${python_proto_cc}/*.proto" + ) + list(APPEND tf_python_protos_cc_srcs ${tf_python_protos_cc_src}) +endforeach(python_proto_cc) + RELATIVE_PROTOBUF_GENERATE_CPP(PROTO_SRCS PROTO_HDRS ${tensorflow_source_dir} ${tf_python_protos_cc_srcs} ) @@ -191,315 +193,15 @@ function(add_python_module MODULE_NAME) endif() endfunction() -add_python_module("tensorflow") -add_python_module("tensorflow/core") -add_python_module("tensorflow/core/example") -add_python_module("tensorflow/core/framework") -add_python_module("tensorflow/core/lib") -add_python_module("tensorflow/core/lib/core") -add_python_module("tensorflow/core/protobuf") -add_python_module("tensorflow/core/util") -add_python_module("tensorflow/examples") -add_python_module("tensorflow/examples/tutorials") -add_python_module("tensorflow/examples/tutorials/mnist") -add_python_module("tensorflow/python") -add_python_module("tensorflow/python/client") -add_python_module("tensorflow/python/data") -add_python_module("tensorflow/python/data/ops") -add_python_module("tensorflow/python/data/util") -add_python_module("tensorflow/python/debug") -add_python_module("tensorflow/python/debug/cli") -add_python_module("tensorflow/python/debug/examples") -add_python_module("tensorflow/python/debug/lib") -add_python_module("tensorflow/python/debug/wrappers") -add_python_module("tensorflow/python/eager") -add_python_module("tensorflow/python/estimator") -add_python_module("tensorflow/python/estimator/canned") -add_python_module("tensorflow/python/estimator/export") -add_python_module("tensorflow/python/estimator/inputs") -add_python_module("tensorflow/python/estimator/inputs/queues") -add_python_module("tensorflow/python/feature_column") -add_python_module("tensorflow/python/framework") -add_python_module("tensorflow/python/grappler") -add_python_module("tensorflow/python/keras") -add_python_module("tensorflow/python/keras/activations") -add_python_module("tensorflow/python/keras/applications") -add_python_module("tensorflow/python/keras/applications/inception_resnet_v2") -add_python_module("tensorflow/python/keras/applications/inception_v3") -add_python_module("tensorflow/python/keras/applications/mobilenet") -add_python_module("tensorflow/python/keras/applications/resnet50") -add_python_module("tensorflow/python/keras/applications/vgg16") -add_python_module("tensorflow/python/keras/applications/vgg19") -add_python_module("tensorflow/python/keras/applications/xception") -add_python_module("tensorflow/python/keras/backend") -add_python_module("tensorflow/python/keras/callbacks") -add_python_module("tensorflow/python/keras/constraints") -add_python_module("tensorflow/python/keras/datasets") -add_python_module("tensorflow/python/keras/datasets/boston_housing") -add_python_module("tensorflow/python/keras/datasets/cifar10") -add_python_module("tensorflow/python/keras/datasets/cifar100") -add_python_module("tensorflow/python/keras/datasets/fashion_mnist") -add_python_module("tensorflow/python/keras/datasets/imdb") -add_python_module("tensorflow/python/keras/datasets/mnist") -add_python_module("tensorflow/python/keras/datasets/reuters") -add_python_module("tensorflow/python/keras/estimator") -add_python_module("tensorflow/python/keras/initializers") -add_python_module("tensorflow/python/keras/layers") -add_python_module("tensorflow/python/keras/losses") -add_python_module("tensorflow/python/keras/metrics") -add_python_module("tensorflow/python/keras/models") -add_python_module("tensorflow/python/keras/optimizers") -add_python_module("tensorflow/python/keras/preprocessing") -add_python_module("tensorflow/python/keras/preprocessing/image") -add_python_module("tensorflow/python/keras/preprocessing/sequence") -add_python_module("tensorflow/python/keras/preprocessing/text") -add_python_module("tensorflow/python/keras/regularizers") -add_python_module("tensorflow/python/keras/utils") -add_python_module("tensorflow/python/keras/wrappers") -add_python_module("tensorflow/python/keras/wrappers/scikit_learn") -add_python_module("tensorflow/python/keras/_impl") -add_python_module("tensorflow/python/keras/_impl/keras") -add_python_module("tensorflow/python/keras/_impl/keras/applications") -add_python_module("tensorflow/python/keras/_impl/keras/datasets") -add_python_module("tensorflow/python/keras/_impl/keras/engine") -add_python_module("tensorflow/python/keras/_impl/keras/layers") -add_python_module("tensorflow/python/keras/_impl/keras/preprocessing") -add_python_module("tensorflow/python/keras/_impl/keras/utils") -add_python_module("tensorflow/python/keras/_impl/keras/wrappers") -add_python_module("tensorflow/python/kernel_tests") -add_python_module("tensorflow/python/kernel_tests/distributions") -add_python_module("tensorflow/python/kernel_tests/linalg") -add_python_module("tensorflow/python/layers") -add_python_module("tensorflow/python/lib") -add_python_module("tensorflow/python/lib/core") -add_python_module("tensorflow/python/lib/io") -add_python_module("tensorflow/python/ops") -add_python_module("tensorflow/python/ops/distributions") -add_python_module("tensorflow/python/ops/linalg") -add_python_module("tensorflow/python/ops/losses") -add_python_module("tensorflow/python/platform") -add_python_module("tensorflow/python/platform/default") -add_python_module("tensorflow/python/platform/summary") -add_python_module("tensorflow/python/profiler/") -add_python_module("tensorflow/python/profiler/internal") -add_python_module("tensorflow/python/saved_model") -add_python_module("tensorflow/python/summary") -add_python_module("tensorflow/python/summary/writer") -add_python_module("tensorflow/python/tools") -add_python_module("tensorflow/python/training") -add_python_module("tensorflow/python/user_ops") -add_python_module("tensorflow/python/util") -add_python_module("tensorflow/python/util/protobuf") -add_python_module("tensorflow/tools") -add_python_module("tensorflow/tools/graph_transforms") -add_python_module("tensorflow/contrib") -add_python_module("tensorflow/contrib/all_reduce") -add_python_module("tensorflow/contrib/all_reduce/python") -add_python_module("tensorflow/contrib/android") -add_python_module("tensorflow/contrib/android/java") -add_python_module("tensorflow/contrib/android/java/org") -add_python_module("tensorflow/contrib/android/java/org/tensorflow") -add_python_module("tensorflow/contrib/android/java/org/tensorflow/contrib") -add_python_module("tensorflow/contrib/android/java/org/tensorflow/contrib/android") -add_python_module("tensorflow/contrib/android/jni") -add_python_module("tensorflow/contrib/bayesflow") -add_python_module("tensorflow/contrib/bayesflow/examples") -add_python_module("tensorflow/contrib/bayesflow/examples/reinforce_simple") -add_python_module("tensorflow/contrib/bayesflow/python") -add_python_module("tensorflow/contrib/bayesflow/python/kernel_tests") -add_python_module("tensorflow/contrib/bayesflow/python/ops") -add_python_module("tensorflow/contrib/boosted_trees") -add_python_module("tensorflow/contrib/boosted_trees/estimator_batch") -add_python_module("tensorflow/contrib/boosted_trees/ops") -add_python_module("tensorflow/contrib/boosted_trees/proto") -add_python_module("tensorflow/contrib/boosted_trees/python") -add_python_module("tensorflow/contrib/boosted_trees/python/kernel_tests") -add_python_module("tensorflow/contrib/boosted_trees/python/ops") -add_python_module("tensorflow/contrib/cloud") -add_python_module("tensorflow/contrib/cloud/kernels") -add_python_module("tensorflow/contrib/cloud/ops") -add_python_module("tensorflow/contrib/cloud/python") -add_python_module("tensorflow/contrib/cloud/python/ops") -add_python_module("tensorflow/contrib/cluster_resolver") -add_python_module("tensorflow/contrib/cluster_resolver/python") -add_python_module("tensorflow/contrib/cluster_resolver/python/training") -add_python_module("tensorflow/contrib/compiler") -add_python_module("tensorflow/contrib/copy_graph") -add_python_module("tensorflow/contrib/copy_graph/python") -add_python_module("tensorflow/contrib/copy_graph/python/util") -add_python_module("tensorflow/contrib/crf") -add_python_module("tensorflow/contrib/crf/python") -add_python_module("tensorflow/contrib/crf/python/kernel_tests") -add_python_module("tensorflow/contrib/crf/python/ops") -add_python_module("tensorflow/contrib/cudnn_rnn") -add_python_module("tensorflow/contrib/cudnn_rnn/kernels") -add_python_module("tensorflow/contrib/cudnn_rnn/ops") -add_python_module("tensorflow/contrib/cudnn_rnn/python") -add_python_module("tensorflow/contrib/cudnn_rnn/python/kernel_tests") -add_python_module("tensorflow/contrib/cudnn_rnn/python/layers") -add_python_module("tensorflow/contrib/cudnn_rnn/python/ops") -add_python_module("tensorflow/contrib/data") -add_python_module("tensorflow/contrib/data/python") -add_python_module("tensorflow/contrib/data/python/kernel_tests") -add_python_module("tensorflow/contrib/data/python/ops") -add_python_module("tensorflow/contrib/decision_trees") -add_python_module("tensorflow/contrib/decision_trees/proto") -add_python_module("tensorflow/contrib/deprecated") -add_python_module("tensorflow/contrib/distributions") -add_python_module("tensorflow/contrib/distributions/python") -add_python_module("tensorflow/contrib/distributions/python/kernel_tests") -add_python_module("tensorflow/contrib/distributions/python/ops") -add_python_module("tensorflow/contrib/distributions/python/ops/bijectors") -add_python_module("tensorflow/contrib/eager") -add_python_module("tensorflow/contrib/eager/python") -add_python_module("tensorflow/contrib/estimator") -add_python_module("tensorflow/contrib/estimator/python") -add_python_module("tensorflow/contrib/estimator/python/estimator") -add_python_module("tensorflow/contrib/factorization") -add_python_module("tensorflow/contrib/factorization/examples") -add_python_module("tensorflow/contrib/factorization/kernels") -add_python_module("tensorflow/contrib/factorization/ops") -add_python_module("tensorflow/contrib/factorization/python") -add_python_module("tensorflow/contrib/factorization/python/kernel_tests") -add_python_module("tensorflow/contrib/factorization/python/ops") -add_python_module("tensorflow/contrib/ffmpeg") -add_python_module("tensorflow/contrib/ffmpeg/default") -add_python_module("tensorflow/contrib/ffmpeg/testdata") -add_python_module("tensorflow/contrib/framework") -add_python_module("tensorflow/contrib/framework/kernels") -add_python_module("tensorflow/contrib/framework/ops") -add_python_module("tensorflow/contrib/framework/python") -add_python_module("tensorflow/contrib/framework/python/framework") -add_python_module("tensorflow/contrib/framework/python/ops") -add_python_module("tensorflow/contrib/gan") -add_python_module("tensorflow/contrib/gan/python") -add_python_module("tensorflow/contrib/gan/python/eval") -add_python_module("tensorflow/contrib/gan/python/eval/python") -add_python_module("tensorflow/contrib/gan/python/features") -add_python_module("tensorflow/contrib/gan/python/features/python") -add_python_module("tensorflow/contrib/gan/python/estimator") -add_python_module("tensorflow/contrib/gan/python/estimator/python") -add_python_module("tensorflow/contrib/gan/python/losses") -add_python_module("tensorflow/contrib/gan/python/losses/python") -add_python_module("tensorflow/contrib/graph_editor") -add_python_module("tensorflow/contrib/graph_editor/examples") -add_python_module("tensorflow/contrib/graph_editor/tests") -add_python_module("tensorflow/contrib/grid_rnn") -add_python_module("tensorflow/contrib/grid_rnn/python") -add_python_module("tensorflow/contrib/grid_rnn/python/kernel_tests") -add_python_module("tensorflow/contrib/grid_rnn/python/ops") -add_python_module("tensorflow/contrib/hooks") -add_python_module("tensorflow/contrib/image") -add_python_module("tensorflow/contrib/image/ops") -add_python_module("tensorflow/contrib/image/python") -add_python_module("tensorflow/contrib/image/python/ops") -add_python_module("tensorflow/contrib/input_pipeline") -add_python_module("tensorflow/contrib/input_pipeline/ops") -add_python_module("tensorflow/contrib/input_pipeline/python") -add_python_module("tensorflow/contrib/input_pipeline/python/ops") -add_python_module("tensorflow/contrib/integrate") -add_python_module("tensorflow/contrib/integrate/python") -add_python_module("tensorflow/contrib/integrate/python/ops") -add_python_module("tensorflow/contrib/ios_examples") -add_python_module("tensorflow/contrib/ios_examples/benchmark") -add_python_module("tensorflow/contrib/ios_examples/benchmark/benchmark.xcodeproj") -add_python_module("tensorflow/contrib/ios_examples/benchmark/data") -add_python_module("tensorflow/contrib/ios_examples/camera") -add_python_module("tensorflow/contrib/ios_examples/camera/camera_example.xcodeproj") -add_python_module("tensorflow/contrib/ios_examples/camera/en.lproj") -add_python_module("tensorflow/contrib/ios_examples/simple") -add_python_module("tensorflow/contrib/ios_examples/simple/data") -add_python_module("tensorflow/contrib/ios_examples/simple/tf_ios_makefile_example.xcodeproj") -add_python_module("tensorflow/contrib/keras") -add_python_module("tensorflow/contrib/keras/api") -add_python_module("tensorflow/contrib/keras/api/keras") -add_python_module("tensorflow/contrib/keras/api/keras/activations") -add_python_module("tensorflow/contrib/keras/api/keras/applications") -add_python_module("tensorflow/contrib/keras/api/keras/applications/inception_v3") -add_python_module("tensorflow/contrib/keras/api/keras/applications/mobilenet") -add_python_module("tensorflow/contrib/keras/api/keras/applications/resnet50") -add_python_module("tensorflow/contrib/keras/api/keras/applications/vgg16") -add_python_module("tensorflow/contrib/keras/api/keras/applications/vgg19") -add_python_module("tensorflow/contrib/keras/api/keras/applications/xception") -add_python_module("tensorflow/contrib/keras/api/keras/backend") -add_python_module("tensorflow/contrib/keras/api/keras/callbacks") -add_python_module("tensorflow/contrib/keras/api/keras/constraints") -add_python_module("tensorflow/contrib/keras/api/keras/datasets") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/boston_housing") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/cifar10") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/cifar100") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/imdb") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/mnist") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/reuters") -add_python_module("tensorflow/contrib/keras/api/keras/initializers") -add_python_module("tensorflow/contrib/keras/api/keras/layers") -add_python_module("tensorflow/contrib/keras/api/keras/losses") -add_python_module("tensorflow/contrib/keras/api/keras/metrics") -add_python_module("tensorflow/contrib/keras/api/keras/models") -add_python_module("tensorflow/contrib/keras/api/keras/optimizers") -add_python_module("tensorflow/contrib/keras/api/keras/preprocessing") -add_python_module("tensorflow/contrib/keras/api/keras/preprocessing/image") -add_python_module("tensorflow/contrib/keras/api/keras/preprocessing/sequence") -add_python_module("tensorflow/contrib/keras/api/keras/preprocessing/text") -add_python_module("tensorflow/contrib/keras/api/keras/regularizers") -add_python_module("tensorflow/contrib/keras/api/keras/utils") -add_python_module("tensorflow/contrib/keras/api/keras/wrappers") -add_python_module("tensorflow/contrib/keras/api/keras/wrappers/scikit_learn") -add_python_module("tensorflow/contrib/keras/python") -add_python_module("tensorflow/contrib/keras/python/keras") -add_python_module("tensorflow/contrib/keras/python/keras/applications") -add_python_module("tensorflow/contrib/keras/python/keras/datasets") -add_python_module("tensorflow/contrib/keras/python/keras/engine") -add_python_module("tensorflow/contrib/keras/python/keras/layers") -add_python_module("tensorflow/contrib/keras/python/keras/preprocessing") -add_python_module("tensorflow/contrib/keras/python/keras/utils") -add_python_module("tensorflow/contrib/keras/python/keras/wrappers") -add_python_module("tensorflow/contrib/kernel_methods") -add_python_module("tensorflow/contrib/kernel_methods/python") -add_python_module("tensorflow/contrib/kernel_methods/python/mappers") -add_python_module("tensorflow/contrib/kfac") -add_python_module("tensorflow/contrib/kfac/examples") -add_python_module("tensorflow/contrib/kfac/python") -add_python_module("tensorflow/contrib/kfac/python/ops") -add_python_module("tensorflow/contrib/labeled_tensor") -add_python_module("tensorflow/contrib/labeled_tensor/python") -add_python_module("tensorflow/contrib/labeled_tensor/python/ops") -add_python_module("tensorflow/contrib/layers") -add_python_module("tensorflow/contrib/layers/kernels") -add_python_module("tensorflow/contrib/layers/ops") -add_python_module("tensorflow/contrib/layers/python") -add_python_module("tensorflow/contrib/layers/python/kernel_tests") -add_python_module("tensorflow/contrib/layers/python/layers") -add_python_module("tensorflow/contrib/layers/python/ops") -add_python_module("tensorflow/contrib/learn") -add_python_module("tensorflow/contrib/learn/python") -add_python_module("tensorflow/contrib/learn/python/learn") -add_python_module("tensorflow/contrib/learn/python/learn/dataframe") -add_python_module("tensorflow/contrib/learn/python/learn/dataframe/queues") -add_python_module("tensorflow/contrib/learn/python/learn/dataframe/transforms") -add_python_module("tensorflow/contrib/learn/python/learn/datasets") -add_python_module("tensorflow/contrib/learn/python/learn/datasets/data") -add_python_module("tensorflow/contrib/learn/python/learn/estimators") -add_python_module("tensorflow/contrib/learn/python/learn/learn_io") -add_python_module("tensorflow/contrib/learn/python/learn/ops") -add_python_module("tensorflow/contrib/learn/python/learn/preprocessing") -add_python_module("tensorflow/contrib/learn/python/learn/preprocessing/tests") -add_python_module("tensorflow/contrib/learn/python/learn/tests") -add_python_module("tensorflow/contrib/learn/python/learn/tests/dataframe") -add_python_module("tensorflow/contrib/learn/python/learn/utils") -add_python_module("tensorflow/contrib/legacy_seq2seq") -add_python_module("tensorflow/contrib/legacy_seq2seq/python") -add_python_module("tensorflow/contrib/legacy_seq2seq/python/ops") -add_python_module("tensorflow/contrib/linalg") -add_python_module("tensorflow/contrib/linalg/python") -add_python_module("tensorflow/contrib/linalg/python/ops") -add_python_module("tensorflow/contrib/linalg/python/kernel_tests") -add_python_module("tensorflow/contrib/linear_optimizer") -add_python_module("tensorflow/contrib/linear_optimizer/kernels") -add_python_module("tensorflow/contrib/linear_optimizer/kernels/g3doc") -add_python_module("tensorflow/contrib/linear_optimizer/python") -add_python_module("tensorflow/contrib/linear_optimizer/python/kernel_tests") -add_python_module("tensorflow/contrib/linear_optimizer/python/ops") +FILE(READ python_modules.txt python_modules) +# Convert file contents into a CMake list (where each element in the list is one line of the file) +STRING(REGEX REPLACE ";" "\\\\;" python_modules "${python_modules}") +STRING(REGEX REPLACE "\n" ";" python_modules "${python_modules}") + +foreach(python_module ${python_modules}) + add_python_module(${python_module}) +endforeach(python_module) + add_custom_command(TARGET tf_python_touchup_modules PRE_BUILD COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite") @@ -513,157 +215,6 @@ add_custom_command( TARGET tf_python_copy_scripts_to_destination PRE_BUILD COMMAND ${CMAKE_COMMAND} -E touch ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite/python/lite.py) -add_python_module("tensorflow/contrib/lookup") -add_python_module("tensorflow/contrib/losses") -add_python_module("tensorflow/contrib/losses/python") -add_python_module("tensorflow/contrib/losses/python/losses") -add_python_module("tensorflow/contrib/losses/python/metric_learning") -add_python_module("tensorflow/contrib/makefile") -add_python_module("tensorflow/contrib/makefile/test") -add_python_module("tensorflow/contrib/memory_stats") -add_python_module("tensorflow/contrib/memory_stats/kernels") -add_python_module("tensorflow/contrib/memory_stats/ops") -add_python_module("tensorflow/contrib/memory_stats/python") -add_python_module("tensorflow/contrib/memory_stats/python/kernel_tests") -add_python_module("tensorflow/contrib/memory_stats/python/ops") -add_python_module("tensorflow/contrib/meta_graph_transform") -add_python_module("tensorflow/contrib/metrics") -add_python_module("tensorflow/contrib/metrics/kernels") -add_python_module("tensorflow/contrib/metrics/ops") -add_python_module("tensorflow/contrib/metrics/python") -add_python_module("tensorflow/contrib/metrics/python/kernel_tests") -add_python_module("tensorflow/contrib/metrics/python/metrics") -add_python_module("tensorflow/contrib/metrics/python/ops") -add_python_module("tensorflow/contrib/model_pruning") -add_python_module("tensorflow/contrib/model_pruning/examples") -add_python_module("tensorflow/contrib/model_pruning/examples/cifar10") -add_python_module("tensorflow/contrib/model_pruning/python") -add_python_module("tensorflow/contrib/model_pruning/python/layers") -add_python_module("tensorflow/contrib/ndlstm") -add_python_module("tensorflow/contrib/ndlstm/python") -add_python_module("tensorflow/contrib/nn") -add_python_module("tensorflow/contrib/nn/python") -add_python_module("tensorflow/contrib/nn/python/ops") -add_python_module("tensorflow/contrib/nccl") -add_python_module("tensorflow/contrib/nccl/kernels") -add_python_module("tensorflow/contrib/nccl/ops") -add_python_module("tensorflow/contrib/nccl/python") -add_python_module("tensorflow/contrib/nccl/python/ops") -add_python_module("tensorflow/contrib/nearest_neighbor/kernels") -add_python_module("tensorflow/contrib/nearest_neighbor/ops") -add_python_module("tensorflow/contrib/nearest_neighbor/python") -add_python_module("tensorflow/contrib/nearest_neighbor/python/kernel_tests") -add_python_module("tensorflow/contrib/nearest_neighbor/python/ops") -add_python_module("tensorflow/contrib/opt") -add_python_module("tensorflow/contrib/opt/python") -add_python_module("tensorflow/contrib/opt/python/training") -add_python_module("tensorflow/contrib/pi_examples") -add_python_module("tensorflow/contrib/pi_examples/camera") -add_python_module("tensorflow/contrib/pi_examples/label_image") -add_python_module("tensorflow/contrib/pi_examples/label_image/data") -add_python_module("tensorflow/contrib/predictor") -add_python_module("tensorflow/contrib/quantization") -add_python_module("tensorflow/contrib/quantization/python") -add_python_module("tensorflow/contrib/quantize") -add_python_module("tensorflow/contrib/quantize/python") -add_python_module("tensorflow/contrib/remote_fused_graph/pylib") -add_python_module("tensorflow/contrib/remote_fused_graph/pylib/python") -add_python_module("tensorflow/contrib/remote_fused_graph/pylib/python/ops") -add_python_module("tensorflow/contrib/resampler") -add_python_module("tensorflow/contrib/resampler/kernels") -add_python_module("tensorflow/contrib/resampler/ops") -add_python_module("tensorflow/contrib/resampler/python") -add_python_module("tensorflow/contrib/resampler/python/ops") -add_python_module("tensorflow/contrib/rnn") -add_python_module("tensorflow/contrib/rnn/kernels") -add_python_module("tensorflow/contrib/rnn/ops") -add_python_module("tensorflow/contrib/rnn/python") -add_python_module("tensorflow/contrib/rnn/python/kernel_tests") -add_python_module("tensorflow/contrib/rnn/python/ops") -add_python_module("tensorflow/contrib/saved_model") -add_python_module("tensorflow/contrib/saved_model/python") -add_python_module("tensorflow/contrib/saved_model/python/saved_model") -add_python_module("tensorflow/contrib/seq2seq") -add_python_module("tensorflow/contrib/seq2seq/kernels") -add_python_module("tensorflow/contrib/seq2seq/ops") -add_python_module("tensorflow/contrib/seq2seq/python") -add_python_module("tensorflow/contrib/seq2seq/python/kernel_tests") -add_python_module("tensorflow/contrib/seq2seq/python/ops") -add_python_module("tensorflow/contrib/session_bundle") -add_python_module("tensorflow/contrib/session_bundle/example") -add_python_module("tensorflow/contrib/session_bundle/testdata") -add_python_module("tensorflow/contrib/signal") -add_python_module("tensorflow/contrib/signal/python") -add_python_module("tensorflow/contrib/signal/python/ops") -add_python_module("tensorflow/contrib/slim") -add_python_module("tensorflow/contrib/slim/python") -add_python_module("tensorflow/contrib/slim/python/slim") -add_python_module("tensorflow/contrib/slim/python/slim/data") -add_python_module("tensorflow/contrib/slim/python/slim/nets") -add_python_module("tensorflow/contrib/solvers") -add_python_module("tensorflow/contrib/solvers/python") -add_python_module("tensorflow/contrib/solvers/python/ops") -add_python_module("tensorflow/contrib/sparsemax") -add_python_module("tensorflow/contrib/sparsemax/python") -add_python_module("tensorflow/contrib/sparsemax/python/ops") -add_python_module("tensorflow/contrib/specs") -add_python_module("tensorflow/contrib/specs/python") -add_python_module("tensorflow/contrib/staging") -add_python_module("tensorflow/contrib/stat_summarizer") -add_python_module("tensorflow/contrib/stateless") -add_python_module("tensorflow/contrib/tensorboard") -add_python_module("tensorflow/contrib/tensorboard/plugins") -add_python_module("tensorflow/contrib/tensorboard/plugins/projector") -add_python_module("tensorflow/contrib/tensor_forest") -add_python_module("tensorflow/contrib/tensor_forest/client") -add_python_module("tensorflow/contrib/tensor_forest/core") -add_python_module("tensorflow/contrib/tensor_forest/core/ops") -add_python_module("tensorflow/contrib/tensor_forest/data") -add_python_module("tensorflow/contrib/tensor_forest/hybrid") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/core") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/core/ops") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/ops") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/python") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/python/kernel_tests") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/python/layers") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/python/models") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/python/ops") -add_python_module("tensorflow/contrib/tensor_forest/python") -add_python_module("tensorflow/contrib/tensor_forest/python/kernel_tests") -add_python_module("tensorflow/contrib/tensor_forest/python/ops") -add_python_module("tensorflow/contrib/testing") -add_python_module("tensorflow/contrib/testing/python") -add_python_module("tensorflow/contrib/testing/python/framework") -add_python_module("tensorflow/contrib/text") -add_python_module("tensorflow/contrib/text/kernels") -add_python_module("tensorflow/contrib/text/ops") -add_python_module("tensorflow/contrib/text/python") -add_python_module("tensorflow/contrib/text/python/ops") -add_python_module("tensorflow/contrib/tfprof") -add_python_module("tensorflow/contrib/timeseries") -add_python_module("tensorflow/contrib/timeseries/examples") -add_python_module("tensorflow/contrib/timeseries/examples/data") -add_python_module("tensorflow/contrib/timeseries/python") -add_python_module("tensorflow/contrib/timeseries/python/timeseries") -add_python_module("tensorflow/contrib/timeseries/python/timeseries/state_space_models") -add_python_module("tensorflow/contrib/tpu") -add_python_module("tensorflow/contrib/tpu/ops") -add_python_module("tensorflow/contrib/tpu/profiler") -add_python_module("tensorflow/contrib/tpu/python") -add_python_module("tensorflow/contrib/tpu/python/ops") -add_python_module("tensorflow/contrib/tpu/python/profiler") -add_python_module("tensorflow/contrib/tpu/python/tpu") -add_python_module("tensorflow/contrib/training") -add_python_module("tensorflow/contrib/training/python") -add_python_module("tensorflow/contrib/training/python/training") -add_python_module("tensorflow/contrib/util") -add_python_module("tensorflow/contrib/reduce_slice_ops") -add_python_module("tensorflow/contrib/reduce_slice_ops/kernels") -add_python_module("tensorflow/contrib/reduce_slice_ops/ops") -add_python_module("tensorflow/contrib/reduce_slice_ops/python") -add_python_module("tensorflow/contrib/reduce_slice_ops/python/kernel_tests") -add_python_module("tensorflow/contrib/reduce_slice_ops/python/ops") -add_python_module("tensorflow/contrib/summary") # Generate the tensorflow.python.platform.build_info module. set(BUILD_INFO_PY "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/platform/build_info.py") @@ -738,7 +289,7 @@ function(GENERATE_PYTHON_OP_LIB tf_python_op_lib_name) # containing the wrappers. add_custom_command( OUTPUT ${GENERATE_PYTHON_OP_LIB_DESTINATION} - COMMAND ${tf_python_op_lib_name}_gen_python @${tensorflow_source_dir}/tensorflow/python/ops/hidden_ops.txt ${require_shape_fn} > ${GENERATE_PYTHON_OP_LIB_DESTINATION} + COMMAND ${tf_python_op_lib_name}_gen_python ${tensorflow_source_dir}/tensorflow/core/api_def/base_api,${tensorflow_source_dir}/tensorflow/core/api_def/python_api @${tensorflow_source_dir}/tensorflow/python/ops/hidden_ops.txt ${require_shape_fn} > ${GENERATE_PYTHON_OP_LIB_DESTINATION} DEPENDS ${tf_python_op_lib_name}_gen_python ) @@ -816,6 +367,9 @@ GENERATE_PYTHON_OP_LIB("contrib_memory_stats_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/memory_stats/ops/gen_memory_stats_ops.py) GENERATE_PYTHON_OP_LIB("contrib_nccl_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/nccl/ops/gen_nccl_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_periodic_resample_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/periodic_resample/python/ops/gen_periodic_resample_op.py) + GENERATE_PYTHON_OP_LIB("contrib_nearest_neighbor_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/nearest_neighbor/ops/gen_nearest_neighbor_ops.py) GENERATE_PYTHON_OP_LIB("contrib_resampler_ops" @@ -888,6 +442,8 @@ set (pywrap_tensorflow_internal_src "${tensorflow_source_dir}/tensorflow/python/framework/cpp_shape_inference.cc" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.h" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.cc" + "${tensorflow_source_dir}/tensorflow/python/lib/core/bfloat16.h" + "${tensorflow_source_dir}/tensorflow/python/lib/core/bfloat16.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/numpy.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/numpy.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/ndarray_tensor.h" @@ -898,6 +454,8 @@ set (pywrap_tensorflow_internal_src "${tensorflow_source_dir}/tensorflow/python/lib/core/py_func.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.cc" + "${tensorflow_source_dir}/tensorflow/python/lib/core/py_util.h" + "${tensorflow_source_dir}/tensorflow/python/lib/core/py_util.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/safe_ptr.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/safe_ptr.cc" "${tensorflow_source_dir}/tensorflow/python/lib/io/py_record_reader.h" @@ -1014,6 +572,20 @@ target_link_libraries(pywrap_tensorflow_internal PRIVATE ) if(WIN32) + + # include contrib/periodic_resample as .so + # + set(tf_periodic_resample_srcs + "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc" + "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h" + "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/ops/array_ops.cc" + ) + + AddUserOps(TARGET _periodic_resample_op + SOURCES "${tf_periodic_resample_srcs}" + DEPENDS pywrap_tensorflow_internal tf_python_ops + DISTCOPY ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/periodic_resample/python/ops/) + # include contrib/nearest_neighbor as .so # set(tf_nearest_neighbor_srcs diff --git a/tensorflow/contrib/cmake/tf_shared_lib.cmake b/tensorflow/contrib/cmake/tf_shared_lib.cmake index d4099f32797e404cc2f3c16b95e18d6b91d13981..571d2b0decb5e9afcec2314f9837546f0974e90d 100644 --- a/tensorflow/contrib/cmake/tf_shared_lib.cmake +++ b/tensorflow/contrib/cmake/tf_shared_lib.cmake @@ -45,7 +45,7 @@ if(WIN32) $ $ ) - + set(tensorflow_deffile "${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}/tensorflow.def") set_source_files_properties(${tensorflow_deffile} PROPERTIES GENERATED TRUE) diff --git a/tensorflow/contrib/cmake/tf_stream_executor.cmake b/tensorflow/contrib/cmake/tf_stream_executor.cmake index 8d95f0d3e813885c581b37cfc0b89e24d04ae6b1..91ca33f4c4d5f6c822f45b0676e6e46d2e4c2860 100644 --- a/tensorflow/contrib/cmake/tf_stream_executor.cmake +++ b/tensorflow/contrib/cmake/tf_stream_executor.cmake @@ -61,18 +61,18 @@ file(GLOB tf_stream_executor_srcs "${tensorflow_source_dir}/tensorflow/stream_executor/platform/default/*.h" ) -if (tensorflow_ENABLE_GPU) +if (tensorflow_ENABLE_GPU) file(GLOB tf_stream_executor_gpu_srcs "${tensorflow_source_dir}/tensorflow/stream_executor/cuda/*.cc" ) list(APPEND tf_stream_executor_srcs ${tf_stream_executor_gpu_srcs}) -endif() +endif() #file(GLOB_RECURSE tf_stream_executor_test_srcs # "${tensorflow_source_dir}/tensorflow/stream_executor/*_test.cc" # "${tensorflow_source_dir}/tensorflow/stream_executor/*_test.h" #) -#list(REMOVE_ITEM tf_stream_executor_srcs ${tf_stream_executor_test_srcs}) +#list(REMOVE_ITEM tf_stream_executor_srcs ${tf_stream_executor_test_srcs}) if (NOT WIN32) set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lgomp") diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index 5d6ba9ca8d85e9a2d19b7f3e488822a8f21c6821..94ca4b00175dffb4461fca34c5ecd79ba79be778 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -139,12 +139,15 @@ if (tensorflow_BUILD_PYTHON_TESTS) file(GLOB_RECURSE tf_test_src_py ${tf_test_rnn_src_py} + "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/*.py" "${tensorflow_source_dir}/tensorflow/python/debug/cli/*_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/lib/*_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/wrappers/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/estimator/python/estimator/*_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/*.py" "${tensorflow_source_dir}/tensorflow/python/meta_graph_transform/*_test.py" + "${tensorflow_source_dir}/tensorflow/python/ops/quantized_conv_ops_test.py" + "${tensorflow_source_dir}/tensorflow/python/ops/quantized_ops_test.py" "${tensorflow_source_dir}/tensorflow/python/platform/build_info_test.py" "${tensorflow_source_dir}/tensorflow/python/profiler/*_test.py" "${tensorflow_source_dir}/tensorflow/python/profiler/internal/*_test.py" @@ -153,7 +156,8 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/contrib/data/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/factorization/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/image/*_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/integration_test.py" + "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/*_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/python/kernel_tests/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/nearest_neighbor/python/kernel_tests/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/python/kernel_tests/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/stateless/python/kernel_tests/*_test.py" @@ -171,7 +175,6 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/contrib/graph_editor/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/bayesflow/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/framework/*_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/keras/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/distributions/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/learn/*_test.py" ) @@ -217,6 +220,7 @@ if (tensorflow_BUILD_PYTHON_TESTS) # TFDBG grpc:// mode is not yet available on Windows. "${tensorflow_source_dir}/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/lib/session_debug_grpc_test.py" + "${tensorflow_source_dir}/tensorflow/python/debug/lib/source_remote_test.py" # stl on windows handles overflows different "${tensorflow_source_dir}/tensorflow/python/kernel_tests/as_string_op_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/string_to_number_op_test.py" @@ -225,6 +229,10 @@ if (tensorflow_BUILD_PYTHON_TESTS) # Numerical issues, calculations off. "${tensorflow_source_dir}/tensorflow/python/kernel_tests/concat_op_test.py" "${tensorflow_source_dir}/tensorflow/contrib/factorization/python/ops/wals_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py" + "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/utils/data_utils_test.py" + "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/backend_test.py" + "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/preprocessing/image_test.py" # Float division by zero "${tensorflow_source_dir}/tensorflow/python/kernel_tests/benchmark_test.py" # Flaky, for unknown reasons. Cannot reproduce in terminal. Revisit once we can get stack traces. @@ -233,11 +241,11 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/training/sync_replicas_optimizer_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/lib/session_debug_grpc_test.py" "${tensorflow_source_dir}tensorflow/python/training/localhost_cluster_performance_test.py" - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/iterator_ops_cluster_test.py" + "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/functional_ops_test.py" "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py" # Type error in testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU. - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/iterator_ops_test.py" + "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/iterator_ops_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py" "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py" # IteratorGetMax OutOfRangeError @@ -261,9 +269,9 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/kernel_tests/linalg_grad_test.py" # cudaSolver handle creation fails. "${tensorflow_source_dir}/tensorflow/python/kernel_tests/array_ops_test.py" # depends on python/framework/test_ops # Dataset tests - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/dataset_constructor_op_test.py" # Segfaults on windows + "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py" # Segfaults on windows "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" # Segfaults on Windows. - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/iterator_ops_cluster_test.py" + "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py" # Broken tensorboard test due to cmake issues. "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py" # Needs portpicker "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py" # b/65430561 @@ -294,6 +302,9 @@ if (tensorflow_BUILD_PYTHON_TESTS) # Test should only be run manually "${tensorflow_source_dir}/tensorflow/python/kernel_tests/reduction_ops_test_big.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/svd_op_test.py" + # Depends on python/framework/test_ops + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/array_ops_test.py" + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/control_flow_util_test.py" ) endif() list(REMOVE_ITEM tf_test_src_py ${tf_test_src_py_exclude}) diff --git a/tensorflow/contrib/copy_graph/__init__.py b/tensorflow/contrib/copy_graph/__init__.py index 30a0aac140b576c501595fd6c8767b7dddde8e58..61ee39e4be1f0471309bb2672476dd9100cbfd49 100644 --- a/tensorflow/contrib/copy_graph/__init__.py +++ b/tensorflow/contrib/copy_graph/__init__.py @@ -13,8 +13,6 @@ # limitations under the License. # ============================================================================== """Functions to copy elements between graphs. - -See the @{$python/contrib.copy_graph} guide. """ from __future__ import absolute_import diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py index 8c2528f548799f9facef740b0134ac56966b2b04..bae66ffd4289308f2cbfc730ec50d057b13923fb 100644 --- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py +++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py @@ -19,7 +19,7 @@ from one graph to another. The copied elements are initialized inside a user-specified scope in the other graph. There are separate functions to copy ops and variables. There is also a function to retrive the copied version of an op from the -first graph inside a scope in the second graph. +first graph inside a scope in the second graph. @@copy_op_to_graph @@copy_variable_to_graph @@ -225,7 +225,7 @@ def copy_op_to_graph(org_instance, to_graph, variables, new_original_op, op_def) #Use Graph's hidden methods to add the op - to_graph._add_op(new_op) + to_graph._add_op(new_op) # pylint: disable=protected-access to_graph._record_op_seen_by_control_dependencies(new_op) for device_function in reversed(to_graph._device_function_stack): new_op._set_device(device_function(new_op)) diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py index 964ec754413f44d90c8e7e5e9358f82102f2cbcc..b47fb426a193e0fcc075deafae3eaab698f18ec9 100644 --- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py +++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py @@ -32,27 +32,41 @@ from tensorflow.python.platform import test class CrfTest(test.TestCase): def testCrfSequenceScore(self): - inputs = np.array( - [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) - tag_indices = np.array([1, 2, 1, 0], dtype=np.int32) transition_params = np.array( [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) - sequence_lengths = np.array(3, dtype=np.int32) - with self.test_session() as sess: - sequence_score = crf.crf_sequence_score( - inputs=array_ops.expand_dims(inputs, 0), - tag_indices=array_ops.expand_dims(tag_indices, 0), - sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), - transition_params=constant_op.constant(transition_params)) - sequence_score = array_ops.squeeze(sequence_score, [0]) - tf_sequence_score = sess.run(sequence_score) - expected_unary_score = sum(inputs[i][tag_indices[i]] - for i in range(sequence_lengths)) - expected_binary_score = sum( - transition_params[tag_indices[i], tag_indices[i + 1]] - for i in range(sequence_lengths - 1)) - expected_sequence_score = expected_unary_score + expected_binary_score - self.assertAllClose(tf_sequence_score, expected_sequence_score) + # Test both the length-1 and regular cases. + sequence_lengths_list = [ + np.array(3, dtype=np.int32), + np.array(1, dtype=np.int32) + ] + inputs_list = [ + np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], + dtype=np.float32), + np.array([[4, 5, -3]], + dtype=np.float32), + ] + tag_indices_list = [ + np.array([1, 2, 1, 0], dtype=np.int32), + np.array([1], dtype=np.int32) + ] + for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list, + inputs_list, + tag_indices_list): + with self.test_session() as sess: + sequence_score = crf.crf_sequence_score( + inputs=array_ops.expand_dims(inputs, 0), + tag_indices=array_ops.expand_dims(tag_indices, 0), + sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), + transition_params=constant_op.constant(transition_params)) + sequence_score = array_ops.squeeze(sequence_score, [0]) + tf_sequence_score = sess.run(sequence_score) + expected_unary_score = sum(inputs[i][tag_indices[i]] + for i in range(sequence_lengths)) + expected_binary_score = sum( + transition_params[tag_indices[i], tag_indices[i + 1]] + for i in range(sequence_lengths - 1)) + expected_sequence_score = expected_unary_score + expected_binary_score + self.assertAllClose(tf_sequence_score, expected_sequence_score) def testCrfUnaryScore(self): inputs = np.array( @@ -89,38 +103,54 @@ class CrfTest(test.TestCase): self.assertAllClose(tf_binary_score, expected_binary_score) def testCrfLogNorm(self): - inputs = np.array( - [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) transition_params = np.array( [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) - num_words = inputs.shape[0] - num_tags = inputs.shape[1] - sequence_lengths = np.array(3, dtype=np.int32) - with self.test_session() as sess: - all_sequence_scores = [] - - # Compare the dynamic program with brute force computation. - for tag_indices in itertools.product( - range(num_tags), repeat=sequence_lengths): - tag_indices = list(tag_indices) - tag_indices.extend([0] * (num_words - sequence_lengths)) - all_sequence_scores.append( - crf.crf_sequence_score( - inputs=array_ops.expand_dims(inputs, 0), - tag_indices=array_ops.expand_dims(tag_indices, 0), - sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), - transition_params=constant_op.constant(transition_params))) - - brute_force_log_norm = math_ops.reduce_logsumexp(all_sequence_scores) - log_norm = crf.crf_log_norm( - inputs=array_ops.expand_dims(inputs, 0), - sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), - transition_params=constant_op.constant(transition_params)) - log_norm = array_ops.squeeze(log_norm, [0]) - tf_brute_force_log_norm, tf_log_norm = sess.run( - [brute_force_log_norm, log_norm]) + # Test both the length-1 and regular cases. + sequence_lengths_list = [ + np.array(3, dtype=np.int32), + np.array(1, dtype=np.int32) + ] + inputs_list = [ + np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], + dtype=np.float32), + np.array([[3, -1, 3]], + dtype=np.float32), + ] + tag_indices_list = [ + np.array([1, 2, 1, 0], dtype=np.int32), + np.array([2], dtype=np.int32) + ] + + for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list, + inputs_list, + tag_indices_list): + num_words = inputs.shape[0] + num_tags = inputs.shape[1] + with self.test_session() as sess: + all_sequence_scores = [] + + # Compare the dynamic program with brute force computation. + for tag_indices in itertools.product( + range(num_tags), repeat=sequence_lengths): + tag_indices = list(tag_indices) + tag_indices.extend([0] * (num_words - sequence_lengths)) + all_sequence_scores.append( + crf.crf_sequence_score( + inputs=array_ops.expand_dims(inputs, 0), + tag_indices=array_ops.expand_dims(tag_indices, 0), + sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), + transition_params=constant_op.constant(transition_params))) + + brute_force_log_norm = math_ops.reduce_logsumexp(all_sequence_scores) + log_norm = crf.crf_log_norm( + inputs=array_ops.expand_dims(inputs, 0), + sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), + transition_params=constant_op.constant(transition_params)) + log_norm = array_ops.squeeze(log_norm, [0]) + tf_brute_force_log_norm, tf_log_norm = sess.run( + [brute_force_log_norm, log_norm]) - self.assertAllClose(tf_log_norm, tf_brute_force_log_norm) + self.assertAllClose(tf_log_norm, tf_brute_force_log_norm) def testCrfLogLikelihood(self): inputs = np.array( @@ -201,50 +231,66 @@ class CrfTest(test.TestCase): expected_max_sequence[:sequence_lengths]) def testCrfDecode(self): - inputs = np.array( - [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) transition_params = np.array( [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) - sequence_lengths = np.array(3, dtype=np.int32) - num_words = inputs.shape[0] - num_tags = inputs.shape[1] + # Test both the length-1 and regular cases. + sequence_lengths_list = [ + np.array(3, dtype=np.int32), + np.array(1, dtype=np.int32) + ] + inputs_list = [ + np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], + dtype=np.float32), + np.array([[-1, 2, 1]], + dtype=np.float32), + ] + tag_indices_list = [ + np.array([1, 2, 1, 0], dtype=np.int32), + np.array([2], dtype=np.int32) + ] + + for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list, + inputs_list, + tag_indices_list): + num_words = inputs.shape[0] + num_tags = inputs.shape[1] - with self.test_session() as sess: - all_sequence_scores = [] - all_sequences = [] - - # Compare the dynamic program with brute force computation. - for tag_indices in itertools.product( - range(num_tags), repeat=sequence_lengths): - tag_indices = list(tag_indices) - tag_indices.extend([0] * (num_words - sequence_lengths)) - all_sequences.append(tag_indices) - sequence_score = crf.crf_sequence_score( - inputs=array_ops.expand_dims(inputs, 0), - tag_indices=array_ops.expand_dims(tag_indices, 0), - sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), - transition_params=constant_op.constant(transition_params)) - sequence_score = array_ops.squeeze(sequence_score, [0]) - all_sequence_scores.append(sequence_score) - - tf_all_sequence_scores = sess.run(all_sequence_scores) - - expected_max_sequence_index = np.argmax(tf_all_sequence_scores) - expected_max_sequence = all_sequences[expected_max_sequence_index] - expected_max_score = tf_all_sequence_scores[expected_max_sequence_index] - - actual_max_sequence, actual_max_score = crf.crf_decode( - array_ops.expand_dims(inputs, 0), - constant_op.constant(transition_params), - array_ops.expand_dims(sequence_lengths, 0)) - actual_max_sequence = array_ops.squeeze(actual_max_sequence, [0]) - actual_max_score = array_ops.squeeze(actual_max_score, [0]) - tf_actual_max_sequence, tf_actual_max_score = sess.run( - [actual_max_sequence, actual_max_score]) - - self.assertAllClose(tf_actual_max_score, expected_max_score) - self.assertEqual(list(tf_actual_max_sequence[:sequence_lengths]), - expected_max_sequence[:sequence_lengths]) + with self.test_session() as sess: + all_sequence_scores = [] + all_sequences = [] + + # Compare the dynamic program with brute force computation. + for tag_indices in itertools.product( + range(num_tags), repeat=sequence_lengths): + tag_indices = list(tag_indices) + tag_indices.extend([0] * (num_words - sequence_lengths)) + all_sequences.append(tag_indices) + sequence_score = crf.crf_sequence_score( + inputs=array_ops.expand_dims(inputs, 0), + tag_indices=array_ops.expand_dims(tag_indices, 0), + sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), + transition_params=constant_op.constant(transition_params)) + sequence_score = array_ops.squeeze(sequence_score, [0]) + all_sequence_scores.append(sequence_score) + + tf_all_sequence_scores = sess.run(all_sequence_scores) + + expected_max_sequence_index = np.argmax(tf_all_sequence_scores) + expected_max_sequence = all_sequences[expected_max_sequence_index] + expected_max_score = tf_all_sequence_scores[expected_max_sequence_index] + + actual_max_sequence, actual_max_score = crf.crf_decode( + array_ops.expand_dims(inputs, 0), + constant_op.constant(transition_params), + array_ops.expand_dims(sequence_lengths, 0)) + actual_max_sequence = array_ops.squeeze(actual_max_sequence, [0]) + actual_max_score = array_ops.squeeze(actual_max_score, [0]) + tf_actual_max_sequence, tf_actual_max_score = sess.run( + [actual_max_sequence, actual_max_score]) + + self.assertAllClose(tf_actual_max_score, expected_max_score) + self.assertEqual(list(tf_actual_max_sequence[:sequence_lengths]), + expected_max_sequence[:sequence_lengths]) if __name__ == "__main__": diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py index 8b621732c1391feda011d21b175bc0b042b9eec7..7f5ae937b26f465076c6976429697c35924432e5 100644 --- a/tensorflow/contrib/crf/python/ops/crf.py +++ b/tensorflow/contrib/crf/python/ops/crf.py @@ -53,7 +53,9 @@ from __future__ import print_function import numpy as np from tensorflow.python.framework import dtypes +from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn @@ -101,12 +103,29 @@ def crf_sequence_score(inputs, tag_indices, sequence_lengths, Returns: sequence_scores: A [batch_size] vector of unnormalized sequence scores. """ - # Compute the scores of the given tag sequence. - unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs) - binary_scores = crf_binary_score(tag_indices, sequence_lengths, - transition_params) - sequence_scores = unary_scores + binary_scores - return sequence_scores + # If max_seq_len is 1, we skip the score calculation and simply gather the + # unary potentials of the single tag. + def _single_seq_fn(): + batch_size = array_ops.shape(inputs, out_type=tag_indices.dtype)[0] + example_inds = array_ops.reshape( + math_ops.range(batch_size, dtype=tag_indices.dtype), [-1, 1]) + return array_ops.gather_nd( + array_ops.squeeze(inputs, [1]), + array_ops.concat([example_inds, tag_indices], axis=1)) + + def _multi_seq_fn(): + # Compute the scores of the given tag sequence. + unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs) + binary_scores = crf_binary_score(tag_indices, sequence_lengths, + transition_params) + sequence_scores = unary_scores + binary_scores + return sequence_scores + + return utils.smart_cond( + pred=math_ops.equal(inputs.shape[1].value or array_ops.shape(inputs)[1], + 1), + fn1=_single_seq_fn, + fn2=_multi_seq_fn) def crf_log_norm(inputs, sequence_lengths, transition_params): @@ -124,19 +143,32 @@ def crf_log_norm(inputs, sequence_lengths, transition_params): # algorithm. first_input = array_ops.slice(inputs, [0, 0, 0], [-1, 1, -1]) first_input = array_ops.squeeze(first_input, [1]) - rest_of_input = array_ops.slice(inputs, [0, 1, 0], [-1, -1, -1]) - # Compute the alpha values in the forward algorithm in order to get the - # partition function. - forward_cell = CrfForwardRnnCell(transition_params) - _, alphas = rnn.dynamic_rnn( - cell=forward_cell, - inputs=rest_of_input, - sequence_length=sequence_lengths - 1, - initial_state=first_input, - dtype=dtypes.float32) - log_norm = math_ops.reduce_logsumexp(alphas, [1]) - return log_norm + # If max_seq_len is 1, we skip the algorithm and simply reduce_logsumexp over + # the "initial state" (the unary potentials). + def _single_seq_fn(): + return math_ops.reduce_logsumexp(first_input, [1]) + + def _multi_seq_fn(): + """Forward computation of alpha values.""" + rest_of_input = array_ops.slice(inputs, [0, 1, 0], [-1, -1, -1]) + + # Compute the alpha values in the forward algorithm in order to get the + # partition function. + forward_cell = CrfForwardRnnCell(transition_params) + _, alphas = rnn.dynamic_rnn( + cell=forward_cell, + inputs=rest_of_input, + sequence_length=sequence_lengths - 1, + initial_state=first_input, + dtype=dtypes.float32) + log_norm = math_ops.reduce_logsumexp(alphas, [1]) + return log_norm + + max_seq_len = array_ops.shape(inputs)[1] + return control_flow_ops.cond(pred=math_ops.equal(max_seq_len, 1), + true_fn=_single_seq_fn, + false_fn=_multi_seq_fn) def crf_log_likelihood(inputs, @@ -437,45 +469,64 @@ def crf_decode(potentials, transition_params, sequence_length): sequence_length: A [batch_size] vector of true sequence lengths. Returns: - decode_tags: A [batch_size, max_seq_len] tensor, with dtype tf.int32. + decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`. Contains the highest scoring tag indices. - best_score: A [batch_size] tensor, containing the score of decode_tags. + best_score: A [batch_size] vector, containing the score of `decode_tags`. """ - # For simplicity, in shape comments, denote: - # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output). - num_tags = potentials.get_shape()[2].value - - # Computes forward decoding. Get last score and backpointers. - crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params) - initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1]) - initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O] - inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O] - backpointers, last_score = rnn.dynamic_rnn( - crf_fwd_cell, - inputs=inputs, - sequence_length=sequence_length - 1, - initial_state=initial_state, - time_major=False, - dtype=dtypes.int32) # [B, T - 1, O], [B, O] - backpointers = gen_array_ops.reverse_sequence( - backpointers, sequence_length - 1, seq_dim=1) # [B, T-1, O] - - # Computes backward decoding. Extract tag indices from backpointers. - crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags) - initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1), - dtype=dtypes.int32) # [B] - initial_state = array_ops.expand_dims(initial_state, axis=-1) # [B, 1] - decode_tags, _ = rnn.dynamic_rnn( - crf_bwd_cell, - inputs=backpointers, - sequence_length=sequence_length - 1, - initial_state=initial_state, - time_major=False, - dtype=dtypes.int32) # [B, T - 1, 1] - decode_tags = array_ops.squeeze(decode_tags, axis=[2]) # [B, T - 1] - decode_tags = array_ops.concat([initial_state, decode_tags], axis=1) # [B, T] - decode_tags = gen_array_ops.reverse_sequence( - decode_tags, sequence_length, seq_dim=1) # [B, T] - - best_score = math_ops.reduce_max(last_score, axis=1) # [B] - return decode_tags, best_score + # If max_seq_len is 1, we skip the algorithm and simply return the argmax tag + # and the max activation. + def _single_seq_fn(): + squeezed_potentials = array_ops.squeeze(potentials, [1]) + decode_tags = array_ops.expand_dims( + math_ops.argmax(squeezed_potentials, axis=1), 1) + best_score = math_ops.reduce_max(squeezed_potentials, axis=1) + return math_ops.cast(decode_tags, dtype=dtypes.int32), best_score + + def _multi_seq_fn(): + """Decoding of highest scoring sequence.""" + + # For simplicity, in shape comments, denote: + # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output). + num_tags = potentials.get_shape()[2].value + + # Computes forward decoding. Get last score and backpointers. + crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params) + initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1]) + initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O] + inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O] + backpointers, last_score = rnn.dynamic_rnn( # [B, T - 1, O], [B, O] + crf_fwd_cell, + inputs=inputs, + sequence_length=sequence_length - 1, + initial_state=initial_state, + time_major=False, + dtype=dtypes.int32) + backpointers = gen_array_ops.reverse_sequence( # [B, T - 1, O] + backpointers, sequence_length - 1, seq_dim=1) + + # Computes backward decoding. Extract tag indices from backpointers. + crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags) + initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1), # [B] + dtype=dtypes.int32) + initial_state = array_ops.expand_dims(initial_state, axis=-1) # [B, 1] + decode_tags, _ = rnn.dynamic_rnn( # [B, T - 1, 1] + crf_bwd_cell, + inputs=backpointers, + sequence_length=sequence_length - 1, + initial_state=initial_state, + time_major=False, + dtype=dtypes.int32) + decode_tags = array_ops.squeeze(decode_tags, axis=[2]) # [B, T - 1] + decode_tags = array_ops.concat([initial_state, decode_tags], # [B, T] + axis=1) + decode_tags = gen_array_ops.reverse_sequence( # [B, T] + decode_tags, sequence_length, seq_dim=1) + + best_score = math_ops.reduce_max(last_score, axis=1) # [B] + return decode_tags, best_score + + return utils.smart_cond( + pred=math_ops.equal( + potentials.shape[1].value or array_ops.shape(potentials)[1], 1), + fn1=_single_seq_fn, + fn2=_multi_seq_fn) diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD index fce2c03e69bc4b8b0ac46b8e081a33c43c9d41ab..0751624bc4b7fbf413c342db3e5c440c9d572cd4 100644 --- a/tensorflow/contrib/cudnn_rnn/BUILD +++ b/tensorflow/contrib/cudnn_rnn/BUILD @@ -146,10 +146,10 @@ cuda_py_test( cuda_py_test( name = "cudnn_rnn_ops_benchmark", - size = "large", + size = "small", srcs = ["python/kernel_tests/cudnn_rnn_ops_benchmark.py"], additional_deps = [ - ":cudnn_rnn_ops_py", + ":cudnn_rnn_py", "//tensorflow/contrib/rnn:rnn_py", "//tensorflow/python:array_ops", "//tensorflow/python:client", @@ -164,7 +164,6 @@ cuda_py_test( "//tensorflow/python:variables", ], tags = [ - "manual", "noasan", # http://b/62067814 "nomsan", "notsan", diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py index ff409ac71826f1f0f57e9133d768003f849abc09..4fc5ff1bd1887c4532e95fcf0e791d72b20471b0 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py @@ -20,8 +20,8 @@ from __future__ import print_function import time +from tensorflow.contrib import rnn as contrib_rnn from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops -from tensorflow.contrib.rnn.python.ops import core_rnn from tensorflow.contrib.rnn.python.ops import lstm_ops from tensorflow.python.client import session from tensorflow.python.framework import dtypes @@ -29,8 +29,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import rnn from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -44,19 +43,19 @@ class CudnnRNNBenchmark(test.Benchmark): "large": { "num_layers": 4, "num_units": 1024, - "seq_length": 40, + "seq_length": 50, "batch_size": 64, }, "medium": { "num_layers": 4, "num_units": 512, - "seq_length": 30, + "seq_length": 50, "batch_size": 64, }, "small": { "num_layers": 4, "num_units": 128, - "seq_length": 20, + "seq_length": 50, "batch_size": 64, }, } @@ -71,7 +70,7 @@ class CudnnRNNBenchmark(test.Benchmark): def _BenchmarkOp(self, op, desc): burn_in_steps = 10 - benchmark_steps = 40 + benchmark_steps = 20 with session.Session() as sess: sess.run(variables.global_variables_initializer()) for i in xrange(burn_in_steps + benchmark_steps): @@ -126,16 +125,12 @@ class CudnnRNNBenchmark(test.Benchmark): seq_length = config["seq_length"] with ops.Graph().as_default(), ops.device("/device:GPU:0"): - inputs = seq_length * [ - array_ops.zeros([batch_size, num_units], dtypes.float32) - ] - initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=127) - - cell = rnn_cell.LSTMCell( - num_units=num_units, initializer=initializer, state_is_tuple=True) - multi_cell = rnn_cell.MultiRNNCell( - [cell() for _ in range(num_layers)]) - outputs, final_state = core_rnn.static_rnn( + inputs = array_ops.zeros([batch_size, seq_length, num_units], + dtypes.float32) + + multi_cell = contrib_rnn.MultiRNNCell( + [contrib_rnn.BasicLSTMCell(num_units) for _ in range(num_layers)]) + outputs, final_state = rnn.dynamic_rnn( multi_cell, inputs, dtype=dtypes.float32) trainable_variables = ops.get_collection( ops.GraphKeys.TRAINABLE_VARIABLES) @@ -154,14 +149,12 @@ class CudnnRNNBenchmark(test.Benchmark): seq_length = config["seq_length"] with ops.Graph().as_default(), ops.device("/device:GPU:0"): - inputs = seq_length * [ - array_ops.zeros([batch_size, num_units], dtypes.float32) - ] - cell = lambda: lstm_ops.LSTMBlockCell(num_units=num_units) # pylint: disable=cell-var-from-loop - - multi_cell = rnn_cell.MultiRNNCell( - [cell() for _ in range(num_layers)]) - outputs, final_state = core_rnn.static_rnn( + inputs = array_ops.zeros([batch_size, seq_length, num_units], + dtypes.float32) + + multi_cell = contrib_rnn.MultiRNNCell( + [lstm_ops.LSTMBlockCell(num_units) for _ in range(num_layers)]) + outputs, final_state = rnn.dynamic_rnn( multi_cell, inputs, dtype=dtypes.float32) trainable_variables = ops.get_collection( ops.GraphKeys.TRAINABLE_VARIABLES) diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD index f7d8a084d9c12c05c411ae0751854d1823a818ec..3b1c33063f1214b68f79560f50d56bf5d31c9560 100644 --- a/tensorflow/contrib/data/BUILD +++ b/tensorflow/contrib/data/BUILD @@ -18,6 +18,7 @@ py_library( "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/data/python/ops:iterator_ops", "//tensorflow/contrib/data/python/ops:readers", + "//tensorflow/contrib/data/python/ops:shuffle_ops", "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/python:util", "//tensorflow/python/data/ops:iterator_ops", diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 6e43ae0e6320fa237435b837780ec8aea941872b..c9ad091bd44d6e3a9368e182c3df9fc1c6e48071 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -17,6 +17,7 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@Dataset +@@Counter @@Iterator @@TFRecordDataset @@FixedLengthRecordDataset @@ -33,6 +34,7 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@unbatch @@parallel_interleave @@rejection_resample +@@scan @@sloppy_interleave @@get_single_element @@ -48,6 +50,7 @@ from tensorflow.contrib.data.python.ops.batching import batch_and_drop_remainder from tensorflow.contrib.data.python.ops.batching import dense_to_sparse_batch from tensorflow.contrib.data.python.ops.batching import padded_batch_and_drop_remainder from tensorflow.contrib.data.python.ops.batching import unbatch +from tensorflow.contrib.data.python.ops.counter import Counter from tensorflow.contrib.data.python.ops.dataset_ops import Dataset from tensorflow.contrib.data.python.ops.dataset_ops import get_single_element from tensorflow.contrib.data.python.ops.enumerate_ops import enumerate_dataset @@ -62,6 +65,8 @@ from tensorflow.contrib.data.python.ops.readers import SqlDataset from tensorflow.contrib.data.python.ops.readers import TextLineDataset from tensorflow.contrib.data.python.ops.readers import TFRecordDataset from tensorflow.contrib.data.python.ops.resampling import rejection_resample +from tensorflow.contrib.data.python.ops.scan_ops import scan +from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat from tensorflow.python.data.ops.iterator_ops import Iterator # pylint: enable=unused-import diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 3d4e46408e20a7c7c39c2601458b237b18676b72..d5ad14532780ff6b0cc40ae5a206c50ca70750ba 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -4,7 +4,7 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "py_test", "tf_py_test") py_test( name = "batch_dataset_op_test", @@ -110,6 +110,7 @@ py_test( "//tensorflow/python:resource_variable_ops", "//tensorflow/python:session", "//tensorflow/python:sparse_tensor", + "//tensorflow/python:tensor_shape", "//tensorflow/python/data/util:nest", "//third_party/py/numpy", ], @@ -131,6 +132,8 @@ py_library( "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", "//tensorflow/python:util", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:iterator_ops", "//third_party/py/numpy", ], ) @@ -140,7 +143,9 @@ py_test( size = "small", srcs = ["filter_dataset_op_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -152,21 +157,28 @@ py_test( ], ) -py_test( +tf_py_test( name = "flat_map_dataset_op_test", size = "small", srcs = ["flat_map_dataset_op_test.py"], - srcs_version = "PY2AND3", - deps = [ + additional_deps = [ + ":dataset_serialization_test", + "//third_party/py/numpy", "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:function", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", "//tensorflow/python:session", "//tensorflow/python:training", - "//third_party/py/numpy", + "//tensorflow/python:variable_scope", ], + grpc_enabled = True, + tags = ["no_pip"], ) py_test( @@ -178,6 +190,7 @@ py_test( "manual", # b/67958761 ], deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/python:array_ops", @@ -194,13 +207,11 @@ py_test( ], ) -py_test( +tf_py_test( name = "iterator_ops_cluster_test", size = "small", srcs = ["iterator_ops_cluster_test.py"], - srcs_version = "PY2AND3", - tags = ["no_windows"], - deps = [ + additional_deps = [ "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -214,14 +225,19 @@ py_test( "//tensorflow/python:session", "//tensorflow/python/data/ops:iterator_ops", ], + grpc_enabled = True, + tags = [ + "no_windows", + "oss_serial", + ], ) -py_test( +tf_py_test( name = "iterator_ops_test", size = "small", srcs = ["iterator_ops_test.py"], - srcs_version = "PY2AND3", - deps = [ + additional_deps = [ + "//third_party/py/numpy", "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/data/python/ops:readers", "//tensorflow/core:protos_all_py", @@ -243,8 +259,8 @@ py_test( "//tensorflow/python:session", "//tensorflow/python:training", "//tensorflow/python/data/ops:iterator_ops", - "//third_party/py/numpy", ], + grpc_enabled = True, ) py_test( @@ -264,12 +280,13 @@ py_test( py_test( name = "map_dataset_op_test", - size = "small", + size = "medium", srcs = ["map_dataset_op_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:iterator_ops", "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -278,23 +295,35 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python:function", "//tensorflow/python:functional_ops", "//tensorflow/python:io_ops", "//tensorflow/python:lookup_ops", "//tensorflow/python:math_ops", - "//tensorflow/python:platform", "//tensorflow/python:random_ops", "//tensorflow/python:script_ops", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", - "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//tensorflow/python/data/ops:iterator_ops", "//third_party/py/numpy", ], ) +py_test( + name = "prefetch_dataset_op_test", + size = "small", + srcs = ["prefetch_dataset_op_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test", + "//tensorflow/python:platform", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + py_test( name = "range_dataset_op_test", size = "small", @@ -323,9 +352,10 @@ py_test( py_test( name = "reader_dataset_ops_test", - size = "small", + size = "medium", srcs = ["reader_dataset_ops_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:readers", @@ -362,8 +392,25 @@ py_test( ) py_test( - name = "sequence_dataset_op_test", + name = "scan_dataset_op_test", size = "small", + srcs = ["scan_dataset_op_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "sequence_dataset_op_test", + size = "medium", srcs = ["sequence_dataset_op_test.py"], srcs_version = "PY2AND3", tags = ["no_pip"], @@ -392,12 +439,15 @@ py_test( py_test( name = "shuffle_dataset_op_test", - size = "small", + size = "medium", srcs = ["shuffle_dataset_op_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/contrib/data/python/ops:shuffle_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -426,6 +476,21 @@ py_test( ], ) +py_test( + name = "stats_dataset_ops_test", + size = "small", + srcs = ["stats_dataset_ops_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + ], +) + py_test( name = "zip_dataset_op_test", size = "small", diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 09416f8302842355da438aa35747bdc178ed5f4f..506eefbef0204284a103827180c13b13200a3f93 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -104,14 +104,58 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError): sess.run(init_op, feed_dict={count: 14, batch_size: 0}) - def testBatchSparseError(self): + def assertSparseValuesEqual(self, a, b): + self.assertAllEqual(a.indices, b.indices) + self.assertAllEqual(a.values, b.values) + self.assertAllEqual(a.dense_shape, b.dense_shape) - def _map_fn(i): - return sparse_tensor.SparseTensor( - indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i + def testBatchSparse(self): - with self.assertRaises(TypeError): - _ = dataset_ops.Dataset.range(10).map(_map_fn).batch(10) + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=[[0]], values=(i * [1]), dense_shape=[1]) + + iterator = dataset_ops.Dataset.range(10).map(_sparse).batch( + 5).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(2): + actual = sess.run(get_next) + expected = sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]], + values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4], + dense_shape=[5, 1]) + self.assertTrue(sparse_tensor.is_sparse(actual)) + self.assertSparseValuesEqual(actual, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testNestedBatchSparse(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=[[0]], values=(i * [1]), dense_shape=[1]) + + iterator = dataset_ops.Dataset.range(10).map(_sparse).batch(5).batch( + 2).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + actual = sess.run(get_next) + expected = sparse_tensor.SparseTensorValue( + indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [0, 4, 0], + [1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0], [1, 4, 0]], + values=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + dense_shape=[2, 5, 1]) + self.assertTrue(sparse_tensor.is_sparse(actual)) + self.assertSparseValuesEqual(actual, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) def testPaddedBatchDataset(self): seq_lens = array_ops.placeholder(dtypes.int32, shape=[None]) @@ -250,7 +294,7 @@ class BatchDatasetTest(test.TestCase): def testPaddedBatchSparseError(self): def _map_fn(i): - return sparse_tensor.SparseTensor( + return sparse_tensor.SparseTensorValue( indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i with self.assertRaises(TypeError): @@ -438,6 +482,30 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) + def testBatchAndDropRemainderSparse(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=[[0]], values=(i * [1]), dense_shape=[1]) + + iterator = dataset_ops.Dataset.range(12).map(_sparse).apply( + batching.batch_and_drop_remainder(5)).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(2): + actual = sess.run(get_next) + expected = sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]], + values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4], + dense_shape=[5, 1]) + self.assertTrue(sparse_tensor.is_sparse(actual)) + self.assertSparseValuesEqual(actual, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + def testPaddedBatchAndDropRemainder(self): els = [] for length in [3, 6, 9, 4, 12, 10, 2]: @@ -474,6 +542,16 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) + def testPaddedBatchAndDropRemainderSparseError(self): + + def _map_fn(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i + + with self.assertRaises(TypeError): + _ = dataset_ops.Dataset.range(10).map(_map_fn).apply( + batching.padded_batch_and_drop_remainder(5)) + def testBatchAndDropRemainderShapeInference(self): components = (array_ops.placeholder(dtypes.int32), (array_ops.placeholder(dtypes.int32, shape=[None]), @@ -499,17 +577,7 @@ class BatchDatasetTest(test.TestCase): self.assertEqual([None], dataset.output_shapes[1][0].as_list()) self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list()) - def testBatchAndDropRemainderSparseError(self): - - def _map_fn(i): - return sparse_tensor.SparseTensor( - indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i - - with self.assertRaises(TypeError): - _ = dataset_ops.Dataset.range(10).map(_map_fn).apply( - batching.batch_and_drop_remainder(10)) - - def testBatchAndMapDataset(self): + def _testBatchAndMapDatasetHelper(self, num_parallel_batches=1): """Test a dataset that maps a TF function across its input elements.""" # The pipeline is TensorSliceDataset -> # RepeatDataset(count) -> BatchAndMapDataset(square_3, batch_size). @@ -525,7 +593,10 @@ class BatchDatasetTest(test.TestCase): iterator = ( dataset_ops.Dataset.from_tensor_slices(components).repeat(count).apply( - batching.map_and_batch(_map_fn, batch_size)) + batching.map_and_batch( + map_func=_map_fn, + batch_size=batch_size, + num_parallel_batches=num_parallel_batches)) .make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -559,7 +630,11 @@ class BatchDatasetTest(test.TestCase): for j in range(8): self.assertAllEqual(component[(i * 8 + j) % 7]**2, result_component[j]) - # The last batch should fail with `OutOfRange`. + result = sess.run(get_next) + for component, result_component in zip(components, result): + for j in range((14 * 7) % 8): + self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2, + result_component[j]) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @@ -572,6 +647,36 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError): sess.run(init_op, feed_dict={count: 14, batch_size: 0}) + def testBatchAndMapDataset(self): + return self._testBatchAndMapDatasetHelper() + + def testBatchAndMapDatasetWithParallelBatching(self): + return self._testBatchAndMapDatasetHelper(num_parallel_batches=10) + + def testMapAndBatchSparse(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=[[0]], values=(i * [1]), dense_shape=[1]) + + iterator = dataset_ops.Dataset.range(10).apply( + batching.map_and_batch(_sparse, 5)).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(2): + actual = sess.run(get_next) + expected = sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]], + values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4], + dense_shape=[5, 1]) + self.assertTrue(sparse_tensor.is_sparse(actual)) + self.assertSparseValuesEqual(actual, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + def testBatchAndMapDatasetFails(self): """Test a dataset that maps a TF function across its input elements.""" dataset = dataset_ops.Dataset.from_tensors( @@ -631,5 +736,41 @@ class BatchDatasetSerializationTest( num_outputs) +class PaddedBatchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testPaddedBatch(self): + + def build_dataset(seq_lens): + return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( + lambda x: array_ops.fill([x], x)).padded_batch( + 4, padded_shapes=[-1]) + + seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) + seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) + self.run_core_tests(lambda: build_dataset(seq_lens1), + lambda: build_dataset(seq_lens2), 8) + + def testPaddedBatchNonDefaultPadding(self): + + def build_dataset(seq_lens): + + def fill_tuple(x): + filled = array_ops.fill([x], x) + return (filled, string_ops.as_string(filled)) + + padded_shape = [-1] + return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( + fill_tuple).padded_batch( + 4, + padded_shapes=(padded_shape, padded_shape), + padding_values=(-1, "")) + + seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) + seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) + self.run_core_tests(lambda: build_dataset(seq_lens1), + lambda: build_dataset(seq_lens2), 8) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py index 0f1c8838ca111c7674fa4f7b16a8a5f6590281f4..55a1d3b95b212466b262ad3c26f1efd7ed0e067e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops @@ -39,7 +40,7 @@ from tensorflow.python.platform import test class DatasetConstructorTest(test.TestCase): - def testTensorDataset(self): + def testFromTensors(self): """Test an dataset that represents a single tuple of tensors.""" components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) @@ -59,7 +60,75 @@ class DatasetConstructorTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testTensorSliceDataset(self): + def assertSparseValuesEqual(self, a, b): + self.assertAllEqual(a.indices, b.indices) + self.assertAllEqual(a.values, b.values) + self.assertAllEqual(a.dense_shape, b.dense_shape) + + def testFromTensorsSparse(self): + """Test an dataset that represents a single tuple of tensors.""" + components = (sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([0]), + dense_shape=np.array([1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0], [1, 1]]), + values=np.array([-1, 1]), + dense_shape=np.array([2, 2]))) + + iterator = ( + dataset_ops.Dataset.from_tensors(components) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + self.assertEqual( + [tensor_shape.TensorShape(c.dense_shape) for c in components], + [shape for shape in iterator.output_shapes]) + + with self.test_session() as sess: + sess.run(init_op) + results = sess.run(get_next) + for component, result_component in zip(components, results): + self.assertSparseValuesEqual(component, result_component) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testFromTensorsMixed(self): + """Test an dataset that represents a single tuple of tensors.""" + components = (np.array(1), np.array([1, 2, 3]), np.array(37.0), + sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([0]), + dense_shape=np.array([1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0], [1, 1]]), + values=np.array([-1, 1]), + dense_shape=np.array([2, 2]))) + + iterator = ( + dataset_ops.Dataset.from_tensors(components) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + self.assertEqual([ + tensor_shape.TensorShape(c.dense_shape) + if sparse_tensor.is_sparse(c) else c.shape for c in components + ], [shape for shape in iterator.output_shapes]) + + with self.test_session() as sess: + sess.run(init_op) + results = sess.run(get_next) + for component, result_component in zip(components, results): + if sparse_tensor.is_sparse(component): + self.assertSparseValuesEqual(component, result_component) + else: + self.assertAllEqual(component, result_component) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testFromTensorSlices(self): """Test an dataset that represents the slices from a tuple of tensors.""" components = ( np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile( @@ -84,7 +153,127 @@ class DatasetConstructorTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testTensorSliceDatasetWithDict(self): + def testFromTensorSlicesSparse(self): + """Test an dataset that represents the slices from a tuple of tensors.""" + components = (sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0], [1, 0], [2, 0]]), + values=np.array([0, 0, 0]), + dense_shape=np.array([3, 1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0], [1, 1], [2, 2]]), + values=np.array([1, 2, 3]), + dense_shape=np.array([3, 3]))) + + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + self.assertEqual( + [tensor_shape.TensorShape(c.dense_shape[1:]) for c in components], + [shape for shape in iterator.output_shapes]) + + with self.test_session() as sess: + sess.run(init_op) + expected = [ + (sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([0]), + dense_shape=np.array([1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([1]), + dense_shape=np.array([3]))), + (sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([0]), + dense_shape=np.array([1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[1]]), + values=np.array([2]), + dense_shape=np.array([3]))), + (sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([0]), + dense_shape=np.array([1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[2]]), + values=np.array([3]), + dense_shape=np.array([3]))), + ] + for i in range(3): + results = sess.run(get_next) + for component, result_component in zip(expected[i], results): + self.assertSparseValuesEqual(component, result_component) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testFromTensorSlicesMixed(self): + """Test an dataset that represents the slices from a tuple of tensors.""" + components = (np.tile(np.array([[1], [2], [3]]), 20), + np.tile(np.array([[12], [13], [14]]), 22), + np.array([37.0, 38.0, 39.0]), + sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0], [1, 0], [2, 0]]), + values=np.array([0, 0, 0]), + dense_shape=np.array([3, 1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0], [1, 1], [2, 2]]), + values=np.array([1, 2, 3]), + dense_shape=np.array([3, 3]))) + + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + self.assertEqual([ + tensor_shape.TensorShape(c.dense_shape[1:]) + if sparse_tensor.is_sparse(c) else c.shape[1:] for c in components + ], [shape for shape in iterator.output_shapes]) + + with self.test_session() as sess: + sess.run(init_op) + expected = [ + (sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([0]), + dense_shape=np.array([1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([1]), + dense_shape=np.array([3]))), + (sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([0]), + dense_shape=np.array([1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[1]]), + values=np.array([2]), + dense_shape=np.array([3]))), + (sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([0]), + dense_shape=np.array([1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[2]]), + values=np.array([3]), + dense_shape=np.array([3]))), + ] + for i in range(3): + results = sess.run(get_next) + for component, result_component in zip( + (zip(*components[:3])[i] + expected[i]), results): + if sparse_tensor.is_sparse(component): + self.assertSparseValuesEqual(component, result_component) + else: + self.assertAllEqual(component, result_component) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testFromTensorSlicesWithDict(self): components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]} iterator = (dataset_ops.Dataset.from_tensor_slices(components) .make_initializable_iterator()) @@ -105,7 +294,7 @@ class DatasetConstructorTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testSparseTensorSliceDataset(self): + def testFromSparseTensorSlices(self): """Test a dataset based on slices of a `tf.SparseTensor`.""" st = array_ops.sparse_placeholder(dtypes.float64) iterator = (dataset_ops.Dataset.from_sparse_tensor_slices(st) diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py index 0a9e99fd99eaff03ae242ca6cf9cc5e231da3038..bf25cc60a1c0efc09bed6501fd2d6f4ccb07764b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py @@ -23,9 +23,11 @@ import os import numpy as np from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.training import saver as saver_lib @@ -63,6 +65,8 @@ class DatasetSerializationTestBase(test.TestCase): ds_fn1, num_outputs, sparse_tensors=sparse_tensors) self.verify_reset_restored_iterator( ds_fn1, num_outputs, sparse_tensors=sparse_tensors) + self.verify_restore_in_empty_graph( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) if ds_fn2: self.verify_restore_in_modified_graph( ds_fn1, ds_fn2, num_outputs, sparse_tensors=sparse_tensors) @@ -229,6 +233,7 @@ class DatasetSerializationTestBase(test.TestCase): ds_fn, sparse_tensors=sparse_tensors) with self.test_session(graph=g) as sess: self._restore(saver, sess) + sess.run(variables.global_variables_initializer()) sess.run(init_op) for _ in range(num_outputs): actual.append(sess.run(get_next_op)) @@ -299,6 +304,97 @@ class DatasetSerializationTestBase(test.TestCase): self.match(expected, actual) + def verify_restore_in_empty_graph(self, + ds_fn, + num_outputs, + break_point=None, + sparse_tensors=False, + verify_exhausted=True): + """Attempts to restore an iterator in an empty graph. + + Builds an input pipeline using ds_fn, runs it for `break_point` steps + and saves a checkpoint. Then builds a new empty graph, restores + the checkpoint from ds_fn and verifies that the restore is successful. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + break_point: Break point. Optional. Defaults to num_outputs/2. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. + + Raises: + AssertionError if any test fails. + """ + break_point = num_outputs // 2 if not break_point else break_point + + # Skip `break_point` items and store the remaining produced from ds_fn + # in `expected`. + self.gen_outputs( + ds_fn, [], + break_point, + sparse_tensors=sparse_tensors, + verify_exhausted=False) + expected = self.gen_outputs( + ds_fn, [], + num_outputs - break_point, + ckpt_saved=True, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) + + # Generate `break_point` items from ds_fn and save checkpoint. + self.gen_outputs( + ds_fn, [], + break_point, + sparse_tensors=sparse_tensors, + verify_exhausted=False) + + actual = [] + # Build an empty graph but load checkpoint for ds_fn. + with ops.Graph().as_default() as g: + get_next_op, saver = self._build_empty_graph( + ds_fn, sparse_tensors=sparse_tensors) + with self.test_session(graph=g) as sess: + self._restore(saver, sess) + for _ in range(num_outputs - break_point): + actual.append(sess.run(get_next_op)) + if verify_exhausted: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + self.match(expected, actual) + + def verify_error_on_save(self, + ds_fn, + num_outputs, + error, + break_point=None, + sparse_tensors=False): + """Attempts to save a non-saveable iterator. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + error: Declared error when trying to save iterator. + break_point: Break point. Optional. Defaults to num_outputs/2. + sparse_tensors: See `run_core_tests`. + + Raises: + AssertionError if any test fails. + """ + + break_point = num_outputs // 2 if not break_point else break_point + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = self._build_graph( + ds_fn, sparse_tensors=sparse_tensors) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for _ in range(break_point): + sess.run(get_next_op) + with self.assertRaises(error): + self._save(sess, saver) + def verify_run_with_breaks(self, ds_fn, break_points, @@ -395,9 +491,11 @@ class DatasetSerializationTestBase(test.TestCase): with self.test_session(graph=g) as sess: if ckpt_saved: if init_before_restore: + sess.run(variables.global_variables_initializer()) sess.run(init_op) self._restore(saver, sess) else: + sess.run(variables.global_variables_initializer()) sess.run(init_op) start = break_points[i - 1] if i > 0 else 0 end = break_points[i] if i < len(break_points) else num_outputs @@ -466,6 +564,18 @@ class DatasetSerializationTestBase(test.TestCase): saver = saver_lib.Saver(allow_empty=True) return init_op, get_next, saver + def _build_empty_graph(self, ds_fn, sparse_tensors=False): + iterator = iterator_ops.Iterator.from_structure( + self._get_output_types(ds_fn), self._get_output_shapes(ds_fn)) + saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + if sparse_tensors: + get_next = sparse_tensor.SparseTensor(*iterator.get_next()) + else: + get_next = iterator.get_next() + saver = saver_lib.Saver(allow_empty=True) + return get_next, saver + def _add_iterator_ops_to_collection(self, init_op, get_next, @@ -495,6 +605,10 @@ class DatasetSerializationTestBase(test.TestCase): with ops.Graph().as_default(): return ds_fn().output_types + def _get_output_shapes(self, ds_fn): + with ops.Graph().as_default(): + return ds_fn().output_shapes + def _ckpt_path(self): return os.path.join(self.get_temp_dir(), "iterator") diff --git a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py index 67c49d77e2489a942fbf79286ec6ebc0af29a45e..5921be2ae89ba1bbbb8d6e3a509cf49c65949544 100644 --- a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -131,9 +132,12 @@ class FilterDatasetTest(test.TestCase): self.assertAllEqual(a.dense_shape, b.dense_shape) def testSparse(self): + def _map_fn(i): - return sparse_tensor.SparseTensor( - indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i + return sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0]]), + values=(i * np.array([1])), + dense_shape=np.array([1, 1])), i def _filter_fn(_, i): return math_ops.equal(i % 2, 0) @@ -148,13 +152,48 @@ class FilterDatasetTest(test.TestCase): sess.run(init_op) for i in range(5): actual = sess.run(get_next) - expected = sparse_tensor.SparseTensor( - indices=[[0, 0]], values=[i*2], dense_shape=[1, 1]) self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue)) - self.assertSparseValuesEqual(actual, expected.eval()) + self.assertSparseValuesEqual(actual, _map_fn(i * 2)[0]) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) +class FilterDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_filter_range_graph(self, div): + return dataset_ops.Dataset.range(100).filter( + lambda x: math_ops.not_equal(math_ops.mod(x, div), 2)) + + def testFilterCore(self): + div = 3 + num_outputs = np.sum([x % 3 is not 2 for x in range(100)]) + self.run_core_tests(lambda: self._build_filter_range_graph(div), + lambda: self._build_filter_range_graph(div * 2), + num_outputs) + + def _build_filter_dict_graph(self): + return dataset_ops.Dataset.range(10).map( + lambda x: {"foo": x * 2, "bar": x ** 2}).filter( + lambda d: math_ops.equal(d["bar"] % 2, 0)).map( + lambda d: d["foo"] + d["bar"]) + + def testFilterDictCore(self): + num_outputs = np.sum([(x**2) % 2 == 0 for x in range(10)]) + self.run_core_tests(self._build_filter_dict_graph, None, num_outputs) + + def _build_sparse_filter(self): + + def _map_fn(i): + return sparse_tensor.SparseTensor( + indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i + + def _filter_fn(_, i): + return math_ops.equal(i % 2, 0) + + return dataset_ops.Dataset.range(10).map(_map_fn).filter(_filter_fn).map( + lambda x, i: x) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py index c950e4857ef0d4d1340fdded1010800e6771939e..d4fbaa5cdcdd315aa0524134b48eb0515169722c 100644 --- a/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py @@ -21,11 +21,18 @@ import random import numpy as np +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.python.client import session +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import function from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test from tensorflow.python.training import server_lib @@ -124,7 +131,7 @@ class FlatMapDatasetTest(test.TestCase): def testSparse(self): def _map_fn(i): - return sparse_tensor.SparseTensor( + return sparse_tensor.SparseTensorValue( indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) def _flat_map_fn(x): @@ -147,5 +154,77 @@ class FlatMapDatasetTest(test.TestCase): sess.run(get_next) +class FlatMapDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testCore(self): + # Complicated way of saying range(start, start+25). + def build_ds(start): + + def map_fn(x): + return dataset_ops.Dataset.range(x, x + 5) + + return dataset_ops.Dataset.range(start, start + 5 * 5, 5).flat_map(map_fn) + + self.run_core_tests(lambda: build_ds(0), lambda: build_ds(10), 25) + + def testMapThenFlatMap(self): + + def build_ds(): + + def flat_map_fn(_): + + def map_fn(y): + return 10 * math_ops.to_int32(y) + + return dataset_ops.Dataset.range(100).map(map_fn) + + return dataset_ops.Dataset.range(5).flat_map(flat_map_fn) + + self.run_core_tests(build_ds, None, 500) + + def testCaptureDefunInMapFn(self): + + def build_ds(): + + def map_fn(x): + + @function.Defun(dtypes.int64) + def defun_fn(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return dataset_ops.Dataset.from_tensor_slices([defun_fn(x)]) + + return dataset_ops.Dataset.range(100).flat_map(map_fn) + + self.run_core_tests(build_ds, None, 100) + + def testDisallowVariableCapture(self): + + def build_ds(): + test_var = variable_scope.get_variable( + name="test_var", shape=(), use_resource=True) + return dataset_ops.Dataset.range(5).flat_map( + lambda _: dataset_ops.Dataset.from_tensor_slices([test_var])) + + self.verify_error_on_save(build_ds, 5, errors.InvalidArgumentError) + + def testDisallowCapturingStatefulOps(self): + + def build_ds(): + + def flat_map_fn(_): + + def map_fn(x): + return random_ops.random_uniform( + (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x) + + return dataset_ops.Dataset.range(100).map(map_fn) + + return dataset_ops.Dataset.range(5).flat_map(flat_map_fn) + + self.verify_error_on_save(build_ds, 500, errors.InvalidArgumentError) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py index 0299e3a1b7d240e75b869ef4595293f691958623..e66ed3f7aa2a512813ef353d2d0744ae67005884 100644 --- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py @@ -22,8 +22,10 @@ import math import threading import time +import numpy as np from six.moves import zip_longest +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.python.framework import dtypes @@ -185,8 +187,9 @@ class InterleaveDatasetTest(test.TestCase): sess.run(next_element) def testSparse(self): + def _map_fn(i): - return sparse_tensor.SparseTensor( + return sparse_tensor.SparseTensorValue( indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) def _interleave_fn(x): @@ -209,6 +212,46 @@ class InterleaveDatasetTest(test.TestCase): sess.run(get_next) +class InterleaveDatasetSeriazationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_iterator_graph(self, input_values, cycle_length, block_length): + repeat_count = 2 + return dataset_ops.Dataset.from_tensor_slices(input_values).repeat( + repeat_count).interleave( + lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), + cycle_length, block_length) + + def testSerializationCore(self): + input_values = np.array([4, 5, 6], dtype=np.int64) + num_outputs = np.sum(input_values) * 2 + # cycle_length > 1, block_length > 1 + cycle_length = 2 + block_length = 3 + # pylint: disable=g-long-lambda + self.run_core_tests( + lambda: self._build_iterator_graph( + input_values, cycle_length, block_length), + lambda: self._build_iterator_graph( + input_values, cycle_length * 2, block_length * 1), + num_outputs) + # cycle_length = 1 + cycle_length = 1 + block_length = 3 + self.run_core_tests( + lambda: self._build_iterator_graph( + input_values, cycle_length, block_length), + None, num_outputs) + # block_length = 1 + cycle_length = 2 + block_length = 1 + self.run_core_tests( + lambda: self._build_iterator_graph( + input_values, cycle_length, block_length), + None, num_outputs) + # pylint: enable=g-long-lambda + + class ParallelInterleaveDatasetTest(test.TestCase): def setUp(self): diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index d8e7f9d5933b4291b2d905aeb3c54439e0958a4c..e9a07da84a8c80c09ebd4dab0b1d69febe1c9790 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -23,10 +23,9 @@ import threading import numpy as np +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.data.python.ops import error_ops -from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops -from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -44,10 +43,7 @@ from tensorflow.python.ops import script_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import string_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables -from tensorflow.python.platform import gfile from tensorflow.python.platform import test -from tensorflow.python.training import saver as saver_lib from tensorflow.python.util import compat @@ -630,9 +626,13 @@ class MapDatasetTest(test.TestCase): self.assertAllEqual(a.dense_shape, b.dense_shape) def testSparse(self): + def _sparse(i): - return sparse_tensor.SparseTensor( - indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]) + return sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0]]), + values=(i * np.array([1])), + dense_shape=np.array([1, 1])) + iterator = (dataset_ops.Dataset.range(10) .map(_sparse) .make_initializable_iterator()) @@ -643,24 +643,26 @@ class MapDatasetTest(test.TestCase): sess.run(init_op) for i in range(10): actual = sess.run(get_next) - expected = sparse_tensor.SparseTensor( - indices=[[0, 0]], values=[i], dense_shape=[1, 1]) self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue)) - self.assertSparseValuesEqual(actual, expected.eval()) + self.assertSparseValuesEqual(actual, _sparse(i)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) def testSparseChain(self): + def _sparse(i): - return sparse_tensor.SparseTensor( - indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]) + return sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0]]), + values=(i * np.array([1])), + dense_shape=np.array([1, 1])) + def _check(i): - self.assertTrue(isinstance(i, sparse_tensor.SparseTensor)) + self.assertTrue(sparse_tensor.is_sparse(i)) return sparse_ops.sparse_concat(0, [i, i]) - iterator = (dataset_ops.Dataset.range(10) - .map(_sparse).map(_check) - .make_initializable_iterator()) + iterator = ( + dataset_ops.Dataset.range(10).map(_sparse).map(_check) + .make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -668,10 +670,8 @@ class MapDatasetTest(test.TestCase): sess.run(init_op) for i in range(10): actual = sess.run(get_next) - expected = sparse_tensor.SparseTensor( - indices=[[0, 0], [1, 0]], values=[i, i], dense_shape=[2, 1]) self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue)) - self.assertSparseValuesEqual(actual, expected.eval()) + self.assertSparseValuesEqual(actual, _check(_sparse(i)).eval()) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @@ -702,20 +702,14 @@ class MapDatasetTest(test.TestCase): sess.run(init_op) -class MapDatasetSerializationTest(test.TestCase): +class MapDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): def setUp(self): self._tensor_slice_len = 7 self._num_epochs = 14 self._num_outputs = self._tensor_slice_len * self._num_epochs - def tearDown(self): - # Remove all checkpoint files. - prefix = self._ckpt_path() - pattern = prefix + "*" - files = gfile.Glob(pattern) - map(gfile.Remove, files) - def _build_ds(self, multiplier=37.0): components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * np.arange(self._tensor_slice_len)[:, np.newaxis], @@ -727,292 +721,11 @@ class MapDatasetSerializationTest(test.TestCase): return (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) .repeat(self._num_epochs)) - def _build_graph(self, multiplier=37.0, build_saveable=True): - ds = self._build_ds(multiplier) - iterator = ds.make_initializable_iterator() - - if build_saveable: - saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) - init_op = iterator.initializer - get_next = iterator.get_next() - self._add_iterator_ops_to_collection(init_op, get_next) - saver = saver_lib.Saver(allow_empty=True) - return init_op, get_next, saver - - def _build_empty_graph(self, output_types, output_shapes): - iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes) - saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) - saver = saver_lib.Saver() - get_next = iterator.get_next() - return get_next, saver - - def _add_iterator_ops_to_collection(self, init_op, get_next): - ops.add_to_collection("iterator_ops", init_op) - ops.add_to_collection("iterator_ops", get_next[0]) - ops.add_to_collection("iterator_ops", get_next[1]) - ops.add_to_collection("iterator_ops", get_next[2]) - - def _get_iterator_ops_from_collection(self): - init_op, get_next_1, get_next_2, get_next_3 = ops.get_collection( - "iterator_ops") - return init_op, (get_next_1, get_next_2, get_next_3) - - def _ckpt_path(self): - return os.path.join(self.get_temp_dir(), "iterator") - - def _latest_ckpt(self): - return saver_lib.latest_checkpoint(self.get_temp_dir()) - - def _save(self, sess, saver): - saver.save(sess, self._ckpt_path()) - - def _restore(self, saver, sess): - saver.restore(sess, self._latest_ckpt()) - - def _import_meta_graph(self): - meta_file_path = self._ckpt_path() + ".meta" - return saver_lib.import_meta_graph(meta_file_path) - - def _testReadWithBreaks(self, break_points, init_before_restore=False): - expected = [] - actual = [] - # Generate the ground truth. - with ops.Graph().as_default() as g: - init_op, get_next_op, _ = self._build_graph() - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(self._num_outputs): - expected.append(sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - # Run and checkpoint after first break_point. - with ops.Graph().as_default() as g: - init_op, get_next_op, saver = self._build_graph() - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(break_points[0]): - actual.append(sess.run(get_next_op)) - self._save(sess, saver) - - # Load from checkpoint and continue running while stopping at each - # subsequent checkpoint. - for i in range(len(break_points)): - with ops.Graph().as_default() as g: - saver = self._import_meta_graph() - init_op, get_next_op = self._get_iterator_ops_from_collection() - with self.test_session(graph=g) as sess: - if init_before_restore: - sess.run(init_op) - self._restore(saver, sess) - start = break_points[i] - end = break_points[ - i + 1] if i < len(break_points) - 1 else self._num_outputs - for _ in range(end - start): - actual.append(sess.run(get_next_op)) - self._save(sess, saver) - if end == self._num_outputs: - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - self._match(expected, actual) - - def _match(self, expected, actual): - self.assertEqual(len(expected), len(actual)) - for expected_tuple, actual_tuple in zip(expected, actual): - self.assertEqual(expected_tuple[0], actual_tuple[0]) - self.assertSequenceEqual(expected_tuple[1].tolist(), - actual_tuple[1].tolist()) - self.assertEqual(expected_tuple[2], actual_tuple[2]) - - def _does_not_match(self, expected, actual): - with self.assertRaises(AssertionError): - self._match(expected, actual) - - def testSaveRestore(self): - self._testReadWithBreaks([4]) - self._testReadWithBreaks([13]) - self._testReadWithBreaks([18]) - self._testReadWithBreaks([23]) - - def testSaveUnusedIterator(self): - self._testReadWithBreaks([0]) - - def testSaveFullyUsedIterator(self): - self._testReadWithBreaks([self._num_outputs]) - - def testMultipleBreaks(self): - self._testReadWithBreaks([0, 5, 9, 15, 25, 32]) - - def testIdempotence(self): - # Attempt to save iterator immediately after restoring. - self._testReadWithBreaks([1, 1, 5, 5, 5, 25, 32]) - - def testInitThenRestore(self): - self._testReadWithBreaks([0, 5, 9, 15, 25, 32], init_before_restore=True) - - def testRestoreExhaustedIterator(self): - with ops.Graph().as_default() as g: - init_op, get_next_op, saver = self._build_graph() - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(self._num_outputs): - sess.run(get_next_op) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - self._save(sess, saver) - - with ops.Graph().as_default() as g: - saver = self._import_meta_graph() - init_op, get_next_op = self._get_iterator_ops_from_collection() - with self.test_session(graph=g) as sess: - self._restore(saver, sess) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - def testResetRestoredIterator(self): - expected = [] - # Collect ground truth containing all outputs. - with ops.Graph().as_default() as g: - init_op, get_next_op, saver = self._build_graph() - break_point = self._num_outputs // 2 - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(break_point): - expected.append(sess.run(get_next_op)) - self._save(sess, saver) - for _ in range(self._num_outputs - break_point): - expected.append(sess.run(get_next_op)) - - actual = [] - # Restore from checkpoint and then run init_op. - with ops.Graph().as_default() as g: - saver = self._import_meta_graph() - init_op, get_next_op = self._get_iterator_ops_from_collection() - with self.test_session(graph=g) as sess: - self._restore(saver, sess) - sess.run(init_op) - for _ in range(self._num_outputs): - actual.append(sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - self._match(expected, actual) - - def testRestoreInModifiedGraph(self): - expected = [] - actual_without_restore = [] - actual = [] - break_point = 10 - with ops.Graph().as_default() as g: - init_op, get_next_op, saver = self._build_graph(multiplier=15.0) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(break_point): - expected.append(sess.run(get_next_op)) - actual.extend(expected) - self._save(sess, saver) - for _ in range(self._num_outputs - break_point): - expected.append(sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - # Collect outputs by running modified graph. - with ops.Graph().as_default() as g: - init_op, get_next_op, saver = self._build_graph(multiplier=30.0) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(self._num_outputs): - actual_without_restore.append(sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - # Restore the checkpoint in the modified graph. - with ops.Graph().as_default() as g: - init_op, get_next_op, saver = self._build_graph(multiplier=30.0) - with self.test_session(graph=g) as sess: - self._restore(saver, sess) - for _ in range(self._num_outputs - break_point): - actual.append(sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - # Ensure the modified graph gets overridden when restoring checkpoint. - self._does_not_match(expected, actual_without_restore) - # Expect that the outputs are what we would expect if we ran the old - # graph. - self._match(expected, actual) - - # TODO(srbs): Add this test to dataset_serialization_test_base.py. - def testRestoreInEmptyGraph(self): - expected = [] - actual = [] - break_point = 10 - with ops.Graph().as_default() as g: - init_op, get_next_op, saver = self._build_graph(multiplier=15.0) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(break_point): - sess.run(get_next_op) - self._save(sess, saver) - for _ in range(self._num_outputs - break_point): - expected.append(sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - with ops.Graph().as_default() as g: - ds = self._build_ds() - output_types = ds.output_types - output_shapes = ds.output_shapes - - with ops.Graph().as_default() as g: - get_next_op, saver = self._build_empty_graph(output_types, output_shapes) - with self.test_session(graph=g) as sess: - self._restore(saver, sess) - for _ in range(self._num_outputs - break_point): - actual.append(sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - # Expect that the outputs are what we would expect if we ran the old - # graph. - self._match(expected, actual) - - def testDoNotBuildSaveable(self): - break_point = 10 - with ops.Graph().as_default() as g: - init_op, get_next_op, saver = self._build_graph(multiplier=15.0) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(break_point): - sess.run(get_next_op) - self._save(sess, saver) - - expected = [] - # Collect ground truth by running modified graph. - with ops.Graph().as_default() as g: - init_op, get_next_op, saver = self._build_graph(multiplier=30.0) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(self._num_outputs): - expected.append(sess.run(get_next_op)) - - actual = [] - with ops.Graph().as_default() as g: - init_op, get_next_op, saver = self._build_graph( - multiplier=30.0, build_saveable=False) - with self.test_session(graph=g) as sess: - # Since the SaveableObject was not added to Saver's list - # of saveables, iterator state is not restored by saver.restore(). - self._restore(saver, sess) - with self.assertRaises(errors.FailedPreconditionError): - sess.run(get_next_op) - sess.run(init_op) - for _ in range(self._num_outputs): - actual.append(sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - self._match(expected, actual) + def testSaveRestoreCore(self): + self.run_core_tests( + self._build_ds, + lambda: self._build_ds(multiplier=15.0), + self._num_outputs) def testSaveStatefulFunction(self): @@ -1024,26 +737,7 @@ class MapDatasetSerializationTest(test.TestCase): return dataset_ops.Dataset.range(100).map(_map_fn) - def _build_graph(): - ds = _build_ds() - iterator = ds.make_initializable_iterator() - - saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) - init_op = iterator.initializer - get_next = iterator.get_next() - saver = saver_lib.Saver(allow_empty=True) - return init_op, get_next, saver - - break_point = 10 - with ops.Graph().as_default() as g: - init_op, get_next_op, saver = _build_graph() - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(break_point): - sess.run(get_next_op) - with self.assertRaises(errors.InvalidArgumentError): - self._save(sess, saver) + self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) def testCaptureVariableInMapFn(self): @@ -1053,27 +747,7 @@ class MapDatasetSerializationTest(test.TestCase): return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( lambda _: counter_var.assign_add(1))) - def _build_graph(): - ds = _build_ds() - iterator = ds.make_initializable_iterator() - - saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) - init_op = iterator.initializer - get_next = iterator.get_next() - saver = saver_lib.Saver(allow_empty=True) - return init_op, get_next, saver - - break_point = 10 - with ops.Graph().as_default() as g: - init_op, get_next_op, saver = _build_graph() - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - for _ in range(break_point): - sess.run(get_next_op) - with self.assertRaises(errors.InvalidArgumentError): - self._save(sess, saver) + self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) def testCaptureDefunInMapFn(self): num_outputs = 100 @@ -1086,46 +760,7 @@ class MapDatasetSerializationTest(test.TestCase): return dataset_ops.Dataset.range(num_outputs).map(defun_fn) - def _build_graph(): - ds = _build_ds() - iterator = ds.make_initializable_iterator() - - saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) - init_op = iterator.initializer - get_next = iterator.get_next() - saver = saver_lib.Saver(allow_empty=True) - return init_op, get_next, saver - - break_point = 10 - expected = [] - with ops.Graph().as_default() as g: - init_op, get_next_op, saver = _build_graph() - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - for _ in range(break_point): - sess.run(get_next_op) - self._save(sess, saver) - for _ in range(num_outputs - break_point): - expected.append(sess.run(get_next_op)) - - with ops.Graph().as_default() as g: - ds = _build_ds() - output_types = ds.output_types - output_shapes = ds.output_shapes - - actual = [] - with ops.Graph().as_default() as g: - get_next_op, saver = self._build_empty_graph(output_types, output_shapes) - with self.test_session(graph=g) as sess: - self._restore(saver, sess) - for _ in range(num_outputs - break_point): - actual.append(sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - self.assertSequenceEqual(expected, actual) + self.run_core_tests(_build_ds, None, num_outputs) def testBuildDefunInMapFn(self): num_outputs = 100 @@ -1143,46 +778,23 @@ class MapDatasetSerializationTest(test.TestCase): return dataset_ops.Dataset.range(num_outputs).map(defun_fn) - def _build_graph(): - ds = _build_ds() - iterator = ds.make_initializable_iterator() + self.run_core_tests(_build_ds, None, num_outputs) - saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) - init_op = iterator.initializer - get_next = iterator.get_next() - saver = saver_lib.Saver(allow_empty=True) - return init_op, get_next, saver - break_point = 10 - expected = [] - with ops.Graph().as_default() as g: - init_op, get_next_op, saver = _build_graph() - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - for _ in range(break_point): - sess.run(get_next_op) - self._save(sess, saver) - for _ in range(num_outputs - break_point): - expected.append(sess.run(get_next_op)) +class IgnoreErrorsSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): - with ops.Graph().as_default() as g: - ds = _build_ds() - output_types = ds.output_types - output_shapes = ds.output_shapes + def _build_ds(self, components): + return dataset_ops.Dataset.from_tensor_slices(components).map( + lambda x: array_ops.check_numerics(x, "message")).apply( + error_ops.ignore_errors()) - actual = [] - with ops.Graph().as_default() as g: - get_next_op, saver = self._build_empty_graph(output_types, output_shapes) - with self.test_session(graph=g) as sess: - self._restore(saver, sess) - for _ in range(num_outputs - break_point): - actual.append(sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - self.assertSequenceEqual(expected, actual) + def testIgnoreErrorsCore(self): + components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) + diff_components = np.array([1., 2., 3., np.nan]).astype(np.float32) + num_outputs = 4 + self.run_core_tests(lambda: self._build_ds(components), + lambda: self._build_ds(diff_components), num_outputs) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_impl.py b/tensorflow/contrib/data/python/kernel_tests/prefetch_dataset_op_test.py similarity index 51% rename from tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_impl.py rename to tensorflow/contrib/data/python/kernel_tests/prefetch_dataset_op_test.py index a640dfe7dfbcce96261589c7fc49107deaefdd54..3d120a3071ef730f21221e3291d8c84385b51aa3 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_impl.py +++ b/tensorflow/contrib/data/python/kernel_tests/prefetch_dataset_op_test.py @@ -12,37 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Sigmoid bijector.""" - +"""Tests for the experimental input pipeline ops.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops.distributions import bijector - - -__all__ = [ - "Sigmoid", -] - +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test -class Sigmoid(bijector.Bijector): - """Bijector which computes `Y = g(X) = 1 / (1 + exp(-X))`.""" - def __init__(self, validate_args=False, name="sigmoid"): - super(Sigmoid, self).__init__( - event_ndims=0, validate_args=validate_args, name=name) +class PrefetchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): - def _forward(self, x): - return math_ops.sigmoid(x) + def build_dataset(self, seed): + return dataset_ops.Dataset.range(100).prefetch(10).shuffle( + buffer_size=10, seed=seed, reshuffle_each_iteration=False) - def _inverse(self, y): - return math_ops.log(y) - math_ops.log1p(-y) + def testCore(self): + num_outputs = 100 + self.run_core_tests(lambda: self.build_dataset(10), + lambda: self.build_dataset(20), num_outputs) - def _inverse_log_det_jacobian(self, y): - return -math_ops.log(y) - math_ops.log1p(-y) - def _forward_log_det_jacobian(self, x): - return -nn_ops.softplus(-x) - nn_ops.softplus(x) +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py index f59ac760dc83a504e563f055b91f1002cb0c80fc..8e6ad061a11752ab7b1ffc13c90b4fa52f67d6aa 100644 --- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import os +from tensorflow.contrib.data.python.ops import counter from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.data.python.ops import enumerate_ops from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops @@ -194,6 +195,27 @@ class RangeDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testCounter(self): + """Test dataset construction using `count`.""" + iterator = (counter.Counter(start=3, step=4) + .make_one_shot_iterator()) + get_next = iterator.get_next() + self.assertEqual([], get_next.shape.as_list()) + self.assertEqual(dtypes.int64, get_next.dtype) + + negative_iterator = (counter.Counter(start=0, step=-1) + .make_one_shot_iterator()) + negative_get_next = negative_iterator.get_next() + + with self.test_session() as sess: + self.assertEqual(3, sess.run(get_next)) + self.assertEqual(3 + 4, sess.run(get_next)) + self.assertEqual(3 + 2 * 4, sess.run(get_next)) + + self.assertEqual(0, sess.run(negative_get_next)) + self.assertEqual(-1, sess.run(negative_get_next)) + self.assertEqual(-2, sess.run(negative_get_next)) + def _iterator_checkpoint_prefix(self): return os.path.join(self.get_temp_dir(), "iterator") diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py index 6b5b53cc0f8f2d1df5622a5bc5e2f8ef04c6342a..72745ec7525ad0578934fb2051018f6531938088 100644 --- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py @@ -22,8 +22,10 @@ import os import numpy as np +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import dataset_ops as contrib_dataset_ops from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops +from tensorflow.contrib.data.python.ops import shuffle_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op @@ -474,5 +476,83 @@ class ShuffleDatasetSerializationTest(test.TestCase): self.assertEqual(expected_outputs_sorted, sorted(actual)) +class ShuffleAndRepeatTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_ds(self, seed, count=5, num_elements=20): + return dataset_ops.Dataset.range(num_elements).apply( + shuffle_ops.shuffle_and_repeat(buffer_size=5, count=count, seed=seed)) + + def testCorrectOutput(self): + output = self.gen_outputs(lambda: self._build_ds(10), [], 100) + self.assertSequenceEqual( + sorted(output), sorted( + np.array([range(20) for _ in range(5)]).flatten())) + for i in range(5): + self.assertSequenceEqual(sorted(output[i * 20:(i + 1) * 20]), range(20)) + + def testReshuffling(self): + # Check that the output orders of different epochs are indeed different. + output = self.gen_outputs(lambda: self._build_ds(10), [], 100) + for i in range(4): + epoch1 = output[i * 20:(i + 1) * 20] + epoch2 = output[(i + 1) * 20:(i + 2) * 20] + self.assertNotEqual(epoch1, epoch2) + + def testSameOrderForSameSeeds(self): + output1 = self.gen_outputs(lambda: self._build_ds(10), [], 100) + output2 = self.gen_outputs(lambda: self._build_ds(10), [], 100) + self.assertEqual(output1, output2) + + def testDifferentOrderForDifferentSeeds(self): + output1 = self.gen_outputs(lambda: self._build_ds(10), [], 100) + output2 = self.gen_outputs(lambda: self._build_ds(20), [], 100) + self.assertNotEqual(output1, output2) + self.assertEqual(sorted(output1), sorted(output2)) + + def testCountNone(self): + output1 = self.gen_outputs( + lambda: self._build_ds(10, count=None), [], 100, verify_exhausted=False) + output2 = self.gen_outputs( + lambda: self._build_ds(20, count=None), [], 100, verify_exhausted=False) + self.assertNotEqual(output1, output2) + self.assertEqual(sorted(output1), sorted(output2)) + + def testCountMinusOne(self): + output1 = self.gen_outputs( + lambda: self._build_ds(10, count=-1), [], 100, verify_exhausted=False) + output2 = self.gen_outputs( + lambda: self._build_ds(20, count=-1), [], 100, verify_exhausted=False) + self.assertNotEqual(output1, output2) + self.assertEqual(sorted(output1), sorted(output2)) + + def testInfiniteOutputs(self): + # Asserting the iterator is exhausted after producing 100 items should fail. + with self.assertRaises(AssertionError): + self.gen_outputs(lambda: self._build_ds(10, count=None), [], 100) + with self.assertRaises(AssertionError): + self.gen_outputs(lambda: self._build_ds(10, count=-1), [], 100) + + def testInfiniteEmpty(self): + with self.assertRaises(errors.OutOfRangeError): + self.gen_outputs(lambda: self._build_ds(10, count=None, num_elements=0), + [], 100) + with self.assertRaises(errors.OutOfRangeError): + self.gen_outputs(lambda: self._build_ds(10, count=-1, num_elements=0), [], + 100) + + +class ShuffleAndRepeatSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_ds(self, seed): + return dataset_ops.Dataset.range(20).apply( + shuffle_ops.shuffle_and_repeat(buffer_size=5, count=5, seed=seed)) + + def testCore(self): + self.run_core_tests(lambda: self._build_ds(10), lambda: self._build_ds(20), + 100) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..07bdf920446e953c2a1abaf495d2e9e1256106fd --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py @@ -0,0 +1,257 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline statistics gathering ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import stats_ops +from tensorflow.core.framework import summary_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class StatsDatasetTest(test.TestCase): + + def _assertSummaryHasCount(self, summary_str, tag, expected_value): + summary_proto = summary_pb2.Summary() + summary_proto.ParseFromString(summary_str) + for value in summary_proto.value: + if tag == value.tag: + self.assertEqual(expected_value, value.histo.num) + return + self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) + + def _assertSummaryHasSum(self, summary_str, tag, expected_value): + summary_proto = summary_pb2.Summary() + summary_proto.ParseFromString(summary_str) + for value in summary_proto.value: + if tag == value.tag: + self.assertEqual(expected_value, value.histo.sum) + return + self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) + + def testBytesProduced(self): + dataset = dataset_ops.Dataset.range(100).map( + lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply( + stats_ops.bytes_produced_stats("bytes_produced")) + iterator = dataset.make_initializable_iterator() + stats_aggregator = stats_ops.StatsAggregator() + stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run([iterator.initializer, stats_aggregator_subscriber]) + expected_sum = 0.0 + for i in range(100): + self.assertAllEqual( + np.array([i] * i, dtype=np.int64), sess.run(next_element)) + summary_str = sess.run(summary_t) + self._assertSummaryHasCount(summary_str, "bytes_produced", float(i + 1)) + expected_sum += i * 8.0 + self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + summary_str = sess.run(summary_t) + self._assertSummaryHasCount(summary_str, "bytes_produced", 100.0) + self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum) + + def testLatencyStats(self): + dataset = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")) + iterator = dataset.make_initializable_iterator() + stats_aggregator = stats_ops.StatsAggregator() + stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run([iterator.initializer, stats_aggregator_subscriber]) + for i in range(100): + self.assertEqual(i, sess.run(next_element)) + self._assertSummaryHasCount( + sess.run(summary_t), "record_latency", float(i + 1)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0) + + def testReinitialize(self): + dataset = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")) + iterator = dataset.make_initializable_iterator() + stats_aggregator = stats_ops.StatsAggregator() + stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run(stats_aggregator_subscriber) + for j in range(5): + sess.run(iterator.initializer) + for i in range(100): + self.assertEqual(i, sess.run(next_element)) + self._assertSummaryHasCount( + sess.run(summary_t), "record_latency", float((j * 100) + i + 1)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self._assertSummaryHasCount( + sess.run(summary_t), "record_latency", (j + 1) * 100.0) + + def testNoAggregatorRegistered(self): + dataset = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for i in range(100): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testMultipleTags(self): + dataset = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")).apply( + stats_ops.latency_stats("record_latency_2")) + iterator = dataset.make_initializable_iterator() + stats_aggregator = stats_ops.StatsAggregator() + stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run([iterator.initializer, stats_aggregator_subscriber]) + for i in range(100): + self.assertEqual(i, sess.run(next_element)) + self._assertSummaryHasCount( + sess.run(summary_t), "record_latency", float(i + 1)) + self._assertSummaryHasCount( + sess.run(summary_t), "record_latency_2", float(i + 1)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0) + self._assertSummaryHasCount( + sess.run(summary_t), "record_latency_2", 100.0) + + def testRepeatedTags(self): + dataset = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")).apply( + stats_ops.latency_stats("record_latency")) + iterator = dataset.make_initializable_iterator() + stats_aggregator = stats_ops.StatsAggregator() + stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run([iterator.initializer, stats_aggregator_subscriber]) + for i in range(100): + self.assertEqual(i, sess.run(next_element)) + self._assertSummaryHasCount( + sess.run(summary_t), "record_latency", float(2 * (i + 1))) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0) + + def testMultipleIteratorsSameAggregator(self): + dataset = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")) + iterator_0 = dataset.make_initializable_iterator() + iterator_1 = dataset.make_initializable_iterator() + stats_aggregator = stats_ops.StatsAggregator() + stats_aggregator_subscribers = [stats_aggregator.subscribe(iterator_0), + stats_aggregator.subscribe(iterator_1)] + next_element = iterator_0.get_next() + iterator_1.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run([iterator_0.initializer, iterator_1.initializer, + stats_aggregator_subscribers]) + for i in range(100): + self.assertEqual(i * 2, sess.run(next_element)) + self._assertSummaryHasCount( + sess.run(summary_t), "record_latency", float(2 * (i + 1))) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0) + + def testMultipleStatsAggregatorsSameIteratorFail(self): + dataset = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")) + iterator = dataset.make_initializable_iterator() + stats_aggregator_0 = stats_ops.StatsAggregator() + stats_aggregator_1 = stats_ops.StatsAggregator() + + with self.test_session() as sess: + sess.run(stats_aggregator_0.subscribe(iterator)) + # TODO(mrry): Consider making this allowable (and also allowing + # aggregators to unsubscribe). + with self.assertRaises(errors.FailedPreconditionError): + sess.run(stats_aggregator_1.subscribe(iterator)) + + +class StatsDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset_bytes_stats(self, num_elements): + return dataset_ops.Dataset.range(num_elements).map( + lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply( + stats_ops.bytes_produced_stats("bytes_produced")) + + def testBytesStatsDatasetSaveableCore(self): + num_outputs = 100 + self.run_core_tests( + lambda: self._build_dataset_bytes_stats(num_outputs), + lambda: self._build_dataset_bytes_stats(num_outputs // 10), num_outputs) + + def _build_dataset_latency_stats(self, num_elements, tag="record_latency"): + return dataset_ops.Dataset.range(num_elements).apply( + stats_ops.latency_stats(tag)) + + def _build_dataset_multiple_tags(self, + num_elements, + tag1="record_latency", + tag2="record_latency_2"): + return dataset_ops.Dataset.range(num_elements).apply( + stats_ops.latency_stats(tag1)).apply(stats_ops.latency_stats(tag2)) + + def testLatencyStatsDatasetSaveableCore(self): + num_outputs = 100 + + self.run_core_tests( + lambda: self._build_dataset_latency_stats(num_outputs), + lambda: self._build_dataset_latency_stats(num_outputs // 10), + num_outputs) + + self.run_core_tests(lambda: self._build_dataset_multiple_tags(num_outputs), + None, num_outputs) + + tag1 = "record_latency" + tag2 = "record_latency" + self.run_core_tests( + lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2), + None, num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index d6aaa12f5b87ea1781346aea0010f23656ffc7d0..1f35ee056b7f897ce5e7488b205ecf5a05ef0268 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -14,6 +14,7 @@ load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") py_library( name = "dataset_ops", srcs = [ + "counter.py", "dataset_ops.py", ], srcs_version = "PY2AND3", @@ -39,6 +40,25 @@ py_library( ], ) +py_library( + name = "random_ops", + srcs = [ + "random_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:constant_op", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:random_seed", + "//tensorflow/python:tensor_shape", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + py_library( name = "readers", srcs = [ @@ -61,6 +81,19 @@ py_library( ], ) +py_library( + name = "shuffle_ops", + srcs = [ + "shuffle_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":random_ops", + ":transformation_ops", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + py_library( name = "transformation_ops", srcs = [ @@ -71,6 +104,7 @@ py_library( "interleave_ops.py", "resampling.py", "scan_ops.py", + "stats_ops.py", ], srcs_version = "PY2AND3", deps = [ diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index cc63baed81334521746fea1161003615535c371f..e8b2d44a8b57d471f11b128622b6121f699fbf85 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -112,8 +112,10 @@ def filter_irregular_batches(batch_size): tensor_batch_size = ops.convert_to_tensor( batch_size, dtype=dtypes.int64, name="batch_size") - flattened = _RestructuredDataset(dataset, - tuple(nest.flatten(dataset.output_types))) + flattened = _RestructuredDataset( + dataset, + tuple(nest.flatten(dataset.output_types)), + output_classes=tuple(nest.flatten(dataset.output_classes))) def _predicate(*xs): """Return `True` if this element is a full batch.""" @@ -135,7 +137,11 @@ def filter_irregular_batches(batch_size): known_shapes = nest.map_structure(_set_first_dimension, dataset.output_shapes) - return _RestructuredDataset(filtered, dataset.output_types, known_shapes) + return _RestructuredDataset( + filtered, + dataset.output_types, + known_shapes, + output_classes=dataset.output_classes) return _apply_fn @@ -237,6 +243,10 @@ class DenseToSparseBatchDataset(dataset_ops.Dataset): output_shapes=self.output_shapes, output_types=self.output_types) + @property + def output_classes(self): + return (ops.Tensor, ops.Tensor, ops.Tensor) + @property def output_shapes(self): num_elements = tensor_shape.Dimension(None) @@ -252,7 +262,11 @@ class DenseToSparseBatchDataset(dataset_ops.Dataset): class _RestructuredDataset(dataset_ops.Dataset): """An internal helper for changing the structure and shape of a dataset.""" - def __init__(self, dataset, output_types, output_shapes=None): + def __init__(self, + dataset, + output_types, + output_shapes=None, + output_classes=None): """Creates a new dataset with the given output types and shapes. The given `dataset` must have a structure that is convertible: @@ -268,6 +282,8 @@ class _RestructuredDataset(dataset_ops.Dataset): output_types: A nested structure of `tf.DType` objects. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects. If omitted, the shapes will be inherited from `dataset`. + output_classes: (Optional.) A nested structure of class types. + If omitted, the class types will be inherited from `dataset`. Raises: ValueError: If either `output_types` or `output_shapes` is not compatible @@ -307,10 +323,21 @@ class _RestructuredDataset(dataset_ops.Dataset): output_shapes)) self._output_shapes = nest.map_structure_up_to( output_types, tensor_shape.as_shape, output_shapes) + if output_classes is None: + # Inherit class types from the original `dataset`. + self._output_classes = nest.pack_sequence_as(output_types, + nest.flatten( + dataset.output_classes)) + else: + self._output_classes = output_classes def _as_variant_tensor(self): return self._dataset._as_variant_tensor() # pylint: disable=protected-access + @property + def output_classes(self): + return self._output_classes + @property def output_types(self): return self._output_types @@ -326,10 +353,6 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): def __init__(self, input_dataset, map_func, batch_size, num_parallel_batches): """See `Dataset.map()` for details.""" super(_MapAndBatchDataset, self).__init__(input_dataset, map_func) - if sparse.any_sparse(self._output_types): - # TODO(b/63669786): support batching of sparse tensors - raise TypeError("Batching of sparse tensors is not currently supported") - self._batch_size = ops.convert_to_tensor( batch_size, dtype=dtypes.int64, name="batch_size") self._num_parallel_batches = ops.convert_to_tensor( @@ -345,8 +368,9 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): batch_size=self._batch_size, num_parallel_batches=self._num_parallel_batches, output_types=nest.flatten( - sparse.unwrap_sparse_types(self.output_types)), - output_shapes=nest.flatten(self.output_shapes)) + sparse.as_dense_types(self.output_types, self.output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes))) # pylint: enable=protected-access @property @@ -366,17 +390,12 @@ def map_and_batch(map_func, batch_size, num_parallel_batches=1): """Fused implementation of `map` and `batch`. Maps `map_func` across `batch_size` consecutive elements of this dataset - and then combines them into a batch. Similarly to `batch_and_drop_remainder`, - if the batch size does not evenly divide the input dataset size, this - transformation will drop the final smaller element. - - - Functionally, it is equivalent to `map` followed by - `batch_and_drop_remainder`. However, by fusing the two transformations - together, the implementation can be more efficient. This transformation is a - stop gap solution for performance critical workloads. Once automatic input - pipeline optimization are implemented, the fusing of map and batch will not - need to be exposed at the API level and this method will be removed. + and then combines them into a batch. Functionally, it is equivalent to `map` + followed by `batch`. However, by fusing the two transformations together, the + implementation can be more efficient. Surfacing this transformation in the API + is temporary. Once automatic input pipeline optimization is implemented, + the fusing of `map` and `batch` will happen automatically and this API will be + deprecated. Args: map_func: A function mapping a nested structure of tensors to another @@ -394,9 +413,6 @@ def map_and_batch(map_func, batch_size, num_parallel_batches=1): """ def _apply_fn(dataset): - if sparse.any_sparse(dataset.output_types): - # TODO(b/63669786): support batching of sparse tensors - raise TypeError("Batching of sparse tensors is not currently supported") return _MapAndBatchDataset(dataset, map_func, batch_size, num_parallel_batches) diff --git a/tensorflow/contrib/data/python/ops/counter.py b/tensorflow/contrib/data/python/ops/counter.py new file mode 100644 index 0000000000000000000000000000000000000000..63226fe78163c59025623a362d17c400fbe57c67 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/counter.py @@ -0,0 +1,52 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Counter Dataset.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import scan_ops + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops + + +def Counter(start=0, step=1, dtype=dtypes.int64): + """Creates a `Dataset` of a `step`-separated count startin from `start`. + + For example: + + ```python + Dataset.count() == [0, 1, 2, ...) + Dataset.count(2) == [2, 3, ...) + Dataset.count(2, 5) == [2, 7, 12, ...) + Dataset.count(0, -1) == [0, -1, -2, ...) + Dataset.count(10, -1) == [10, 9, ...) + ``` + + Args: + start: starting value for count. + step: step size. + dtype: counter data type. + + Returns: + A `Dataset` of scalar elements. + """ + with ops.name_scope("counter"): + start = ops.convert_to_tensor(start, dtype=dtype, name="start") + step = ops.convert_to_tensor(step, dtype=dtype, name="step") + return dataset_ops.Dataset.from_tensors(0).repeat(None).apply( + scan_ops.scan(start, lambda state, _: (state + step, state))) diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py index 45d6dbe7438957029b4d6b71e181cb1fc3596ecb..626a9e0edcea5928b1636c1a2a86e83657c966a5 100644 --- a/tensorflow/contrib/data/python/ops/dataset_ops.py +++ b/tensorflow/contrib/data/python/ops/dataset_ops.py @@ -21,7 +21,6 @@ from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.data.python.ops import enumerate_ops from tensorflow.contrib.data.python.ops import error_ops from tensorflow.contrib.data.python.ops import grouping - from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.ops import gen_dataset_ops @@ -48,6 +47,10 @@ class Dataset(dataset_ops.Dataset): def _as_variant_tensor(self): return self._dataset._as_variant_tensor() # pylint: disable=protected-access + @property + def output_classes(self): + return self._dataset.output_classes + @property def output_shapes(self): return self._dataset.output_shapes diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py index 194b61151390e2dcc3fa13b618003cbe5697806f..aa629cba479102ee4244884e7c546615b28cf4e5 100644 --- a/tensorflow/contrib/data/python/ops/error_ops.py +++ b/tensorflow/contrib/data/python/ops/error_ops.py @@ -63,9 +63,14 @@ class IgnoreErrorsDataset(dataset_ops.Dataset): def _as_variant_tensor(self): return gen_dataset_ops.ignore_errors_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - output_shapes=nest.flatten(self.output_shapes), + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes)), output_types=nest.flatten( - sparse.unwrap_sparse_types(self.output_types))) + sparse.as_dense_types(self.output_types, self.output_classes))) + + @property + def output_classes(self): + return self._input_dataset.output_classes @property def output_shapes(self): diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index 86337271bca79ea8bffda28fac79e41dc39f3fd3..ef91c56726e969053fdad667dda3e89430045652 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -88,15 +88,21 @@ def group_by_window(key_func, class _VariantDataset(dataset_ops.Dataset): """A Dataset wrapper for a tf.variant-typed function argument.""" - def __init__(self, dataset_variant, output_types, output_shapes): + def __init__(self, dataset_variant, output_types, output_shapes, + output_classes): super(_VariantDataset, self).__init__() self._dataset_variant = dataset_variant self._output_types = output_types self._output_shapes = output_shapes + self._output_classes = output_classes def _as_variant_tensor(self): return self._dataset_variant + @property + def output_classes(self): + return self._output_classes + @property def output_shapes(self): return self._output_shapes @@ -138,17 +144,21 @@ class GroupByWindowDataset(dataset_ops.Dataset): def _make_key_func(self, key_func, input_dataset): """Make wrapping Defun for key_func.""" - @function.Defun( - *nest.flatten(sparse.unwrap_sparse_types(input_dataset.output_types))) + @function.Defun(*nest.flatten( + sparse.as_dense_types(input_dataset.output_types, + input_dataset.output_classes))) def tf_key_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the input_dataset. - for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)): + dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, + input_dataset.output_classes) + for arg, shape in zip(args, nest.flatten(dense_shapes)): arg.set_shape(shape) nested_args = nest.pack_sequence_as(input_dataset.output_types, args) nested_args = sparse.deserialize_sparse_tensors( - nested_args, input_dataset.output_types) + nested_args, input_dataset.output_types, input_dataset.output_shapes, + input_dataset.output_classes) # pylint: disable=protected-access if dataset_ops._should_unpack_args(nested_args): ret = key_func(*nested_args) @@ -170,14 +180,15 @@ class GroupByWindowDataset(dataset_ops.Dataset): def tf_reduce_func(key, window_dataset_variant): """A wrapper for Defun that facilitates shape inference.""" key.set_shape([]) - window_dataset = _VariantDataset(window_dataset_variant, - input_dataset.output_types, - input_dataset.output_shapes) + window_dataset = _VariantDataset( + window_dataset_variant, input_dataset.output_types, + input_dataset.output_shapes, input_dataset.output_classes) if not isinstance(window_dataset, dataset_ops.Dataset): raise TypeError("`window_dataset` must return a `Dataset` object.") output_dataset = reduce_func(key, window_dataset) if not isinstance(output_dataset, dataset_ops.Dataset): raise TypeError("`reduce_func` must return a `Dataset` object.") + self._output_classes = output_dataset.output_classes self._output_types = output_dataset.output_types self._output_shapes = output_dataset.output_shapes return output_dataset._as_variant_tensor() # pylint: disable=protected-access @@ -185,6 +196,10 @@ class GroupByWindowDataset(dataset_ops.Dataset): self._reduce_func = tf_reduce_func self._reduce_func.add_to_graph(ops.get_default_graph()) + @property + def output_classes(self): + return self._output_classes + @property def output_shapes(self): return self._output_shapes @@ -203,5 +218,6 @@ class GroupByWindowDataset(dataset_ops.Dataset): reduce_func=self._reduce_func, window_size_func=self._window_size_func, output_types=nest.flatten( - sparse.unwrap_sparse_types(self.output_types)), - output_shapes=nest.flatten(self.output_shapes)) + sparse.as_dense_types(self.output_types, self.output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes))) diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index 830642c0401b281e14e4dc7f7265ab6c77bbe513..53324e06e7f1dc249388410f0e14e42336630cd1 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -36,17 +36,21 @@ class ParallelInterleaveDataset(dataset_ops.Dataset): super(ParallelInterleaveDataset, self).__init__() self._input_dataset = input_dataset - @function.Defun( - *nest.flatten(sparse.unwrap_sparse_types(input_dataset.output_types))) + @function.Defun(*nest.flatten( + sparse.as_dense_types(input_dataset.output_types, + input_dataset.output_classes))) def tf_map_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the input_dataset. - for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)): + dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, + input_dataset.output_classes) + for arg, shape in zip(args, nest.flatten(dense_shapes)): arg.set_shape(shape) nested_args = nest.pack_sequence_as(input_dataset.output_types, args) nested_args = sparse.deserialize_sparse_tensors( - nested_args, input_dataset.output_types) + nested_args, input_dataset.output_types, input_dataset.output_shapes, + input_dataset.output_classes) if dataset_ops._should_unpack_args(nested_args): # pylint: disable=protected-access dataset = map_func(*nested_args) else: @@ -55,6 +59,7 @@ class ParallelInterleaveDataset(dataset_ops.Dataset): if not isinstance(dataset, dataset_ops.Dataset): raise TypeError("`map_func` must return a `Dataset` object.") + self._output_classes = dataset.output_classes self._output_types = dataset.output_types self._output_shapes = dataset.output_shapes @@ -79,8 +84,13 @@ class ParallelInterleaveDataset(dataset_ops.Dataset): self._sloppy, f=self._map_func, output_types=nest.flatten( - sparse.unwrap_sparse_types(self.output_types)), - output_shapes=nest.flatten(self.output_shapes)) + sparse.as_dense_types(self.output_types, self.output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + + @property + def output_classes(self): + return self._output_classes @property def output_shapes(self): diff --git a/tensorflow/contrib/data/python/ops/random_ops.py b/tensorflow/contrib/data/python/ops/random_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..7d727165feabb101549567f28a2dfa07083de244 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/random_ops.py @@ -0,0 +1,67 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Datasets for random number generators.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_dataset_ops + + +class RandomDataset(dataset_ops.Dataset): + """A `Dataset` of pseudorandom values.""" + + def __init__(self, seed=None): + """A `Dataset` of pseudorandom values.""" + super(RandomDataset, self).__init__() + seed, seed2 = random_seed.get_seed(seed) + if seed is None: + self._seed = constant_op.constant(0, dtype=dtypes.int64, name="seed") + else: + self._seed = ops.convert_to_tensor(seed, dtype=dtypes.int64, name="seed") + if seed2 is None: + self._seed2 = constant_op.constant(0, dtype=dtypes.int64, name="seed2") + else: + self._seed2 = ops.convert_to_tensor( + seed2, dtype=dtypes.int64, name="seed2") + + def _as_variant_tensor(self): + return gen_dataset_ops.random_dataset( + seed=self._seed, + seed2=self._seed2, + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes)), + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes))) + + @property + def output_classes(self): + return ops.Tensor + + @property + def output_shapes(self): + return tensor_shape.scalar() + + @property + def output_types(self): + return dtypes.int64 diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index 632082b5f1edb6c3aa25cacb0d4831f9e9e7488c..347e5edc7b0d479dfa260e8cec500ffaaba375be 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -164,7 +164,7 @@ def read_batch_features(file_pattern, shuffling but would increase memory usage and startup time. Returns: - A dict from keys in features to Tensor or SparseTensor objects. + A dict from keys in features to `Tensor` or `SparseTensor` objects. """ filenames = _get_file_names(file_pattern, randomize_input) if reader_args: @@ -179,6 +179,7 @@ def read_batch_features(file_pattern, dataset = dataset.shuffle(capacity) dataset = dataset.batch(batch_size) dataset = dataset.map(lambda x: parsing_ops.parse_example(x, features)) + dataset = dataset.prefetch(1) iterator = dataset.make_one_shot_iterator() outputs = iterator.get_next() return outputs @@ -269,6 +270,10 @@ class _SqlDataset(dataset_ops.Dataset): nest.flatten(self.output_types), nest.flatten(self.output_shapes)) + @property + def output_classes(self): + return nest.map_structure(lambda _: ops.Tensor, self._output_types) + @property def output_shapes(self): return nest.map_structure(lambda _: tensor_shape.TensorShape([]), diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py index 2cfc0709cda37491f8cfa61c4f05b380931ab603..2744786e9eec4c9268ba854df6ea761339bb0b4e 100644 --- a/tensorflow/contrib/data/python/ops/scan_ops.py +++ b/tensorflow/contrib/data/python/ops/scan_ops.py @@ -53,6 +53,7 @@ class _ScanDataset(dataset_ops.Dataset): [t.dtype for t in nest.flatten(self._initial_state)]) # Will be populated by calling `tf_scan_func`. + self._output_classes = None self._output_shapes = None self._output_types = None @@ -68,13 +69,16 @@ class _ScanDataset(dataset_ops.Dataset): flat_new_state_shapes = [] @function.Defun(*(flat_state_types + nest.flatten( - sparse.unwrap_sparse_types(input_dataset.output_types)))) + sparse.as_dense_types(input_dataset.output_types, + input_dataset.output_classes)))) def tf_scan_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the state and input_dataset. - for arg, shape in zip( - args, - flat_state_shapes + nest.flatten(input_dataset.output_shapes)): + # TODO(b/69424092): Check that neither inputs nor outputs are sparse. + dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, + input_dataset.output_classes) + for arg, shape in zip(args, + flat_state_shapes + nest.flatten(dense_shapes)): arg.set_shape(shape) pivot = len(flat_state_shapes) @@ -108,6 +112,8 @@ class _ScanDataset(dataset_ops.Dataset): "state. Expected %s; got %s." % (self._state_types, nest.pack_sequence_as( self._state_types, [t.dtype for t in flat_new_state]))) + self._output_classes = nest.pack_sequence_as( + output_value, [ops.Tensor for _ in flat_output_value]) self._output_types = nest.pack_sequence_as( output_value, [t.dtype for t in flat_output_value]) @@ -147,8 +153,13 @@ class _ScanDataset(dataset_ops.Dataset): self._scan_func.captured_inputs, f=self._scan_func, output_types=nest.flatten( - sparse.unwrap_sparse_types(self.output_types)), - output_shapes=nest.flatten(self.output_shapes)) + sparse.as_dense_types(self.output_types, self.output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + + @property + def output_classes(self): + return self._output_classes @property def output_shapes(self): diff --git a/tensorflow/contrib/data/python/ops/shuffle_ops.py b/tensorflow/contrib/data/python/ops/shuffle_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..410989fad4f2a3bb8c9051c094ce8ab7b2eee96c --- /dev/null +++ b/tensorflow/contrib/data/python/ops/shuffle_ops.py @@ -0,0 +1,120 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental shuffle ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed +from tensorflow.python.ops import gen_dataset_ops + + +class _ShuffleAndRepeatDataset(dataset_ops.Dataset): + """A `Dataset` that fuses `shuffle` and `repeat`.""" + + def __init__(self, + input_dataset, + buffer_size, + count=None, + seed=None): + """See `Dataset.map()` for details.""" + super(_ShuffleAndRepeatDataset, self).__init__() + self._input_dataset = input_dataset + self._buffer_size = ops.convert_to_tensor( + buffer_size, dtype=dtypes.int64, name="buffer_size") + if count is None: + self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count") + else: + self._count = ops.convert_to_tensor( + count, dtype=dtypes.int64, name="count") + + seed, seed2 = random_seed.get_seed(seed) + if seed is None: + self._seed = constant_op.constant(0, dtype=dtypes.int64, name="seed") + else: + self._seed = ops.convert_to_tensor(seed, dtype=dtypes.int64, name="seed") + if seed2 is None: + self._seed2 = constant_op.constant(0, dtype=dtypes.int64, name="seed2") + else: + self._seed2 = ops.convert_to_tensor( + seed2, dtype=dtypes.int64, name="seed2") + + def _as_variant_tensor(self): + # pylint: disable=protected-access + input_resource = self._input_dataset._as_variant_tensor() + return gen_dataset_ops.shuffle_and_repeat_dataset( + input_resource, + buffer_size=self._buffer_size, + count=self._count, + seed=self._seed, + seed2=self._seed2, + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + # pylint: enable=protected-access + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types + + +def shuffle_and_repeat(buffer_size, count=None, seed=None): + """Shuffles and repeats a Dataset returning a new permutation for each epoch. + + `dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size, count))` + + is equivalent to + + `dataset.shuffle(buffer_size, reshuffle_each_iteration=True).repeat(count)` + + The difference is that the latter dataset is not serializable. So, + if you need to checkpoint an input pipeline with reshuffling you must use + this implementation. + + Args: + buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the + maximum number elements that will be buffered when prefetching. + count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the + number of times the dataset should be repeated. The default behavior + (if `count` is `None` or `-1`) is for the dataset be repeated + indefinitely. + seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the + random seed that will be used to create the distribution. See + @{tf.set_random_seed} for behavior. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.contrib.data.Dataset.apply}. + """ + + def _apply_fn(dataset): # pylint: disable=missing-docstring + return _ShuffleAndRepeatDataset(dataset, buffer_size, count, seed) + + return _apply_fn diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..b8875bd533ddc9e2c195646619dccf3aab5225e4 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/stats_ops.py @@ -0,0 +1,177 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental API for gathering statistics from `tf.data` pipelines.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.data.util import nest +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops + + +class StatsAggregator(object): + """A stateful resource that aggregates statistics from one or more iterators. + + To record statistics, use one of the custom transformation functions defined + in this module when defining your @{tf.data.Dataset}. All statistics will be + aggregated by the `StatsAggregator` that is associated with a particular + iterator (see below). For example, to record the total number of bytes + produced by iterating over a dataset: + + ```python + dataset = ... + dataset = dataset.apply(stats_ops.bytes_produced_stats("total_bytes")) + ``` + + To associate a `StatsAggregator` with a @{tf.data.Iterator} object, use + the following pattern: + + ```python + dataset = ... + iterator = dataset.make_one_shot_iterator() + stats_aggregator = stats_ops.StatsAggregator() + set_op = stats_op.set_stats_aggregator_op(iterator, stats_aggregator) + + with tf.Session() as sess: + # Running `set_op` will associate `iterator` with `stats_aggregator`. + sess.run(set_op) + ``` + + To get a protocol buffer summary of the currently aggregated statistics, + use the `StatsAggregator.get_summary()` tensor. The easiest way to do this + is to add the returned tensor to the @{tf.GraphKeys.SUMMARIES} collection, + so that the summaries will be included with any existing summaries. + + ```python + stats_aggregator = stats_ops.StatsAggregator() + stats_summary = stats_aggregator.get_summary() + tf.add_to_collection(tf.GraphKeys.SUMMARIES, stats_summary) + ``` + + Note: This interface is experimental and expected to change. In particular, + we expect to add other implementations of `StatsAggregator` that provide + different ways of exporting statistics, and add more types of statistics. + """ + + def __init__(self): + """Creates a `StatsAggregator`.""" + self._resource = gen_dataset_ops.stats_aggregator_handle() + + def get_summary(self): + """Returns a string @{tf.Tensor} that summarizes the aggregated statistics. + + The returned tensor will contain a serialized @{tf.summary.Summary} protocol + buffer, which can be used with the standard TensorBoard logging facilities. + + Returns: + A scalar string @{tf.Tensor} that summarizes the aggregated statistics. + """ + return gen_dataset_ops.stats_aggregator_summary(self._resource) + + def subscribe(self, iterator): + """Returns a @{tf.Operation} to associate this aggregator with `iterator`. + + Note: Each @{tf.data.Iterator} can be associated with at most one + `StatsAggregator`. After running the operation that this function + returns, all statistics recorded in the iteration of `iterator` + will be stored in `stats_aggregator`. + + Args: + iterator: A @{tf.data.Iterator} object. + + Returns: + A @{tf.Operation} that, when run, associates this aggregator with + `iterator`. + """ + if not isinstance(iterator, iterator_ops.Iterator): + raise TypeError("`iterator` must be a `tf.data.Iterator` object.") + return gen_dataset_ops.iterator_set_stats_aggregator( + iterator._iterator_resource, self._resource) # pylint: disable=protected-access + + +def bytes_produced_stats(tag): + """Records the number of bytes produced by each element of the input dataset. + + To consume the statistics, associate a `StatsAggregator` with an iterator + over the output dataset. + + Args: + tag: String. All statistics recorded by the returned transformation will + be associated with the given `tag`. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.contrib.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + return _StatsDataset(dataset, gen_dataset_ops.bytes_produced_stats_dataset, + tag) + + return _apply_fn + + +def latency_stats(tag): + """Records the latency of producing each element of the input dataset. + + To consume the statistics, associate a `StatsAggregator` with an iterator + over the output dataset. + + Args: + tag: String. All statistics recorded by the returned transformation will + be associated with the given `tag`. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.contrib.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + return _StatsDataset(dataset, gen_dataset_ops.latency_stats_dataset, tag) + + return _apply_fn + + +class _StatsDataset(dataset_ops.Dataset): + """A `Dataset` that acts as an identity, and also records statistics.""" + + def __init__(self, input_dataset, op_function, tag): + super(_StatsDataset, self).__init__() + self._input_dataset = input_dataset + self._op_function = op_function + self._tag = ops.convert_to_tensor(tag, dtype=dtypes.string) + + def _as_variant_tensor(self): + return self._op_function( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._tag, + output_shapes=nest.flatten(self.output_shapes), + output_types=nest.flatten(self.output_types)) + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types + + @property + def output_classes(self): + return self._input_dataset.output_classes diff --git a/tensorflow/contrib/decision_trees/proto/generic_tree_model_proto.swig b/tensorflow/contrib/decision_trees/proto/generic_tree_model_proto.swig index d3d201afd5761e7c5c136301c779222bedc68492..cafb9314caee1c4907786b8101e7c71bd7095306 100644 --- a/tensorflow/contrib/decision_trees/proto/generic_tree_model_proto.swig +++ b/tensorflow/contrib/decision_trees/proto/generic_tree_model_proto.swig @@ -2,7 +2,7 @@ %include "net/proto/swig/protofunc.swig" -#ifndef MUST_USE_RESULT +#ifndef ABSL_MUST_USE_RESULT #error Use this file only as a %include or %import after google.swig. #endif diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 145b9495ff40f8095b50d00e576333fdf5d7acdf..95848af69950bdaa680c41daecd8cbd8f3174f8e 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -60,6 +60,7 @@ py_library( "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:random_ops", + "//tensorflow/python:spectral_ops", "//tensorflow/python:state_ops", "//tensorflow/python:tensor_util", "//tensorflow/python:util", @@ -204,6 +205,24 @@ cuda_py_test( ], ) +cuda_py_test( + name = "half_normal_test", + size = "medium", + srcs = ["python/kernel_tests/half_normal_test.py"], + additional_deps = [ + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:nn_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:variables", + ], +) + cuda_py_test( name = "inverse_gamma_test", srcs = ["python/kernel_tests/inverse_gamma_test.py"], @@ -419,6 +438,7 @@ cuda_py_test( "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:spectral_ops_test_util", "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", "//tensorflow/python:platform_test", diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 0d12d838932e3a46e07f4a4242b889296c6e13c4..7b401e178f35fe56e4eb461936565f5c630ec4cf 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -36,6 +36,7 @@ from tensorflow.contrib.distributions.python.ops.distribution_util import softpl from tensorflow.contrib.distributions.python.ops.distribution_util import tridiag from tensorflow.contrib.distributions.python.ops.estimator import * from tensorflow.contrib.distributions.python.ops.geometric import * +from tensorflow.contrib.distributions.python.ops.half_normal import * from tensorflow.contrib.distributions.python.ops.independent import * from tensorflow.contrib.distributions.python.ops.inverse_gamma import * from tensorflow.contrib.distributions.python.ops.logistic import * @@ -107,6 +108,7 @@ _allowed_symbols = [ 'Gamma', 'GammaWithSoftplusConcentrationRate', 'Geometric', + 'HalfNormal', 'Independent', 'InverseGamma', 'InverseGammaWithSoftplusConcentrationRate', @@ -157,6 +159,10 @@ _allowed_symbols = [ 'assign_log_moving_mean_exp', 'moving_mean_variance', 'estimator_head_distribution_regression', + 'quadrature_scheme_softmaxnormal_gauss_hermite', + 'quadrature_scheme_softmaxnormal_quantiles', + 'quadrature_scheme_lognormal_gauss_hermite', + 'quadrature_scheme_lognormal_quantiles', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py index 25a9b6f5fe2ed6d218d6b44650fce17fa89c0664..288d9d8dd6f17cd6348d3d72aea4408e26913ebd 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py @@ -22,9 +22,9 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import test_util from tensorflow.contrib.distributions.python.ops.bijectors.invert import Invert +from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import _gen_mask from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import masked_autoregressive_default_template from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import MaskedAutoregressiveFlow -from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive_impl import _gen_mask from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops from tensorflow.python.ops import variables diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py index 38b3a23c2d684a6f89b7c4be4a763c649bf4de15..49451446b56d290f130c5db90c13b94974d92dc9 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py @@ -28,8 +28,19 @@ from tensorflow.python.ops.distributions.bijector_test_util import assert_biject from tensorflow.python.platform import test -class ReshapeBijectorTest(test.TestCase): - """Tests correctness of the reshape transformation.""" +class _ReshapeBijectorTest(object): + """Base class for testing the reshape transformation. + + Methods defined in this class call a method self.build_shapes() that + is implemented by subclasses defined below, returning respectively + ReshapeBijectorTestStatic: static shapes, + ReshapeBijectorTestDynamic: shape placeholders of known ndims, and + ReshapeBijectorTestDynamicNdims: shape placeholders of unspecified ndims, + so that each test in this base class is automatically run over all + three cases. The subclasses also implement assertRaisesError to test + for either Python exceptions (in the case of static shapes) or + TensorFlow op errors (dynamic shapes). + """ def setUp(self): self._rng = np.random.RandomState(42) @@ -40,9 +51,10 @@ class ReshapeBijectorTest(test.TestCase): expected_y = np.reshape(expected_x, [4, 6]) with self.test_session() as sess: + shape_in, shape_out, feed_dict = self.build_shapes([3, 2], [6,]) bijector = Reshape( - event_shape_out=[6,], - event_shape_in=[3, 2], + event_shape_out=shape_out, + event_shape_in=shape_in, validate_args=True) (x_, y_, @@ -52,66 +64,23 @@ class ReshapeBijectorTest(test.TestCase): bijector.forward(expected_x), bijector.forward_log_det_jacobian(expected_x), bijector.inverse_log_det_jacobian(expected_y), - )) + ), feed_dict=feed_dict) self.assertEqual("reshape", bijector.name) self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) self.assertAllClose(0., fldj_, rtol=1e-6, atol=0) self.assertAllClose(0., ildj_, rtol=1e-6, atol=0) - def testEventShapeDynamicNdims(self): - """Check forward/inverse shape methods with dynamic ndims.""" - - shape_in = tensor_shape.TensorShape([6,]) - shape_in_ph = array_ops.placeholder(dtype=dtypes.int32) - - shape_out = tensor_shape.TensorShape([2, 3]) - shape_out_ph = array_ops.placeholder(dtype=dtypes.int32) - - bijector = Reshape( - event_shape_out=shape_out_ph, - event_shape_in=shape_in_ph, validate_args=True) - - # using the _tensor methods, we should always get a fully-specified - # result since these are evaluated at graph runtime. - with self.test_session() as sess: - (shape_out_, - shape_in_) = sess.run(( - bijector.forward_event_shape_tensor(shape_in), - bijector.inverse_event_shape_tensor(shape_out), - ), feed_dict={ - shape_in_ph: shape_in, - shape_out_ph: shape_out, - }) - self.assertAllEqual(shape_out, shape_out_) - self.assertAllEqual(shape_in, shape_in_) - - def testEventShapeDynamic(self): - """Check shape methods with static ndims but dynamic shape.""" - - shape_in = tensor_shape.TensorShape([6,]) - shape_in_partial = tensor_shape.TensorShape([None,]) - shape_in_ph = array_ops.placeholder( - shape=[1,], dtype=dtypes.int32) - - shape_out = tensor_shape.TensorShape([2, 3]) - shape_out_partial = tensor_shape.TensorShape([None, None]) - shape_out_ph = array_ops.placeholder( - shape=[2,], dtype=dtypes.int32) + def testEventShapeTensor(self): + """Test event_shape_tensor methods when even ndims may be dynamic.""" + shape_in_static = [2, 3] + shape_out_static = [6,] + shape_in, shape_out, feed_dict = self.build_shapes(shape_in_static, + shape_out_static) bijector = Reshape( - event_shape_out=shape_out_ph, - event_shape_in=shape_in_ph, - validate_args=True) - - # if event shapes are not statically available, should - # return partially-specified TensorShapes. - self.assertAllEqual( - bijector.forward_event_shape(shape_in).as_list(), - shape_out_partial.as_list()) - self.assertAllEqual( - bijector.inverse_event_shape(shape_out).as_list(), - shape_in_partial.as_list()) + event_shape_out=shape_out, + event_shape_in=shape_in, validate_args=True) # using the _tensor methods, we should always get a fully-specified # result since these are evaluated at graph runtime. @@ -120,42 +89,9 @@ class ReshapeBijectorTest(test.TestCase): shape_in_) = sess.run(( bijector.forward_event_shape_tensor(shape_in), bijector.inverse_event_shape_tensor(shape_out), - ), feed_dict={ - shape_in_ph: shape_in, - shape_out_ph: shape_out, - }) - self.assertAllEqual(shape_out, shape_out_) - self.assertAllEqual(shape_in, shape_in_) - - def testEventShapeStatic(self): - """Check shape methods when shape is statically known.""" - - shape_in = tensor_shape.TensorShape([6,]) - shape_out = tensor_shape.TensorShape([2, 3]) - - bijector_static = Reshape( - event_shape_out=shape_out, - event_shape_in=shape_in, - validate_args=True) - - # test that forward_ and inverse_event_shape do sensible things - # when shapes are statically known. - self.assertEqual( - bijector_static.forward_event_shape(shape_in), - shape_out) - self.assertEqual( - bijector_static.inverse_event_shape(shape_out), - shape_in) - - with self.test_session() as sess: - (shape_out_static_, - shape_in_static_, - ) = sess.run(( - bijector_static.forward_event_shape_tensor(shape_in), - bijector_static.inverse_event_shape_tensor(shape_out), - )) - self.assertAllEqual(shape_out, shape_out_static_) - self.assertAllEqual(shape_in, shape_in_static_) + ), feed_dict=feed_dict) + self.assertAllEqual(shape_out_static, shape_out_) + self.assertAllEqual(shape_in_static, shape_in_) def testScalarReshape(self): """Test reshaping to and from a scalar shape ().""" @@ -166,11 +102,11 @@ class ReshapeBijectorTest(test.TestCase): expected_x_scalar = np.random.randn(1,) expected_y_scalar = expected_x_scalar[0] + shape_in, shape_out, feed_dict = self.build_shapes([], [1,]) with self.test_session() as sess: bijector = Reshape( - event_shape_out=[], - event_shape_in=[1,], validate_args=True) - + event_shape_out=shape_in, + event_shape_in=shape_out, validate_args=True) (x_, y_, x_scalar_, @@ -180,53 +116,178 @@ class ReshapeBijectorTest(test.TestCase): bijector.forward(expected_x), bijector.inverse(expected_y_scalar), bijector.forward(expected_x_scalar), - )) + ), feed_dict=feed_dict) self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) self.assertAllClose(expected_y_scalar, y_scalar_, rtol=1e-6, atol=0) self.assertAllClose(expected_x_scalar, x_scalar_, rtol=1e-6, atol=0) - def testRaisesOpError(self): - x1 = np.random.randn(4, 2, 3) - x2 = np.random.randn(4, 3, 2) - x3 = np.random.randn(4, 5, 1, 1) + def testMultipleUnspecifiedDimensionsOpError(self): with self.test_session() as sess: - shape_in_ph = array_ops.placeholder(shape=[2,], dtype=dtypes.int32) - shape_out_ph = array_ops.placeholder(shape=[3,], dtype=dtypes.int32) + shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [4, -1, -1,]) bijector = Reshape( - event_shape_out=shape_out_ph, - event_shape_in=shape_in_ph, + event_shape_out=shape_out, + event_shape_in=shape_in, validate_args=True) - with self.assertRaisesOpError( + with self.assertRaisesError( + "elements must have at most one `-1`."): + sess.run(bijector.forward_event_shape_tensor(shape_in), + feed_dict=feed_dict) + + def testInvalidDimensionsOpError(self): + + with self.test_session() as sess: + + shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [1, 2, -2,]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + + with self.assertRaisesError( + "elements must be either positive integers or `-1`."): + sess.run(bijector.forward_event_shape_tensor(shape_in), + feed_dict=feed_dict) + + def testValidButNonMatchingInputOpError(self): + x = np.random.randn(4, 3, 2) + + with self.test_session() as sess: + shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [1, 6, 1,]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + + # Here we pass in a tensor (x) whose shape is compatible with + # the output shape, so tf.reshape will throw no error, but + # doesn't match the expected input shape. + with self.assertRaisesError( "Input `event_shape` does not match `event_shape_in`."): - sess.run(bijector.forward(x2), - feed_dict={shape_out_ph: [1, 6, 1], - shape_in_ph: [2, 3]}) + sess.run(bijector.forward(x), + feed_dict=feed_dict) - with self.assertRaisesOpError( - "event_shape_out entries must be positive."): - sess.run(bijector.forward(x1), - feed_dict={shape_out_ph: [-1, -1, 6], - shape_in_ph: [2, 3]}) + def testValidButNonMatchingInputPartiallySpecifiedOpError(self): + x = np.random.randn(4, 3, 2) + + with self.test_session() as sess: + shape_in, shape_out, feed_dict = self.build_shapes([2, -1], [1, 6, 1,]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + + with self.assertRaisesError( + "Input `event_shape` does not match `event_shape_in`."): + sess.run(bijector.forward(x), + feed_dict=feed_dict) + + def testInputOutputMismatchOpError(self): + x1 = np.random.randn(4, 2, 3) + x2 = np.random.randn(4, 1, 1, 5) + + with self.test_session() as sess: + shape_in, shape_out, fd_mismatched = self.build_shapes([2, 3], + [1, 1, 5]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) # test that *all* methods check basic assertions - fd_mismatched = {shape_out_ph: [1, 1, 5], shape_in_ph: [2, 3]} - with self.assertRaisesOpError( - "Input/output `event_size`s do not match."): + with self.assertRaisesError( + "Input to reshape is a tensor with"): sess.run(bijector.forward(x1), feed_dict=fd_mismatched) - with self.assertRaisesOpError( - "Input/output `event_size`s do not match."): - sess.run(bijector.inverse(x3), feed_dict=fd_mismatched) - with self.assertRaisesOpError( - "Input/output `event_size`s do not match."): - sess.run(bijector.inverse_log_det_jacobian(x3), - feed_dict=fd_mismatched) - with self.assertRaisesOpError( - "Input/output `event_size`s do not match."): - sess.run(bijector.forward_log_det_jacobian(x1), - feed_dict=fd_mismatched) + with self.assertRaisesError( + "Input to reshape is a tensor with"): + sess.run(bijector.inverse(x2), feed_dict=fd_mismatched) + + def testOneShapePartiallySpecified(self): + expected_x = np.random.randn(4, 6) + expected_y = np.reshape(expected_x, [4, 2, 3]) + + with self.test_session() as sess: + # one of input/output shapes is partially specified + shape_in, shape_out, feed_dict = self.build_shapes([-1,], [2, 3]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + (x_, + y_, + ) = sess.run(( + bijector.inverse(expected_y), + bijector.forward(expected_x), + ), feed_dict=feed_dict) + self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) + self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) + + def testBothShapesPartiallySpecified(self): + expected_x = np.random.randn(4, 2, 3) + expected_y = np.reshape(expected_x, [4, 3, 2]) + with self.test_session() as sess: + shape_in, shape_out, feed_dict = self.build_shapes([-1, 3], [-1, 2]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + (x_, + y_, + ) = sess.run(( + bijector.inverse(expected_y), + bijector.forward(expected_x), + ), feed_dict=feed_dict) + self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) + self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) + + def testDefaultVectorShape(self): + expected_x = np.random.randn(4, 4) + expected_y = np.reshape(expected_x, [4, 2, 2]) + with self.test_session() as sess: + _, shape_out, feed_dict = self.build_shapes([-1,], [-1, 2]) + bijector = Reshape(shape_out, + validate_args=True) + (x_, + y_, + ) = sess.run(( + bijector.inverse(expected_y), + bijector.forward(expected_x), + ), feed_dict=feed_dict) + self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) + self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) + + def build_shapes(self, *args, **kwargs): + raise NotImplementedError("Subclass failed to implement `build_shapes`.") + + +class ReshapeBijectorTestStatic(test.TestCase, _ReshapeBijectorTest): + + def build_shapes(self, shape_in, shape_out): + shape_in_static = shape_in + shape_out_static = shape_out + feed_dict = {} + return shape_in_static, shape_out_static, feed_dict + + def assertRaisesError(self, msg): + return self.assertRaisesRegexp(Exception, msg) + + def testEventShape(self): + shape_in_static = tensor_shape.TensorShape([2, 3]) + shape_out_static = tensor_shape.TensorShape([6,]) + bijector = Reshape( + event_shape_out=shape_out_static, + event_shape_in=shape_in_static, validate_args=True) + + # test that forward_ and inverse_event_shape do sensible things + # when shapes are statically known. + self.assertEqual( + bijector.forward_event_shape(shape_in_static), + shape_out_static) + self.assertEqual( + bijector.inverse_event_shape(shape_out_static), + shape_in_static) def testBijectiveAndFinite(self): x = np.random.randn(4, 2, 3) @@ -238,5 +299,32 @@ class ReshapeBijectorTest(test.TestCase): validate_args=True) assert_bijective_and_finite(bijector, x, y, rtol=1e-6, atol=0) + +class ReshapeBijectorTestDynamic(test.TestCase, _ReshapeBijectorTest): + + def build_shapes(self, shape_in, shape_out): + shape_in_ph = array_ops.placeholder(shape=(len(shape_in),), + dtype=dtypes.int32) + shape_out_ph = array_ops.placeholder(shape=(len(shape_out),), + dtype=dtypes.int32) + feed_dict = {shape_in_ph: shape_in, shape_out_ph: shape_out} + return shape_in_ph, shape_out_ph, feed_dict + + def assertRaisesError(self, msg): + return self.assertRaisesOpError(msg) + + +class ReshapeBijectorTestDynamicNdims(test.TestCase, _ReshapeBijectorTest): + + def build_shapes(self, shape_in, shape_out): + shape_in_ph = array_ops.placeholder(shape=None, dtype=dtypes.int32) + shape_out_ph = array_ops.placeholder(shape=None, dtype=dtypes.int32) + feed_dict = {shape_in_ph: shape_in, shape_out_ph: shape_out} + return shape_in_ph, shape_out_ph, feed_dict + + def assertRaisesError(self, msg): + return self.assertRaisesOpError(msg) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/cauchy_test.py b/tensorflow/contrib/distributions/python/kernel_tests/cauchy_test.py index 7f7697357ce7c77b2a50b87271d4ba7b49cbe05e..73747db31c86b67eaad5aeab7d5e80191e12b333 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/cauchy_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/cauchy_test.py @@ -41,6 +41,7 @@ def try_import(name): # pylint: disable=invalid-name tf_logging.warning("Could not import %s: %s" % (name, str(e))) return module + stats = try_import("scipy.stats") @@ -62,9 +63,9 @@ class CauchyTest(test.TestCase): self.assertAllEqual(expected, scale_shape.eval()) loc = array_ops.zeros(loc_shape) scale = array_ops.ones(scale_shape) - self.assertAllEqual( - expected, - array_ops.shape(cauchy_lib.Cauchy(loc, scale).sample()).eval()) + self.assertAllEqual(expected, + array_ops.shape( + cauchy_lib.Cauchy(loc, scale).sample()).eval()) def _testParamStaticShapes(self, sample_shape, expected): param_shapes = cauchy_lib.Cauchy.param_static_shapes(sample_shape) @@ -92,8 +93,7 @@ class CauchyTest(test.TestCase): cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) log_pdf = cauchy.log_prob(x) - self.assertAllEqual(cauchy.batch_shape_tensor().eval(), - log_pdf.shape) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), log_pdf.shape) self.assertAllEqual(cauchy.batch_shape_tensor().eval(), log_pdf.eval().shape) self.assertAllEqual(cauchy.batch_shape, log_pdf.shape) @@ -115,16 +115,15 @@ class CauchyTest(test.TestCase): with self.test_session(): batch_size = 6 loc = constant_op.constant([[3.0, -3.0]] * batch_size) - scale = constant_op.constant([[np.sqrt(10.0), np.sqrt(15.0)]] * - batch_size) + scale = constant_op.constant( + [[np.sqrt(10.0), np.sqrt(15.0)]] * batch_size) x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) log_pdf = cauchy.log_prob(x) log_pdf_values = log_pdf.eval() self.assertEqual(log_pdf.shape, (6, 2)) - self.assertAllEqual(cauchy.batch_shape_tensor().eval(), - log_pdf.shape) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), log_pdf.shape) self.assertAllEqual(cauchy.batch_shape_tensor().eval(), log_pdf.eval().shape) self.assertAllEqual(cauchy.batch_shape, log_pdf.shape) @@ -248,8 +247,7 @@ class CauchyTest(test.TestCase): cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) entropy = cauchy.entropy() - self.assertAllEqual(cauchy.batch_shape_tensor().eval(), - entropy.shape) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), entropy.shape) self.assertAllEqual(cauchy.batch_shape_tensor().eval(), entropy.eval().shape) self.assertAllEqual(cauchy.batch_shape, entropy.shape) @@ -257,7 +255,7 @@ class CauchyTest(test.TestCase): if not stats: return - expected_entropy = stats.cauchy(loc, scale).entropy() + expected_entropy = stats.cauchy(loc, scale[0]).entropy().reshape((1, 3)) self.assertAllClose(expected_entropy, entropy.eval()) def testCauchyMode(self): @@ -368,8 +366,8 @@ class CauchyTest(test.TestCase): self.assertAllEqual(expected_shape, samples.shape) self.assertAllEqual(expected_shape, sample_values.shape) - expected_shape = (tensor_shape.TensorShape( - [n.eval()]).concatenate(cauchy.batch_shape)) + expected_shape = ( + tensor_shape.TensorShape([n.eval()]).concatenate(cauchy.batch_shape)) self.assertAllEqual(expected_shape, samples.shape) self.assertAllEqual(expected_shape, sample_values.shape) @@ -385,18 +383,18 @@ class CauchyTest(test.TestCase): samples = cauchy.sample(n) sample_values = samples.eval() self.assertEqual(samples.shape, (100000, batch_size, 2)) - self.assertAllClose(np.median(sample_values[:, 0, 0]), - loc_v[0], atol=1e-1) - self.assertAllClose(np.median(sample_values[:, 0, 1]), - loc_v[1], atol=1e-1) + self.assertAllClose( + np.median(sample_values[:, 0, 0]), loc_v[0], atol=1e-1) + self.assertAllClose( + np.median(sample_values[:, 0, 1]), loc_v[1], atol=1e-1) expected_shape = tensor_shape.TensorShape([n.eval()]).concatenate( tensor_shape.TensorShape(cauchy.batch_shape_tensor().eval())) self.assertAllEqual(expected_shape, samples.shape) self.assertAllEqual(expected_shape, sample_values.shape) - expected_shape = (tensor_shape.TensorShape( - [n.eval()]).concatenate(cauchy.batch_shape)) + expected_shape = ( + tensor_shape.TensorShape([n.eval()]).concatenate(cauchy.batch_shape)) self.assertAllEqual(expected_shape, samples.shape) self.assertAllEqual(expected_shape, sample_values.shape) @@ -428,9 +426,12 @@ class CauchyTest(test.TestCase): self.assertEqual(cauchy.event_shape, ()) self.assertAllEqual(cauchy.event_shape_tensor().eval(), []) self.assertAllEqual( - sess.run(cauchy.batch_shape_tensor(), - feed_dict={loc: 5.0, - scale: [1.0, 2.0]}), [2]) + sess.run( + cauchy.batch_shape_tensor(), + feed_dict={ + loc: 5.0, + scale: [1.0, 2.0] + }), [2]) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py index 2d74aa1f320149d0f7ef9e9c52b8c7053c2f74d7..a255d4fc890e67180532e342332a8e3f63a869cd 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py @@ -395,5 +395,110 @@ class MixtureStddevTest(test.TestCase): self.assertAllClose(actual_devs, expected_devs) +class _PadTest(object): + + def testNegAxisCorrectness(self): + x_ = np.float32([[1., 2, 3], + [4, 5, 6]]) + value_ = np.float32(0.25) + count_ = np.int32(2) + with self.test_session() as sess: + x = array_ops.placeholder_with_default( + x_, shape=x_.shape if self.is_static_shape else None) + value = (constant_op.constant(value_) if self.is_static_shape + else array_ops.placeholder_with_default(value_, shape=None)) + count = (constant_op.constant(count_) if self.is_static_shape + else array_ops.placeholder_with_default(count_, shape=None)) + + x0_front = distribution_util.pad( + x, axis=-2, value=value, count=count, front=True) + x0_back = distribution_util.pad( + x, axis=-2, count=count, back=True) + x0_both = distribution_util.pad( + x, axis=-2, value=value, front=True, back=True) + + if self.is_static_shape: + self.assertAllEqual([4, 3], x0_front.shape) + self.assertAllEqual([4, 3], x0_back.shape) + self.assertAllEqual([4, 3], x0_both.shape) + + [x0_front_, x0_back_, x0_both_] = sess.run([ + x0_front, x0_back, x0_both]) + + self.assertAllClose( + np.float32([[value_]*3, + [value_]*3, + [1, 2, 3], + [4, 5, 6]]), + x0_front_, atol=0., rtol=1e-6) + self.assertAllClose( + np.float32([[1, 2, 3], + [4, 5, 6], + [0.]*3, + [0.]*3]), + x0_back_, atol=0., rtol=1e-6) + self.assertAllClose( + np.float32([[value_]*3, + [1, 2, 3], + [4, 5, 6], + [value_]*3]), + x0_both_, atol=0., rtol=1e-6) + + def testPosAxisCorrectness(self): + x_ = np.float32([[1., 2, 3], + [4, 5, 6]]) + value_ = np.float32(0.25) + count_ = np.int32(2) + with self.test_session() as sess: + x = array_ops.placeholder_with_default( + x_, shape=x_.shape if self.is_static_shape else None) + value = (constant_op.constant(value_) if self.is_static_shape + else array_ops.placeholder_with_default(value_, shape=None)) + count = (constant_op.constant(count_) if self.is_static_shape + else array_ops.placeholder_with_default(count_, shape=None)) + + x1_front = distribution_util.pad( + x, axis=1, value=value, count=count, front=True) + x1_back = distribution_util.pad( + x, axis=1, count=count, back=True) + x1_both = distribution_util.pad( + x, axis=1, value=value, front=True, back=True) + + if self.is_static_shape: + self.assertAllEqual([2, 5], x1_front.shape) + self.assertAllEqual([2, 5], x1_back.shape) + self.assertAllEqual([2, 5], x1_both.shape) + + [x1_front_, x1_back_, x1_both_] = sess.run([ + x1_front, x1_back, x1_both]) + + self.assertAllClose( + np.float32([[value_]*2 + [1, 2, 3], + [value_]*2 + [4, 5, 6]]), + x1_front_, atol=0., rtol=1e-6) + self.assertAllClose( + np.float32([[1, 2, 3] + [0.]*2, + [4, 5, 6] + [0.]*2]), + x1_back_, atol=0., rtol=1e-6) + self.assertAllClose( + np.float32([[value_, 1, 2, 3, value_], + [value_, 4, 5, 6, value_]]), + x1_both_, atol=0., rtol=1e-6) + + +class PadStaticTest(_PadTest, test.TestCase): + + @property + def is_static_shape(self): + return True + + +class PadDynamicTest(_PadTest, test.TestCase): + + @property + def is_static_shape(self): + return False + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/half_normal_test.py b/tensorflow/contrib/distributions/python/kernel_tests/half_normal_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a4e75660083dc2edd1759a3a54e221d9e8a268c3 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/half_normal_test.py @@ -0,0 +1,320 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for initializers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import importlib +import numpy as np + +from tensorflow.contrib.distributions.python.ops import half_normal as hn_lib +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging + + +def try_import(name): # pylint: disable=invalid-name + module = None + try: + module = importlib.import_module(name) + except ImportError as e: + tf_logging.warning("Could not import %s: %s" % (name, str(e))) + return module + +stats = try_import("scipy.stats") + + +class HalfNormalTest(test.TestCase): + + def setUp(self): + self._rng = np.random.RandomState(123) + + def assertAllFinite(self, tensor): + is_finite = np.isfinite(tensor.eval()) + all_true = np.ones_like(is_finite, dtype=np.bool) + self.assertAllEqual(all_true, is_finite) + + def _testParamShapes(self, sample_shape, expected): + with self.test_session(): + param_shapes = hn_lib.HalfNormal.param_shapes(sample_shape) + scale_shape = param_shapes["scale"] + self.assertAllEqual(expected, scale_shape.eval()) + scale = array_ops.ones(scale_shape) + self.assertAllEqual( + expected, + array_ops.shape(hn_lib.HalfNormal(scale).sample()).eval()) + + def _testParamStaticShapes(self, sample_shape, expected): + param_shapes = hn_lib.HalfNormal.param_static_shapes(sample_shape) + scale_shape = param_shapes["scale"] + self.assertEqual(expected, scale_shape) + + def _testBatchShapes(self, dist, tensor): + self.assertAllEqual(dist.batch_shape_tensor().eval(), tensor.shape) + self.assertAllEqual(dist.batch_shape_tensor().eval(), tensor.eval().shape) + self.assertAllEqual(dist.batch_shape, tensor.shape) + self.assertAllEqual(dist.batch_shape, tensor.eval().shape) + + def testParamShapes(self): + sample_shape = [10, 3, 4] + self._testParamShapes(sample_shape, sample_shape) + self._testParamShapes(constant_op.constant(sample_shape), sample_shape) + + def testParamStaticShapes(self): + sample_shape = [10, 3, 4] + self._testParamStaticShapes(sample_shape, sample_shape) + self._testParamStaticShapes( + tensor_shape.TensorShape(sample_shape), sample_shape) + + def testHalfNormalLogPDF(self): + with self.test_session(): + batch_size = 6 + scale = constant_op.constant([3.0] * batch_size) + x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32) + halfnorm = hn_lib.HalfNormal(scale=scale) + + log_pdf = halfnorm.log_prob(x) + self._testBatchShapes(halfnorm, log_pdf) + + pdf = halfnorm.prob(x) + self._testBatchShapes(halfnorm, pdf) + + if not stats: + return + expected_log_pdf = stats.halfnorm(scale=scale.eval()).logpdf(x) + self.assertAllClose(expected_log_pdf, log_pdf.eval()) + self.assertAllClose(np.exp(expected_log_pdf), pdf.eval()) + + def testHalfNormalLogPDFMultidimensional(self): + with self.test_session(): + batch_size = 6 + scale = constant_op.constant([[3.0, 1.0]] * batch_size) + x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T + halfnorm = hn_lib.HalfNormal(scale=scale) + + log_pdf = halfnorm.log_prob(x) + self._testBatchShapes(halfnorm, log_pdf) + + pdf = halfnorm.prob(x) + self._testBatchShapes(halfnorm, pdf) + + if not stats: + return + expected_log_pdf = stats.halfnorm(scale=scale.eval()).logpdf(x) + self.assertAllClose(expected_log_pdf, log_pdf.eval()) + self.assertAllClose(np.exp(expected_log_pdf), pdf.eval()) + + def testHalfNormalCDF(self): + with self.test_session(): + batch_size = 50 + scale = self._rng.rand(batch_size) + 1.0 + x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64) + halfnorm = hn_lib.HalfNormal(scale=scale) + + cdf = halfnorm.cdf(x) + self._testBatchShapes(halfnorm, cdf) + + log_cdf = halfnorm.log_cdf(x) + self._testBatchShapes(halfnorm, log_cdf) + + if not stats: + return + expected_logcdf = stats.halfnorm(scale=scale).logcdf(x) + self.assertAllClose(expected_logcdf, log_cdf.eval(), atol=0) + self.assertAllClose(np.exp(expected_logcdf), cdf.eval(), atol=0) + + def testHalfNormalSurvivalFunction(self): + with self.test_session(): + batch_size = 50 + scale = self._rng.rand(batch_size) + 1.0 + x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64) + halfnorm = hn_lib.HalfNormal(scale=scale) + + sf = halfnorm.survival_function(x) + self._testBatchShapes(halfnorm, sf) + + log_sf = halfnorm.log_survival_function(x) + self._testBatchShapes(halfnorm, log_sf) + + if not stats: + return + expected_logsf = stats.halfnorm(scale=scale).logsf(x) + self.assertAllClose(expected_logsf, log_sf.eval(), atol=0) + self.assertAllClose(np.exp(expected_logsf), sf.eval(), atol=0) + + def testHalfNormalQuantile(self): + with self.test_session(): + batch_size = 50 + scale = self._rng.rand(batch_size) + 1.0 + p = np.linspace(0., 1.0, batch_size).astype(np.float64) + + halfnorm = hn_lib.HalfNormal(scale=scale) + x = halfnorm.quantile(p) + self._testBatchShapes(halfnorm, x) + + if not stats: + return + expected_x = stats.halfnorm(scale=scale).ppf(p) + self.assertAllClose(expected_x, x.eval(), atol=0) + + def testFiniteGradients(self): + for dtype in [np.float32, np.float64]: + g = ops.Graph() + with g.as_default(): + scale = variables.Variable(dtype(3.0)) + dist = hn_lib.HalfNormal(scale=scale) + x = np.array([0.01, 0.1, 1., 5., 10.]).astype(dtype) + for func in [ + dist.cdf, dist.log_cdf, dist.survival_function, + dist.log_prob, dist.prob, dist.log_survival_function, + ]: + print(func.__name__) + value = func(x) + grads = gradients_impl.gradients(value, [scale]) + with self.test_session(graph=g): + variables.global_variables_initializer().run() + self.assertAllFinite(value) + self.assertAllFinite(grads[0]) + + def testHalfNormalEntropy(self): + with self.test_session(): + scale = np.array([[1.0, 2.0, 3.0]]) + halfnorm = hn_lib.HalfNormal(scale=scale) + + # See https://en.wikipedia.org/wiki/Half-normal_distribution for the + # entropy formula used here. + expected_entropy = 0.5 * np.log(np.pi * scale ** 2.0 / 2.0) + 0.5 + + entropy = halfnorm.entropy() + self._testBatchShapes(halfnorm, entropy) + self.assertAllClose(expected_entropy, entropy.eval()) + + def testHalfNormalMeanAndMode(self): + with self.test_session(): + scale = np.array([11., 12., 13.]) + + halfnorm = hn_lib.HalfNormal(scale=scale) + expected_mean = scale * np.sqrt(2.0) / np.sqrt(np.pi) + + self.assertAllEqual((3,), halfnorm.mean().eval().shape) + self.assertAllEqual(expected_mean, halfnorm.mean().eval()) + + self.assertAllEqual((3,), halfnorm.mode().eval().shape) + self.assertAllEqual([0., 0., 0.], halfnorm.mode().eval()) + + def testHalfNormalVariance(self): + with self.test_session(): + scale = np.array([7., 7., 7.]) + halfnorm = hn_lib.HalfNormal(scale=scale) + expected_variance = scale ** 2.0 * (1.0 - 2.0 / np.pi) + + self.assertAllEqual((3,), halfnorm.variance().eval().shape) + self.assertAllEqual(expected_variance, halfnorm.variance().eval()) + + def testHalfNormalStandardDeviation(self): + with self.test_session(): + scale = np.array([7., 7., 7.]) + halfnorm = hn_lib.HalfNormal(scale=scale) + expected_variance = scale ** 2.0 * (1.0 - 2.0 / np.pi) + + self.assertAllEqual((3,), halfnorm.stddev().shape) + self.assertAllEqual(np.sqrt(expected_variance), halfnorm.stddev().eval()) + + def testHalfNormalSample(self): + with self.test_session(): + scale = constant_op.constant(3.0) + n = constant_op.constant(100000) + halfnorm = hn_lib.HalfNormal(scale=scale) + + sample = halfnorm.sample(n) + + self.assertEqual(sample.eval().shape, (100000,)) + self.assertAllClose(sample.eval().mean(), + 3.0 * np.sqrt(2.0) / np.sqrt(np.pi), atol=1e-1) + + expected_shape = tensor_shape.TensorShape([n.eval()]).concatenate( + tensor_shape.TensorShape(halfnorm.batch_shape_tensor().eval())) + self.assertAllEqual(expected_shape, sample.shape) + self.assertAllEqual(expected_shape, sample.eval().shape) + + expected_shape_static = (tensor_shape.TensorShape( + [n.eval()]).concatenate(halfnorm.batch_shape)) + self.assertAllEqual(expected_shape_static, sample.shape) + self.assertAllEqual(expected_shape_static, sample.eval().shape) + + def testHalfNormalSampleMultiDimensional(self): + with self.test_session(): + batch_size = 2 + scale = constant_op.constant([[2.0, 3.0]] * batch_size) + n = constant_op.constant(100000) + halfnorm = hn_lib.HalfNormal(scale=scale) + + sample = halfnorm.sample(n) + self.assertEqual(sample.shape, (100000, batch_size, 2)) + self.assertAllClose(sample.eval()[:, 0, 0].mean(), + 2.0 * np.sqrt(2.0) / np.sqrt(np.pi), atol=1e-1) + self.assertAllClose(sample.eval()[:, 0, 1].mean(), + 3.0 * np.sqrt(2.0) / np.sqrt(np.pi), atol=1e-1) + + expected_shape = tensor_shape.TensorShape([n.eval()]).concatenate( + tensor_shape.TensorShape(halfnorm.batch_shape_tensor().eval())) + self.assertAllEqual(expected_shape, sample.shape) + self.assertAllEqual(expected_shape, sample.eval().shape) + + expected_shape_static = (tensor_shape.TensorShape( + [n.eval()]).concatenate(halfnorm.batch_shape)) + self.assertAllEqual(expected_shape_static, sample.shape) + self.assertAllEqual(expected_shape_static, sample.eval().shape) + + def testNegativeSigmaFails(self): + with self.test_session(): + halfnorm = hn_lib.HalfNormal(scale=[-5.], validate_args=True, name="G") + with self.assertRaisesOpError("Condition x > 0 did not hold"): + halfnorm.mean().eval() + + def testHalfNormalShape(self): + with self.test_session(): + scale = constant_op.constant([6.0] * 5) + halfnorm = hn_lib.HalfNormal(scale=scale) + + self.assertEqual(halfnorm.batch_shape_tensor().eval(), [5]) + self.assertEqual(halfnorm.batch_shape, tensor_shape.TensorShape([5])) + self.assertAllEqual(halfnorm.event_shape_tensor().eval(), []) + self.assertEqual(halfnorm.event_shape, tensor_shape.TensorShape([])) + + def testHalfNormalShapeWithPlaceholders(self): + scale = array_ops.placeholder(dtype=dtypes.float32) + halfnorm = hn_lib.HalfNormal(scale=scale) + + with self.test_session() as sess: + # get_batch_shape should return an "" tensor. + self.assertEqual(halfnorm.batch_shape, tensor_shape.TensorShape(None)) + self.assertEqual(halfnorm.event_shape, ()) + self.assertAllEqual(halfnorm.event_shape_tensor().eval(), []) + self.assertAllEqual( + sess.run(halfnorm.batch_shape_tensor(), + feed_dict={scale: [1.0, 2.0]}), [2]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py index ece6bc077d9e21502fdfd01300a9d3e9f2c9c380..ff6092fc260660b512e8123823c63e98a023af6d 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py @@ -45,6 +45,17 @@ class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers, self.assertEqual([4, 5], x.shape) self.assertEqual([4, 5], log_prob_x.shape) + def testSampleAndLogProbBatch(self): + with self.test_session(): + gm = mixture_same_family_lib.MixtureSameFamily( + mixture_distribution=categorical_lib.Categorical(probs=[[0.3, 0.7]]), + components_distribution=normal_lib.Normal( + loc=[[-1., 1]], scale=[[0.1, 0.5]])) + x = gm.sample([4, 5], seed=42) + log_prob_x = gm.log_prob(x) + self.assertEqual([4, 5, 1], x.shape) + self.assertEqual([4, 5, 1], log_prob_x.shape) + def testSampleAndLogProbShapesBroadcastMix(self): mix_probs = np.float32([.3, .7]) bern_probs = np.float32([[.4, .6], [.25, .75]]) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py b/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py index 3c0147b8cf6e1b6a2791e85c0c0997992445fa7e..1035cb00f76d95c7c52c3e812e8bb2868d34b890 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py @@ -18,37 +18,40 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - from tensorflow.contrib.distributions.python.ops import poisson_lognormal from tensorflow.contrib.distributions.python.ops import test_util -from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class PoissonLogNormalQuadratureCompoundTest( - test_util.DiscreteScalarDistributionTestHelpers, test.TestCase): +class _PoissonLogNormalQuadratureCompoundTest( + test_util.DiscreteScalarDistributionTestHelpers): """Tests the PoissonLogNormalQuadratureCompoundTest distribution.""" def testSampleProbConsistent(self): with self.test_session() as sess: pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( - loc=-2., - scale=1.1, - quadrature_grid_and_probs=( - np.polynomial.hermite.hermgauss(deg=10)), + loc=array_ops.placeholder_with_default( + -2., + shape=[] if self.static_shape else None), + scale=array_ops.placeholder_with_default( + 1.1, + shape=[] if self.static_shape else None), + quadrature_size=10, validate_args=True) self.run_test_sample_consistent_log_prob( - sess.run, pln, rtol=0.1) + sess.run, pln, batch_size=1, rtol=0.1) def testMeanVariance(self): with self.test_session() as sess: pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( - loc=0., - scale=1., - quadrature_grid_and_probs=( - np.polynomial.hermite.hermgauss(deg=10)), + loc=array_ops.placeholder_with_default( + 0., + shape=[] if self.static_shape else None), + scale=array_ops.placeholder_with_default( + 1., + shape=[] if self.static_shape else None), + quadrature_size=10, validate_args=True) self.run_test_sample_consistent_mean_variance( sess.run, pln, rtol=0.02) @@ -56,21 +59,27 @@ class PoissonLogNormalQuadratureCompoundTest( def testSampleProbConsistentBroadcastScalar(self): with self.test_session() as sess: pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( - loc=[0., -0.5], - scale=1., - quadrature_grid_and_probs=( - np.polynomial.hermite.hermgauss(deg=10)), + loc=array_ops.placeholder_with_default( + [0., -0.5], + shape=[2] if self.static_shape else None), + scale=array_ops.placeholder_with_default( + 1., + shape=[] if self.static_shape else None), + quadrature_size=10, validate_args=True) self.run_test_sample_consistent_log_prob( - sess.run, pln, rtol=0.1, atol=0.01) + sess.run, pln, batch_size=2, rtol=0.1, atol=0.01) def testMeanVarianceBroadcastScalar(self): with self.test_session() as sess: pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( - loc=[0., -0.5], - scale=1., - quadrature_grid_and_probs=( - np.polynomial.hermite.hermgauss(deg=10)), + loc=array_ops.placeholder_with_default( + [0., -0.5], + shape=[2] if self.static_shape else None), + scale=array_ops.placeholder_with_default( + 1., + shape=[] if self.static_shape else None), + quadrature_size=10, validate_args=True) self.run_test_sample_consistent_mean_variance( sess.run, pln, rtol=0.1, atol=0.01) @@ -78,38 +87,46 @@ class PoissonLogNormalQuadratureCompoundTest( def testSampleProbConsistentBroadcastBoth(self): with self.test_session() as sess: pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( - loc=[[0.], [-0.5]], - scale=[[1., 0.9]], - quadrature_grid_and_probs=( - np.polynomial.hermite.hermgauss(deg=10)), + loc=array_ops.placeholder_with_default( + [[0.], [-0.5]], + shape=[2, 1] if self.static_shape else None), + scale=array_ops.placeholder_with_default( + [[1., 0.9]], + shape=[1, 2] if self.static_shape else None), + quadrature_size=10, validate_args=True) self.run_test_sample_consistent_log_prob( - sess.run, pln, rtol=0.1, atol=0.08) + sess.run, pln, batch_size=4, rtol=0.1, atol=0.08) def testMeanVarianceBroadcastBoth(self): with self.test_session() as sess: pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( - loc=[[0.], [-0.5]], - scale=[[1., 0.9]], - quadrature_grid_and_probs=( - np.polynomial.hermite.hermgauss(deg=10)), + loc=array_ops.placeholder_with_default( + [[0.], [-0.5]], + shape=[2, 1] if self.static_shape else None), + scale=array_ops.placeholder_with_default( + [[1., 0.9]], + shape=[1, 2] if self.static_shape else None), + quadrature_size=10, validate_args=True) self.run_test_sample_consistent_mean_variance( sess.run, pln, rtol=0.1, atol=0.01) - def testSampleProbConsistentDynamicQuadrature(self): - with self.test_session() as sess: - qgrid = array_ops.placeholder(dtype=dtypes.float32) - qprobs = array_ops.placeholder(dtype=dtypes.float32) - g, p = np.polynomial.hermite.hermgauss(deg=10) - pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( - loc=-2., - scale=1.1, - quadrature_grid_and_probs=(g, p), - validate_args=True) - self.run_test_sample_consistent_log_prob( - lambda x: sess.run(x, feed_dict={qgrid: g, qprobs: p}), - pln, rtol=0.1) + +class PoissonLogNormalQuadratureCompoundStaticShapeTest( + _PoissonLogNormalQuadratureCompoundTest, test.TestCase): + + @property + def static_shape(self): + return True + + +class PoissonLogNormalQuadratureCompoundDynamicShapeTest( + _PoissonLogNormalQuadratureCompoundTest, test.TestCase): + + @property + def static_shape(self): + return False if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py b/tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py index 595d9f5df755d7defa63d385039bafe4f87aa6ec..4186cf129dbf31724c84133734da3f226817c71a 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py @@ -23,11 +23,244 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import sample_stats from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops import spectral_ops_test_util from tensorflow.python.platform import test rng = np.random.RandomState(0) +class _AutoCorrelationTest(object): + + @property + def use_static_shape(self): + raise NotImplementedError("Subclass failed to implement `use_static_shape`") + + @property + def dtype(self): + raise NotImplementedError("Subclass failed to implement `dtype`.") + + def test_constant_sequence_axis_0_max_lags_none_center_false(self): + x_ = np.array([[0., 0., 0.], + [1., 1., 1.]]).astype(self.dtype) + x_ph = array_ops.placeholder_with_default( + input=x_, + shape=x_.shape if self.use_static_shape else None) + with spectral_ops_test_util.fft_kernel_label_map(): + with self.test_session() as sess: + # Setting normalize = True means we divide by zero. + auto_corr = sample_stats.auto_correlation( + x_ph, axis=1, center=False, normalize=False) + if self.use_static_shape: + self.assertEqual((2, 3), auto_corr.shape) + auto_corr_ = sess.run(auto_corr) + self.assertAllClose( + [[0., 0., 0.], + [1., 1., 1.]], auto_corr_) + + def test_constant_sequence_axis_0_max_lags_none_center_true(self): + x_ = np.array([[0., 0., 0.], + [1., 1., 1.]]).astype(self.dtype) + x_ph = array_ops.placeholder_with_default( + input=x_, + shape=x_.shape if self.use_static_shape else None) + with spectral_ops_test_util.fft_kernel_label_map(): + with self.test_session() as sess: + # Setting normalize = True means we divide by zero. + auto_corr = sample_stats.auto_correlation( + x_ph, axis=1, normalize=False, center=True) + if self.use_static_shape: + self.assertEqual((2, 3), auto_corr.shape) + auto_corr_ = sess.run(auto_corr) + self.assertAllClose( + [[0., 0., 0.], + [0., 0., 0.]], auto_corr_) + + def check_results_versus_brute_force( + self, x, axis, max_lags, center, normalize): + """Compute auto-correlation by brute force, then compare to tf result.""" + # Brute for auto-corr -- avoiding fft and transpositions. + axis_len = x.shape[axis] + if max_lags is None: + max_lags = axis_len - 1 + else: + max_lags = min(axis_len - 1, max_lags) + auto_corr_at_lag = [] + if center: + x -= x.mean(axis=axis, keepdims=True) + for m in range(max_lags + 1): + auto_corr_at_lag.append(( + np.take(x, indices=range(0, axis_len - m), axis=axis) * + np.conj(np.take(x, indices=range(m, axis_len), axis=axis)) + ).mean(axis=axis, keepdims=True)) + rxx = np.concatenate(auto_corr_at_lag, axis=axis) + if normalize: + rxx /= np.take(rxx, [0], axis=axis) + + x_ph = array_ops.placeholder_with_default( + x, shape=x.shape if self.use_static_shape else None) + with spectral_ops_test_util.fft_kernel_label_map(): + with self.test_session(): + auto_corr = sample_stats.auto_correlation( + x_ph, axis=axis, max_lags=max_lags, center=center, + normalize=normalize) + if self.use_static_shape: + output_shape = list(x.shape) + output_shape[axis] = max_lags + 1 + self.assertAllEqual(output_shape, auto_corr.shape) + self.assertAllClose(rxx, auto_corr.eval(), rtol=1e-5, atol=1e-5) + + def test_axis_n1_center_false_max_lags_none(self): + x = rng.randn(2, 3, 4).astype(self.dtype) + if self.dtype in [np.complex64]: + x = 1j * rng.randn(2, 3, 4).astype(self.dtype) + self.check_results_versus_brute_force( + x, axis=-1, max_lags=None, center=False, normalize=False) + + def test_axis_n2_center_false_max_lags_none(self): + x = rng.randn(3, 4, 5).astype(self.dtype) + if self.dtype in [np.complex64]: + x = 1j * rng.randn(3, 4, 5).astype(self.dtype) + self.check_results_versus_brute_force( + x, axis=-2, max_lags=None, center=False, normalize=False) + + def test_axis_n1_center_false_max_lags_none_normalize_true(self): + x = rng.randn(2, 3, 4).astype(self.dtype) + if self.dtype in [np.complex64]: + x = 1j * rng.randn(2, 3, 4).astype(self.dtype) + self.check_results_versus_brute_force( + x, axis=-1, max_lags=None, center=False, normalize=True) + + def test_axis_n2_center_false_max_lags_none_normalize_true(self): + x = rng.randn(3, 4, 5).astype(self.dtype) + if self.dtype in [np.complex64]: + x = 1j * rng.randn(3, 4, 5).astype(self.dtype) + self.check_results_versus_brute_force( + x, axis=-2, max_lags=None, center=False, normalize=True) + + def test_axis_0_center_true_max_lags_none(self): + x = rng.randn(3, 4, 5).astype(self.dtype) + if self.dtype in [np.complex64]: + x = 1j * rng.randn(3, 4, 5).astype(self.dtype) + self.check_results_versus_brute_force( + x, axis=0, max_lags=None, center=True, normalize=False) + + def test_axis_2_center_true_max_lags_1(self): + x = rng.randn(3, 4, 5).astype(self.dtype) + if self.dtype in [np.complex64]: + x = 1j * rng.randn(3, 4, 5).astype(self.dtype) + self.check_results_versus_brute_force( + x, axis=2, max_lags=1, center=True, normalize=False) + + def test_axis_2_center_true_max_lags_100(self): + # There are less than 100 elements in axis 2, so expect we get back an array + # the same size as x, despite having asked for 100 lags. + x = rng.randn(3, 4, 5).astype(self.dtype) + if self.dtype in [np.complex64]: + x = 1j * rng.randn(3, 4, 5).astype(self.dtype) + self.check_results_versus_brute_force( + x, axis=2, max_lags=100, center=True, normalize=False) + + def test_long_orthonormal_sequence_has_corr_length_0(self): + l = 10000 + x = rng.randn(l).astype(self.dtype) + x_ph = array_ops.placeholder_with_default( + x, shape=(l,) if self.use_static_shape else None) + with spectral_ops_test_util.fft_kernel_label_map(): + with self.test_session(): + rxx = sample_stats.auto_correlation( + x_ph, max_lags=l // 2, center=True, normalize=False) + if self.use_static_shape: + self.assertAllEqual((l // 2 + 1,), rxx.shape) + rxx_ = rxx.eval() + # OSS CPU FFT has some accuracy issues is not the most accurate. + # So this tolerance is a bit bad. + self.assertAllClose(1., rxx_[0], rtol=0.05) + # The maximal error in the rest of the sequence is not great. + self.assertAllClose(np.zeros(l // 2), rxx_[1:], atol=0.1) + # The mean error in the rest is ok, actually 0.008 when I tested it. + self.assertLess(np.abs(rxx_[1:]).mean(), 0.02) + + def test_step_function_sequence(self): + # x jumps to new random value every 10 steps. So correlation length = 10. + x = (rng.randint(-10, 10, size=(1000, 1)) + * np.ones((1, 10))).ravel().astype(self.dtype) + x_ph = array_ops.placeholder_with_default( + x, shape=(1000 * 10,) if self.use_static_shape else None) + with spectral_ops_test_util.fft_kernel_label_map(): + with self.test_session(): + rxx = sample_stats.auto_correlation( + x_ph, max_lags=1000 * 10 // 2, center=True, normalize=False) + if self.use_static_shape: + self.assertAllEqual((1000 * 10 // 2 + 1,), rxx.shape) + rxx_ = rxx.eval() + rxx_ /= rxx_[0] + # Expect positive correlation for the first 10 lags, then significantly + # smaller negative. + self.assertGreater(rxx_[:10].min(), 0) + self.assertGreater(rxx_[9], 5 * rxx_[10:20].mean()) + # RXX should be decreasing for the first 10 lags. + diff = np.diff(rxx_) + self.assertLess(diff[:10].max(), 0) + + def test_normalization(self): + l = 10000 + x = 3 * rng.randn(l).astype(self.dtype) + x_ph = array_ops.placeholder_with_default( + x, shape=(l,) if self.use_static_shape else None) + with spectral_ops_test_util.fft_kernel_label_map(): + with self.test_session(): + rxx = sample_stats.auto_correlation( + x_ph, max_lags=l // 2, center=True, normalize=True) + if self.use_static_shape: + self.assertAllEqual((l // 2 + 1,), rxx.shape) + rxx_ = rxx.eval() + # Note that RXX[0] = 1, despite the fact that E[X^2] = 9, and this is + # due to normalize=True. + # OSS CPU FFT has some accuracy issues is not the most accurate. + # So this tolerance is a bit bad. + self.assertAllClose(1., rxx_[0], rtol=0.05) + # The maximal error in the rest of the sequence is not great. + self.assertAllClose(np.zeros(l // 2), rxx_[1:], atol=0.1) + # The mean error in the rest is ok, actually 0.008 when I tested it. + self.assertLess(np.abs(rxx_[1:]).mean(), 0.02) + + +class AutoCorrelationTestStaticShapeFloat32(test.TestCase, + _AutoCorrelationTest): + + @property + def dtype(self): + return np.float32 + + @property + def use_static_shape(self): + return True + + +class AutoCorrelationTestStaticShapeComplex64(test.TestCase, + _AutoCorrelationTest): + + @property + def dtype(self): + return np.complex64 + + @property + def use_static_shape(self): + return True + + +class AutoCorrelationTestDynamicShapeFloat32(test.TestCase, + _AutoCorrelationTest): + + @property + def dtype(self): + return np.float32 + + @property + def use_static_shape(self): + return False + + class PercentileTestWithLowerInterpolation(test.TestCase): _interpolation = "lower" diff --git a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py index 103d8e186221e879d1734a097114708429f725bd..cbaf74d3f66253ae5727e1ba579e2d49235b748e 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py @@ -200,6 +200,27 @@ class TransformedDistributionTest(test.TestCase): self.assertAllEqual([2], multi_logit_normal.event_shape) self.assertAllEqual([2], multi_logit_normal.event_shape_tensor().eval()) + def testCastLogDetJacobian(self): + """Test log_prob when Jacobian and log_prob dtypes do not match.""" + + with self.test_session(): + # Create an identity bijector whose jacobians have dtype int32 + int_identity = bs.Inline( + forward_fn=array_ops.identity, + inverse_fn=array_ops.identity, + inverse_log_det_jacobian_fn=lambda x: math_ops.cast(0, dtypes.int32), + forward_log_det_jacobian_fn=lambda x: math_ops.cast(0, dtypes.int32), + is_constant_jacobian=True) + normal = self._cls()( + distribution=ds.Normal(loc=0., scale=1.), + bijector=int_identity, + validate_args=True) + + y = normal.sample() + normal.log_prob(y).eval() + normal.prob(y).eval() + normal.entropy().eval() + def testEntropy(self): with self.test_session(): shift = np.array([[-1, 0, 1], [-1, -2, -3]], dtype=np.float32) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py b/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py index de4a221f7badca8267a81d612a57137c676ff052..d292b04665e34196670ee4f1c1655f805e04e06a 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py @@ -21,9 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.distributions.python.ops import test_util -from tensorflow.contrib.distributions.python.ops import vector_diffeomixture as vector_diffeomixture_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops +from tensorflow.contrib.distributions.python.ops import vector_diffeomixture as vdm_lib from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.ops.linalg import linear_operator_diag as linop_diag_lib from tensorflow.python.ops.linalg import linear_operator_identity as linop_identity_lib @@ -37,7 +35,7 @@ class VectorDiffeomixtureTest( def testSampleProbConsistentBroadcastMixNoBatch(self): with self.test_session() as sess: dims = 4 - vdm = vector_diffeomixture_lib.VectorDiffeomixture( + vdm = vdm_lib.VectorDiffeomixture( mix_loc=[[0.], [1.]], mix_scale=[1.], distribution=normal_lib.Normal(0., 1.), @@ -54,18 +52,19 @@ class VectorDiffeomixtureTest( diag=np.linspace(2.5, 3.5, dims, dtype=np.float32), is_positive_definite=True), ], + quadrature_size=8, validate_args=True) # Ball centered at component0's mean. self.run_test_sample_consistent_log_prob( - sess.run, vdm, radius=2., center=0., rtol=0.005) + sess.run, vdm, radius=2., center=0., rtol=0.015) # Larger ball centered at component1's mean. self.run_test_sample_consistent_log_prob( - sess.run, vdm, radius=4., center=2., rtol=0.005) + sess.run, vdm, radius=4., center=2., rtol=0.015) def testSampleProbConsistentBroadcastMixNonStandardBase(self): with self.test_session() as sess: dims = 4 - vdm = vector_diffeomixture_lib.VectorDiffeomixture( + vdm = vdm_lib.VectorDiffeomixture( mix_loc=[[0.], [1.]], mix_scale=[1.], distribution=normal_lib.Normal(1., 1.5), @@ -82,18 +81,19 @@ class VectorDiffeomixtureTest( diag=np.linspace(2.5, 3.5, dims, dtype=np.float32), is_positive_definite=True), ], + quadrature_size=8, validate_args=True) # Ball centered at component0's mean. self.run_test_sample_consistent_log_prob( - sess.run, vdm, radius=2., center=1., rtol=0.006) + sess.run, vdm, radius=2., center=1., rtol=0.015) # Larger ball centered at component1's mean. self.run_test_sample_consistent_log_prob( - sess.run, vdm, radius=4., center=3., rtol=0.009) + sess.run, vdm, radius=4., center=3., rtol=0.01) def testSampleProbConsistentBroadcastMixBatch(self): with self.test_session() as sess: dims = 4 - vdm = vector_diffeomixture_lib.VectorDiffeomixture( + vdm = vdm_lib.VectorDiffeomixture( mix_loc=[[0.], [1.]], mix_scale=[1.], distribution=normal_lib.Normal(0., 1.), @@ -113,18 +113,19 @@ class VectorDiffeomixtureTest( ]), is_positive_definite=True), ], + quadrature_size=8, validate_args=True) # Ball centered at component0's mean. self.run_test_sample_consistent_log_prob( - sess.run, vdm, radius=2., center=0., rtol=0.005) + sess.run, vdm, radius=2., center=0., rtol=0.01) # Larger ball centered at component1's mean. self.run_test_sample_consistent_log_prob( - sess.run, vdm, radius=4., center=2., rtol=0.005) + sess.run, vdm, radius=4., center=2., rtol=0.01) def testMeanCovarianceNoBatch(self): with self.test_session() as sess: dims = 3 - vdm = vector_diffeomixture_lib.VectorDiffeomixture( + vdm = vdm_lib.VectorDiffeomixture( mix_loc=[[0.], [4.]], mix_scale=[10.], distribution=normal_lib.Normal(0., 1.), @@ -141,14 +142,15 @@ class VectorDiffeomixtureTest( diag=np.linspace(2.5, 3.5, dims, dtype=np.float32), is_positive_definite=True), ], + quadrature_size=8, validate_args=True) self.run_test_sample_consistent_mean_covariance( - sess.run, vdm, rtol=0.02, cov_rtol=0.06) + sess.run, vdm, rtol=0.02, cov_rtol=0.08) def testMeanCovarianceNoBatchUncenteredNonStandardBase(self): with self.test_session() as sess: dims = 3 - vdm = vector_diffeomixture_lib.VectorDiffeomixture( + vdm = vdm_lib.VectorDiffeomixture( mix_loc=[[0.], [4.]], mix_scale=[10.], distribution=normal_lib.Normal(-1., 1.5), @@ -165,6 +167,7 @@ class VectorDiffeomixtureTest( diag=np.linspace(2.5, 3.5, dims, dtype=np.float32), is_positive_definite=True), ], + quadrature_size=8, validate_args=True) self.run_test_sample_consistent_mean_covariance( sess.run, vdm, num_samples=int(1e6), rtol=0.01, cov_atol=0.025) @@ -172,7 +175,7 @@ class VectorDiffeomixtureTest( def testMeanCovarianceBatch(self): with self.test_session() as sess: dims = 3 - vdm = vector_diffeomixture_lib.VectorDiffeomixture( + vdm = vdm_lib.VectorDiffeomixture( mix_loc=[[0.], [4.]], mix_scale=[10.], distribution=normal_lib.Normal(0., 1.), @@ -192,18 +195,16 @@ class VectorDiffeomixtureTest( ]), is_positive_definite=True), ], + quadrature_size=8, validate_args=True) self.run_test_sample_consistent_mean_covariance( - sess.run, vdm, rtol=0.02, cov_rtol=0.06) + sess.run, vdm, rtol=0.02, cov_rtol=0.07) - def testSampleProbConsistentDynamicQuadrature(self): + def testSampleProbConsistentQuadrature(self): with self.test_session() as sess: - qgrid = array_ops.placeholder(dtype=dtypes.float32) - qprobs = array_ops.placeholder(dtype=dtypes.float32) - g, p = np.polynomial.hermite.hermgauss(deg=8) dims = 4 - vdm = vector_diffeomixture_lib.VectorDiffeomixture( - mix_loc=[[0.], [1.]], + vdm = vdm_lib.VectorDiffeomixture( + mix_loc=[0.], mix_scale=[1.], distribution=normal_lib.Normal(0., 1.), loc=[ @@ -219,15 +220,14 @@ class VectorDiffeomixtureTest( diag=np.linspace(2.5, 3.5, dims, dtype=np.float32), is_positive_definite=True), ], - quadrature_grid_and_probs=(g, p), + quadrature_size=3, validate_args=True) # Ball centered at component0's mean. - sess_run_fn = lambda x: sess.run(x, feed_dict={qgrid: g, qprobs: p}) self.run_test_sample_consistent_log_prob( - sess_run_fn, vdm, radius=2., center=0., rtol=0.005) + sess.run, vdm, radius=2., center=0., rtol=0.015) # Larger ball centered at component1's mean. self.run_test_sample_consistent_log_prob( - sess_run_fn, vdm, radius=4., center=2., rtol=0.005) + sess.run, vdm, radius=4., center=2., rtol=0.005) # TODO(jvdillon): We've tested that (i) .sample and .log_prob are consistent, # (ii) .mean, .stddev etc... and .sample are consistent. However, we haven't diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py index 6049419818e18c54209f0be95d41fcecf6627b7e..0fe9f6aa78fbe845b99d0668f075b0162ec2a9f7 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py @@ -18,12 +18,117 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.absolute_value_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["AbsoluteValue"] +__all__ = [ + "AbsoluteValue", +] -remove_undocumented(__name__, _allowed_symbols) + +class AbsoluteValue(bijector.Bijector): + """Computes `Y = g(X) = Abs(X)`, element-wise. + + This non-injective bijector allows for transformations of scalar distributions + with the absolute value function, which maps `(-inf, inf)` to `[0, inf)`. + + * For `y in (0, inf)`, `AbsoluteValue.inverse(y)` returns the set inverse + `{x in (-inf, inf) : |x| = y}` as a tuple, `-y, y`. + * `AbsoluteValue.inverse(0)` returns `0, 0`, which is not the set inverse + (the set inverse is the singleton `{0}`), but "works" in conjunction with + `TransformedDistribution` to produce a left semi-continuous pdf. + * For `y < 0`, `AbsoluteValue.inverse(y)` happily returns the + wrong thing, `-y, y`. This is done for efficiency. If + `validate_args == True`, `y < 0` will raise an exception. + + + ```python + tfd = tf.contrib.distributions + + abs = tfd.bijectors.AbsoluteValue() + + abs.forward([-1., 0., 1.]) + ==> [1., 0., 1.] + + abs.inverse(1.) + ==> [-1., 1.] + + # The |dX/dY| is constant, == 1. So Log|dX/dY| == 0. + abs.inverse_log_det_jacobian(1.) + ==> [0., 0.] + + # Special case handling of 0. + abs.inverse(0.) + ==> [0., 0.] + + abs.inverse_log_det_jacobian(0.) + ==> [0., 0.] + ``` + + """ + + def __init__(self, event_ndims=0, validate_args=False, name="absolute_value"): + """Instantiates the `AbsoluteValue` bijector. + + Args: + event_ndims: Python scalar indicating the number of dimensions associated + with a particular draw from the distribution. Currently only zero is + supported. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness, in particular whether inputs to `inverse` and + `inverse_log_det_jacobian` are non-negative. + name: Python `str` name given to ops managed by this object. + + Raises: + ValueError: If `event_ndims` is not zero. + """ + self._graph_parents = [] + self._name = name + + event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") + event_ndims_const = tensor_util.constant_value(event_ndims) + if event_ndims_const is not None and event_ndims_const not in (0,): + raise ValueError("event_ndims(%s) was not 0" % event_ndims_const) + else: + if validate_args: + event_ndims = control_flow_ops.with_dependencies( + [check_ops.assert_equal( + event_ndims, 0, message="event_ndims was not 0")], + event_ndims) + + with self._name_scope("init"): + super(AbsoluteValue, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) + + def _forward(self, x): + return math_ops.abs(x) + + def _inverse(self, y): + if self.validate_args: + y = control_flow_ops.with_dependencies( + [check_ops.assert_non_negative(y, message="Argument y was negative")], + y) + return -y, y + + def _inverse_log_det_jacobian(self, y): + # If event_ndims = 2, + # F^{-1}(y) = (-y, y), so DF^{-1}(y) = (-1, 1), + # so Log|DF^{-1}(y)| = Log[1, 1] = [0, 0]. + batch_shape = array_ops.shape(y)[:array_ops.rank(y) - self.event_ndims] + zeros = array_ops.zeros(batch_shape, dtype=y.dtype) + if self.validate_args: + zeros = control_flow_ops.with_dependencies( + [check_ops.assert_non_negative(y, message="Argument y was negative")], + zeros) + return zeros, zeros + + @property + def _is_injective(self): + return False diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py deleted file mode 100644 index b84502003ab6c0c4ffdda21eea162f441509e1fa..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""AbsoluteValue bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector - -__all__ = [ - "AbsoluteValue", -] - - -class AbsoluteValue(bijector.Bijector): - """Computes `Y = g(X) = Abs(X)`, element-wise. - - This non-injective bijector allows for transformations of scalar distributions - with the absolute value function, which maps `(-inf, inf)` to `[0, inf)`. - - * For `y in (0, inf)`, `AbsoluteValue.inverse(y)` returns the set inverse - `{x in (-inf, inf) : |x| = y}` as a tuple, `-y, y`. - * `AbsoluteValue.inverse(0)` returns `0, 0`, which is not the set inverse - (the set inverse is the singleton `{0}`), but "works" in conjunction with - `TransformedDistribution` to produce a left semi-continuous pdf. - * For `y < 0`, `AbsoluteValue.inverse(y)` happily returns the - wrong thing, `-y, y`. This is done for efficiency. If - `validate_args == True`, `y < 0` will raise an exception. - - - ```python - abs = ds.bijectors.AbsoluteValue() - - abs.forward([-1., 0., 1.]) - ==> [1., 0., 1.] - - abs.inverse(1.) - ==> [-1., 1.] - - # The |dX/dY| is constant, == 1. So Log|dX/dY| == 0. - abs.inverse_log_det_jacobian(1.) - ==> [0., 0.] - - # Special case handling of 0. - abs.inverse(0.) - ==> [0., 0.] - - abs.inverse_log_det_jacobian(0.) - ==> [0., 0.] - ``` - - """ - - def __init__(self, event_ndims=0, validate_args=False, name="absolute_value"): - """Instantiates the `AbsoluteValue` bijector. - - Args: - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. Currently only zero is - supported. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness, in particular whether inputs to `inverse` and - `inverse_log_det_jacobian` are non-negative. - name: Python `str` name given to ops managed by this object. - - Raises: - ValueError: If `event_ndims` is not zero. - """ - self._graph_parents = [] - self._name = name - - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims_const = tensor_util.constant_value(event_ndims) - if event_ndims_const is not None and event_ndims_const not in (0,): - raise ValueError("event_ndims(%s) was not 0" % event_ndims_const) - else: - if validate_args: - event_ndims = control_flow_ops.with_dependencies( - [check_ops.assert_equal( - event_ndims, 0, message="event_ndims was not 0")], - event_ndims) - - with self._name_scope("init"): - super(AbsoluteValue, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) - - def _forward(self, x): - return math_ops.abs(x) - - def _inverse(self, y): - if self.validate_args: - y = control_flow_ops.with_dependencies( - [check_ops.assert_non_negative(y, message="Argument y was negative")], - y) - return -y, y - - def _inverse_log_det_jacobian(self, y): - # If event_ndims = 2, - # F^{-1}(y) = (-y, y), so DF^{-1}(y) = (-1, 1), - # so Log|DF^{-1}(y)| = Log[1, 1] = [0, 0]. - batch_shape = array_ops.shape(y)[:array_ops.rank(y) - self.event_ndims] - zeros = array_ops.zeros(batch_shape, dtype=y.dtype) - if self.validate_args: - zeros = control_flow_ops.with_dependencies( - [check_ops.assert_non_negative(y, message="Argument y was negative")], - zeros) - return zeros, zeros - - @property - def _is_injective(self): - return False diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py index 940cceff04e77cfc2f7caae5a798d135f7601b95..05bb9c2f9bdf35e222c94db3491157893da64ebd 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py @@ -18,12 +18,386 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.affine_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.contrib import linalg +from tensorflow.contrib.distributions.python.ops import distribution_util +from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["Affine"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "Affine", +] + + +def _as_tensor(x, name): + """Convenience to convert to `Tensor` or leave as `None`.""" + return None if x is None else ops.convert_to_tensor(x, name=name) + + +class Affine(bijector.Bijector): + """Compute `Y = g(X; shift, scale) = scale @ X + shift`. + + Here `scale = c * I + diag(D1) + tril(L) + V @ diag(D2) @ V.T`. + + In TF parlance, the `scale` term is logically equivalent to: + + ```python + scale = ( + scale_identity_multiplier * tf.diag(tf.ones(d)) + + tf.diag(scale_diag) + + scale_tril + + scale_perturb_factor @ diag(scale_perturb_diag) @ + tf.transpose([scale_perturb_factor]) + ) + ``` + + The `scale` term is applied without necessarily materializing constituent + matrices, i.e., the matmul is [matrix-free]( + https://en.wikipedia.org/wiki/Matrix-free_methods) when possible. + + Examples: + + ```python + # Y = X + b = Affine() + + # Y = X + shift + b = Affine(shift=[1., 2, 3]) + + # Y = 2 * I @ X.T + shift + b = Affine(shift=[1., 2, 3], + scale_identity_multiplier=2.) + + # Y = tf.diag(d1) @ X.T + shift + b = Affine(shift=[1., 2, 3], + scale_diag=[-1., 2, 1]) # Implicitly 3x3. + + # Y = (I + v * v.T) @ X.T + shift + b = Affine(shift=[1., 2, 3], + scale_perturb_factor=[[1., 0], + [0, 1], + [1, 1]]) + + # Y = (diag(d1) + v * diag(d2) * v.T) @ X.T + shift + b = Affine(shift=[1., 2, 3], + scale_diag=[1., 3, 3], # Implicitly 3x3. + scale_perturb_diag=[2., 1], # Implicitly 2x2. + scale_perturb_factor=[[1., 0], + [0, 1], + [1, 1]]) + + ``` + + """ + + def __init__(self, + shift=None, + scale_identity_multiplier=None, + scale_diag=None, + scale_tril=None, + scale_perturb_factor=None, + scale_perturb_diag=None, + event_ndims=1, + validate_args=False, + name="affine"): + """Instantiates the `Affine` bijector. + + This `Bijector` is initialized with `shift` `Tensor` and `scale` arguments, + giving the forward operation: + + ```none + Y = g(X) = scale @ X + shift + ``` + + where the `scale` term is logically equivalent to: + + ```python + scale = ( + scale_identity_multiplier * tf.diag(tf.ones(d)) + + tf.diag(scale_diag) + + scale_tril + + scale_perturb_factor @ diag(scale_perturb_diag) @ + tf.transpose([scale_perturb_factor]) + ) + ``` + + If none of `scale_identity_multiplier`, `scale_diag`, or `scale_tril` are + specified then `scale += IdentityMatrix`. Otherwise specifying a + `scale` argument has the semantics of `scale += Expand(arg)`, i.e., + `scale_diag != None` means `scale += tf.diag(scale_diag)`. + + Args: + shift: Floating-point `Tensor`. If this is set to `None`, no shift is + applied. + scale_identity_multiplier: floating point rank 0 `Tensor` representing a + scaling done to the identity matrix. + When `scale_identity_multiplier = scale_diag = scale_tril = None` then + `scale += IdentityMatrix`. Otherwise no scaled-identity-matrix is added + to `scale`. + scale_diag: Floating-point `Tensor` representing the diagonal matrix. + `scale_diag` has shape [N1, N2, ... k], which represents a k x k + diagonal matrix. + When `None` no diagonal term is added to `scale`. + scale_tril: Floating-point `Tensor` representing the diagonal matrix. + `scale_diag` has shape [N1, N2, ... k, k], which represents a k x k + lower triangular matrix. + When `None` no `scale_tril` term is added to `scale`. + The upper triangular elements above the diagonal are ignored. + scale_perturb_factor: Floating-point `Tensor` representing factor matrix + with last two dimensions of shape `(k, r)`. When `None`, no rank-r + update is added to `scale`. + scale_perturb_diag: Floating-point `Tensor` representing the diagonal + matrix. `scale_perturb_diag` has shape [N1, N2, ... r], which + represents an `r x r` diagonal matrix. When `None` low rank updates will + take the form `scale_perturb_factor * scale_perturb_factor.T`. + event_ndims: Scalar `int` `Tensor` indicating the number of dimensions + associated with a particular draw from the distribution. Must be 0 or 1. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + + Raises: + ValueError: if `perturb_diag` is specified but not `perturb_factor`. + TypeError: if `shift` has different `dtype` from `scale` arguments. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + + # Ambiguous definition of low rank update. + if scale_perturb_diag is not None and scale_perturb_factor is None: + raise ValueError("When scale_perturb_diag is specified, " + "scale_perturb_factor must be specified.") + + # Special case, only handling a scaled identity matrix. We don't know its + # dimensions, so this is special cased. + # We don't check identity_multiplier, since below we set it to 1. if all + # other scale args are None. + self._is_only_identity_multiplier = (scale_tril is None and + scale_diag is None and + scale_perturb_factor is None) + + with self._name_scope("init", values=[ + shift, scale_identity_multiplier, scale_diag, scale_tril, + scale_perturb_diag, scale_perturb_factor]): + event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") + event_ndims_const = tensor_util.constant_value(event_ndims) + if event_ndims_const is not None and event_ndims_const not in (0, 1): + raise ValueError("event_ndims(%s) was not 0 or 1" % event_ndims_const) + else: + if validate_args: + # Shape tool will catch if event_ndims is negative. + event_ndims = control_flow_ops.with_dependencies( + [check_ops.assert_less( + event_ndims, 2, message="event_ndims must be 0 or 1")], + event_ndims) + + if event_ndims_const == 0 and not self._is_only_identity_multiplier: + raise ValueError( + "If event_ndims == 0, the only scale argument you can pass is " + "scale_identity_multiplier. All others operate on vectors.") + + # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`. + dtype = dtypes.float32 + + if shift is not None: + shift = ops.convert_to_tensor(shift, name="shift") + dtype = shift.dtype.base_dtype + self._shift = shift + + # When no args are specified, pretend the scale matrix is the identity + # matrix. + if (self._is_only_identity_multiplier and + scale_identity_multiplier is None): + scale_identity_multiplier = ops.convert_to_tensor(1., dtype=dtype) + + # self._create_scale_operator returns a LinearOperator in all cases + # except if self._is_only_identity_multiplier; in which case it + # returns a scalar Tensor. + scale = self._create_scale_operator( + identity_multiplier=scale_identity_multiplier, + diag=scale_diag, + tril=scale_tril, + perturb_diag=scale_perturb_diag, + perturb_factor=scale_perturb_factor, + shift=shift, + validate_args=validate_args) + + if scale.dtype is not None: + dtype = scale.dtype.base_dtype + + if scale is not None and not self._is_only_identity_multiplier: + if (shift is not None and + shift.dtype.base_dtype != scale.dtype.base_dtype): + raise TypeError( + "shift.dtype({}) is incompatible with scale.dtype({}).".format( + shift.dtype, scale.dtype)) + + if scale.tensor_rank is not None: + batch_ndims = scale.tensor_rank - 2 + else: + batch_ndims = scale.tensor_rank_tensor() - 2 + else: + # We won't need shape inference when scale is None or when scale is a + # scalar. + batch_ndims = 0 + self._scale = scale + self._shaper = _DistributionShape( + batch_ndims=batch_ndims, + event_ndims=event_ndims, + validate_args=validate_args) + super(Affine, self).__init__( + event_ndims=event_ndims, + graph_parents=( + [event_ndims] + + [self._scale] if tensor_util.is_tensor(self._scale) + else self._scale.graph_parents + + [self._shift] if self._shift is not None else []), + is_constant_jacobian=True, + dtype=dtype, + validate_args=validate_args, + name=name) + + def _create_scale_operator(self, identity_multiplier, diag, tril, + perturb_diag, perturb_factor, shift, + validate_args): + """Construct `scale` from various components. + + Args: + identity_multiplier: floating point rank 0 `Tensor` representing a scaling + done to the identity matrix. + diag: Floating-point `Tensor` representing the diagonal matrix. + `scale_diag` has shape [N1, N2, ... k], which represents a k x k + diagonal matrix. + tril: Floating-point `Tensor` representing the diagonal matrix. + `scale_tril` has shape [N1, N2, ... k], which represents a k x k lower + triangular matrix. + perturb_diag: Floating-point `Tensor` representing the diagonal matrix of + the low rank update. + perturb_factor: Floating-point `Tensor` representing factor matrix. + shift: Floating-point `Tensor` representing `shift in `scale @ X + shift`. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + + Returns: + scale. In the case of scaling by a constant, scale is a + floating point `Tensor`. Otherwise, scale is a `LinearOperator`. + + Raises: + ValueError: if all of `tril`, `diag` and `identity_multiplier` are `None`. + """ + identity_multiplier = _as_tensor(identity_multiplier, "identity_multiplier") + diag = _as_tensor(diag, "diag") + tril = _as_tensor(tril, "tril") + perturb_diag = _as_tensor(perturb_diag, "perturb_diag") + perturb_factor = _as_tensor(perturb_factor, "perturb_factor") + + # If possible, use the low rank update to infer the shape of + # the identity matrix, when scale represents a scaled identity matrix + # with a low rank update. + shape_hint = None + if perturb_factor is not None: + shape_hint = distribution_util.dimension_size(perturb_factor, axis=-2) + + if self._is_only_identity_multiplier: + if validate_args: + return control_flow_ops.with_dependencies( + [check_ops.assert_none_equal( + identity_multiplier, + array_ops.zeros([], identity_multiplier.dtype), + ["identity_multiplier should be non-zero."])], + identity_multiplier) + return identity_multiplier + + scale = distribution_util.make_tril_scale( + loc=shift, + scale_tril=tril, + scale_diag=diag, + scale_identity_multiplier=identity_multiplier, + validate_args=validate_args, + assert_positive=False, + shape_hint=shape_hint) + + if perturb_factor is not None: + return linalg.LinearOperatorLowRankUpdate( + scale, + u=perturb_factor, + diag_update=perturb_diag, + is_diag_update_positive=perturb_diag is None, + is_non_singular=True, # Implied by is_positive_definite=True. + is_self_adjoint=True, + is_positive_definite=True, + is_square=True) + + return scale + + @property + def shift(self): + """The `shift` `Tensor` in `Y = scale @ X + shift`.""" + return self._shift + + @property + def scale(self): + """The `scale` `LinearOperator` in `Y = scale @ X + shift`.""" + return self._scale + + def _forward(self, x): + y = x + if self._is_only_identity_multiplier: + y *= self._scale + if self.shift is not None: + return y + self.shift + return y + y, sample_shape = self._shaper.make_batch_of_event_sample_matrices( + y, expand_batch_dim=False) + with ops.control_dependencies(self._maybe_check_scale() if + self.validate_args else []): + y = self.scale.matmul(y) + y = self._shaper.undo_make_batch_of_event_sample_matrices( + y, sample_shape, expand_batch_dim=False) + if self.shift is not None: + y += self.shift + return y + + def _inverse(self, y): + x = y + if self.shift is not None: + x -= self.shift + if self._is_only_identity_multiplier: + return x / self._scale + + x, sample_shape = self._shaper.make_batch_of_event_sample_matrices( + x, expand_batch_dim=False) + # Solve fails if the op is singular so we may safely skip this assertion. + x = self.scale.solve(x) + x = self._shaper.undo_make_batch_of_event_sample_matrices( + x, sample_shape, expand_batch_dim=False) + return x + + def _inverse_log_det_jacobian(self, y): + return -self._forward_log_det_jacobian(y) + + def _forward_log_det_jacobian(self, x): + if self._is_only_identity_multiplier: + # We don't pad in this case and instead let the fldj be applied + # via broadcast. + event_size = distribution_util.pick_vector( + math_ops.equal(self._shaper.event_ndims, 0), + [1], array_ops.shape(x))[-1] + event_size = math_ops.cast(event_size, dtype=self._scale.dtype) + return math_ops.log(math_ops.abs(self._scale)) * event_size + return self.scale.log_abs_determinant() + + def _maybe_check_scale(self): + try: + return [self.scale.assert_non_singular()] + except NotImplementedError: + pass + return [] diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_impl.py deleted file mode 100644 index 05bb9c2f9bdf35e222c94db3491157893da64ebd..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_impl.py +++ /dev/null @@ -1,403 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Affine bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib import linalg -from tensorflow.contrib.distributions.python.ops import distribution_util -from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector - - -__all__ = [ - "Affine", -] - - -def _as_tensor(x, name): - """Convenience to convert to `Tensor` or leave as `None`.""" - return None if x is None else ops.convert_to_tensor(x, name=name) - - -class Affine(bijector.Bijector): - """Compute `Y = g(X; shift, scale) = scale @ X + shift`. - - Here `scale = c * I + diag(D1) + tril(L) + V @ diag(D2) @ V.T`. - - In TF parlance, the `scale` term is logically equivalent to: - - ```python - scale = ( - scale_identity_multiplier * tf.diag(tf.ones(d)) + - tf.diag(scale_diag) + - scale_tril + - scale_perturb_factor @ diag(scale_perturb_diag) @ - tf.transpose([scale_perturb_factor]) - ) - ``` - - The `scale` term is applied without necessarily materializing constituent - matrices, i.e., the matmul is [matrix-free]( - https://en.wikipedia.org/wiki/Matrix-free_methods) when possible. - - Examples: - - ```python - # Y = X - b = Affine() - - # Y = X + shift - b = Affine(shift=[1., 2, 3]) - - # Y = 2 * I @ X.T + shift - b = Affine(shift=[1., 2, 3], - scale_identity_multiplier=2.) - - # Y = tf.diag(d1) @ X.T + shift - b = Affine(shift=[1., 2, 3], - scale_diag=[-1., 2, 1]) # Implicitly 3x3. - - # Y = (I + v * v.T) @ X.T + shift - b = Affine(shift=[1., 2, 3], - scale_perturb_factor=[[1., 0], - [0, 1], - [1, 1]]) - - # Y = (diag(d1) + v * diag(d2) * v.T) @ X.T + shift - b = Affine(shift=[1., 2, 3], - scale_diag=[1., 3, 3], # Implicitly 3x3. - scale_perturb_diag=[2., 1], # Implicitly 2x2. - scale_perturb_factor=[[1., 0], - [0, 1], - [1, 1]]) - - ``` - - """ - - def __init__(self, - shift=None, - scale_identity_multiplier=None, - scale_diag=None, - scale_tril=None, - scale_perturb_factor=None, - scale_perturb_diag=None, - event_ndims=1, - validate_args=False, - name="affine"): - """Instantiates the `Affine` bijector. - - This `Bijector` is initialized with `shift` `Tensor` and `scale` arguments, - giving the forward operation: - - ```none - Y = g(X) = scale @ X + shift - ``` - - where the `scale` term is logically equivalent to: - - ```python - scale = ( - scale_identity_multiplier * tf.diag(tf.ones(d)) + - tf.diag(scale_diag) + - scale_tril + - scale_perturb_factor @ diag(scale_perturb_diag) @ - tf.transpose([scale_perturb_factor]) - ) - ``` - - If none of `scale_identity_multiplier`, `scale_diag`, or `scale_tril` are - specified then `scale += IdentityMatrix`. Otherwise specifying a - `scale` argument has the semantics of `scale += Expand(arg)`, i.e., - `scale_diag != None` means `scale += tf.diag(scale_diag)`. - - Args: - shift: Floating-point `Tensor`. If this is set to `None`, no shift is - applied. - scale_identity_multiplier: floating point rank 0 `Tensor` representing a - scaling done to the identity matrix. - When `scale_identity_multiplier = scale_diag = scale_tril = None` then - `scale += IdentityMatrix`. Otherwise no scaled-identity-matrix is added - to `scale`. - scale_diag: Floating-point `Tensor` representing the diagonal matrix. - `scale_diag` has shape [N1, N2, ... k], which represents a k x k - diagonal matrix. - When `None` no diagonal term is added to `scale`. - scale_tril: Floating-point `Tensor` representing the diagonal matrix. - `scale_diag` has shape [N1, N2, ... k, k], which represents a k x k - lower triangular matrix. - When `None` no `scale_tril` term is added to `scale`. - The upper triangular elements above the diagonal are ignored. - scale_perturb_factor: Floating-point `Tensor` representing factor matrix - with last two dimensions of shape `(k, r)`. When `None`, no rank-r - update is added to `scale`. - scale_perturb_diag: Floating-point `Tensor` representing the diagonal - matrix. `scale_perturb_diag` has shape [N1, N2, ... r], which - represents an `r x r` diagonal matrix. When `None` low rank updates will - take the form `scale_perturb_factor * scale_perturb_factor.T`. - event_ndims: Scalar `int` `Tensor` indicating the number of dimensions - associated with a particular draw from the distribution. Must be 0 or 1. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - - Raises: - ValueError: if `perturb_diag` is specified but not `perturb_factor`. - TypeError: if `shift` has different `dtype` from `scale` arguments. - """ - self._graph_parents = [] - self._name = name - self._validate_args = validate_args - - # Ambiguous definition of low rank update. - if scale_perturb_diag is not None and scale_perturb_factor is None: - raise ValueError("When scale_perturb_diag is specified, " - "scale_perturb_factor must be specified.") - - # Special case, only handling a scaled identity matrix. We don't know its - # dimensions, so this is special cased. - # We don't check identity_multiplier, since below we set it to 1. if all - # other scale args are None. - self._is_only_identity_multiplier = (scale_tril is None and - scale_diag is None and - scale_perturb_factor is None) - - with self._name_scope("init", values=[ - shift, scale_identity_multiplier, scale_diag, scale_tril, - scale_perturb_diag, scale_perturb_factor]): - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims_const = tensor_util.constant_value(event_ndims) - if event_ndims_const is not None and event_ndims_const not in (0, 1): - raise ValueError("event_ndims(%s) was not 0 or 1" % event_ndims_const) - else: - if validate_args: - # Shape tool will catch if event_ndims is negative. - event_ndims = control_flow_ops.with_dependencies( - [check_ops.assert_less( - event_ndims, 2, message="event_ndims must be 0 or 1")], - event_ndims) - - if event_ndims_const == 0 and not self._is_only_identity_multiplier: - raise ValueError( - "If event_ndims == 0, the only scale argument you can pass is " - "scale_identity_multiplier. All others operate on vectors.") - - # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`. - dtype = dtypes.float32 - - if shift is not None: - shift = ops.convert_to_tensor(shift, name="shift") - dtype = shift.dtype.base_dtype - self._shift = shift - - # When no args are specified, pretend the scale matrix is the identity - # matrix. - if (self._is_only_identity_multiplier and - scale_identity_multiplier is None): - scale_identity_multiplier = ops.convert_to_tensor(1., dtype=dtype) - - # self._create_scale_operator returns a LinearOperator in all cases - # except if self._is_only_identity_multiplier; in which case it - # returns a scalar Tensor. - scale = self._create_scale_operator( - identity_multiplier=scale_identity_multiplier, - diag=scale_diag, - tril=scale_tril, - perturb_diag=scale_perturb_diag, - perturb_factor=scale_perturb_factor, - shift=shift, - validate_args=validate_args) - - if scale.dtype is not None: - dtype = scale.dtype.base_dtype - - if scale is not None and not self._is_only_identity_multiplier: - if (shift is not None and - shift.dtype.base_dtype != scale.dtype.base_dtype): - raise TypeError( - "shift.dtype({}) is incompatible with scale.dtype({}).".format( - shift.dtype, scale.dtype)) - - if scale.tensor_rank is not None: - batch_ndims = scale.tensor_rank - 2 - else: - batch_ndims = scale.tensor_rank_tensor() - 2 - else: - # We won't need shape inference when scale is None or when scale is a - # scalar. - batch_ndims = 0 - self._scale = scale - self._shaper = _DistributionShape( - batch_ndims=batch_ndims, - event_ndims=event_ndims, - validate_args=validate_args) - super(Affine, self).__init__( - event_ndims=event_ndims, - graph_parents=( - [event_ndims] + - [self._scale] if tensor_util.is_tensor(self._scale) - else self._scale.graph_parents + - [self._shift] if self._shift is not None else []), - is_constant_jacobian=True, - dtype=dtype, - validate_args=validate_args, - name=name) - - def _create_scale_operator(self, identity_multiplier, diag, tril, - perturb_diag, perturb_factor, shift, - validate_args): - """Construct `scale` from various components. - - Args: - identity_multiplier: floating point rank 0 `Tensor` representing a scaling - done to the identity matrix. - diag: Floating-point `Tensor` representing the diagonal matrix. - `scale_diag` has shape [N1, N2, ... k], which represents a k x k - diagonal matrix. - tril: Floating-point `Tensor` representing the diagonal matrix. - `scale_tril` has shape [N1, N2, ... k], which represents a k x k lower - triangular matrix. - perturb_diag: Floating-point `Tensor` representing the diagonal matrix of - the low rank update. - perturb_factor: Floating-point `Tensor` representing factor matrix. - shift: Floating-point `Tensor` representing `shift in `scale @ X + shift`. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - - Returns: - scale. In the case of scaling by a constant, scale is a - floating point `Tensor`. Otherwise, scale is a `LinearOperator`. - - Raises: - ValueError: if all of `tril`, `diag` and `identity_multiplier` are `None`. - """ - identity_multiplier = _as_tensor(identity_multiplier, "identity_multiplier") - diag = _as_tensor(diag, "diag") - tril = _as_tensor(tril, "tril") - perturb_diag = _as_tensor(perturb_diag, "perturb_diag") - perturb_factor = _as_tensor(perturb_factor, "perturb_factor") - - # If possible, use the low rank update to infer the shape of - # the identity matrix, when scale represents a scaled identity matrix - # with a low rank update. - shape_hint = None - if perturb_factor is not None: - shape_hint = distribution_util.dimension_size(perturb_factor, axis=-2) - - if self._is_only_identity_multiplier: - if validate_args: - return control_flow_ops.with_dependencies( - [check_ops.assert_none_equal( - identity_multiplier, - array_ops.zeros([], identity_multiplier.dtype), - ["identity_multiplier should be non-zero."])], - identity_multiplier) - return identity_multiplier - - scale = distribution_util.make_tril_scale( - loc=shift, - scale_tril=tril, - scale_diag=diag, - scale_identity_multiplier=identity_multiplier, - validate_args=validate_args, - assert_positive=False, - shape_hint=shape_hint) - - if perturb_factor is not None: - return linalg.LinearOperatorLowRankUpdate( - scale, - u=perturb_factor, - diag_update=perturb_diag, - is_diag_update_positive=perturb_diag is None, - is_non_singular=True, # Implied by is_positive_definite=True. - is_self_adjoint=True, - is_positive_definite=True, - is_square=True) - - return scale - - @property - def shift(self): - """The `shift` `Tensor` in `Y = scale @ X + shift`.""" - return self._shift - - @property - def scale(self): - """The `scale` `LinearOperator` in `Y = scale @ X + shift`.""" - return self._scale - - def _forward(self, x): - y = x - if self._is_only_identity_multiplier: - y *= self._scale - if self.shift is not None: - return y + self.shift - return y - y, sample_shape = self._shaper.make_batch_of_event_sample_matrices( - y, expand_batch_dim=False) - with ops.control_dependencies(self._maybe_check_scale() if - self.validate_args else []): - y = self.scale.matmul(y) - y = self._shaper.undo_make_batch_of_event_sample_matrices( - y, sample_shape, expand_batch_dim=False) - if self.shift is not None: - y += self.shift - return y - - def _inverse(self, y): - x = y - if self.shift is not None: - x -= self.shift - if self._is_only_identity_multiplier: - return x / self._scale - - x, sample_shape = self._shaper.make_batch_of_event_sample_matrices( - x, expand_batch_dim=False) - # Solve fails if the op is singular so we may safely skip this assertion. - x = self.scale.solve(x) - x = self._shaper.undo_make_batch_of_event_sample_matrices( - x, sample_shape, expand_batch_dim=False) - return x - - def _inverse_log_det_jacobian(self, y): - return -self._forward_log_det_jacobian(y) - - def _forward_log_det_jacobian(self, x): - if self._is_only_identity_multiplier: - # We don't pad in this case and instead let the fldj be applied - # via broadcast. - event_size = distribution_util.pick_vector( - math_ops.equal(self._shaper.event_ndims, 0), - [1], array_ops.shape(x))[-1] - event_size = math_ops.cast(event_size, dtype=self._scale.dtype) - return math_ops.log(math_ops.abs(self._scale)) * event_size - return self.scale.log_abs_determinant() - - def _maybe_check_scale(self): - try: - return [self.scale.assert_non_singular()] - except NotImplementedError: - pass - return [] diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py index aca04a89df7c3ee09d5f7cc10f6779e33fa7aa66..89043b1410370074f11f2cfa59b6b6663fa62521 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py @@ -18,12 +18,214 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.affine_linear_operator_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.ops.linalg import linear_operator -_allowed_symbols = ["AffineLinearOperator"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "AffineLinearOperator", +] + + +class AffineLinearOperator(bijector.Bijector): + """Compute `Y = g(X; shift, scale) = scale @ X + shift`. + + `shift` is a numeric `Tensor` and `scale` is a `LinearOperator`. + + If `X` is a scalar then the forward transformation is: `scale * X + shift` + where `*` denotes the scalar product. + + Note: we don't always simply transpose `X` (but write it this way for + brevity). Actually the input `X` undergoes the following transformation + before being premultiplied by `scale`: + + 1. If there are no sample dims, we call `X = tf.expand_dims(X, 0)`, i.e., + `new_sample_shape = [1]`. Otherwise do nothing. + 2. The sample shape is flattened to have one dimension, i.e., + `new_sample_shape = [n]` where `n = tf.reduce_prod(old_sample_shape)`. + 3. The sample dim is cyclically rotated left by 1, i.e., + `new_shape = [B1,...,Bb, k, n]` where `n` is as above, `k` is the + event_shape, and `B1,...,Bb` are the batch shapes for each of `b` batch + dimensions. + + (For more details see `shape.make_batch_of_event_sample_matrices`.) + + The result of the above transformation is that `X` can be regarded as a batch + of matrices where each column is a draw from the distribution. After + premultiplying by `scale`, we take the inverse of this procedure. The input + `Y` also undergoes the same transformation before/after premultiplying by + `inv(scale)`. + + Example Use: + + ```python + linalg = tf.linalg + + x = [1., 2, 3] + + shift = [-1., 0., 1] + diag = [1., 2, 3] + scale = linalg.LinearOperatorDiag(diag) + affine = AffineLinearOperator(shift, scale) + # In this case, `forward` is equivalent to: + # y = scale @ x + shift + y = affine.forward(x) # [0., 4, 10] + + shift = [2., 3, 1] + tril = [[1., 0, 0], + [2, 1, 0], + [3, 2, 1]] + scale = linalg.LinearOperatorLowerTriangular(tril) + affine = AffineLinearOperator(shift, scale) + # In this case, `forward` is equivalent to: + # np.squeeze(np.matmul(tril, np.expand_dims(x, -1)), -1) + shift + y = affine.forward(x) # [3., 7, 11] + ``` + + """ + + def __init__(self, + shift=None, + scale=None, + event_ndims=1, + validate_args=False, + name="affine_linear_operator"): + """Instantiates the `AffineLinearOperator` bijector. + + Args: + shift: Floating-point `Tensor`. + scale: Subclass of `LinearOperator`. Represents the (batch) positive + definite matrix `M` in `R^{k x k}`. + event_ndims: Scalar `integer` `Tensor` indicating the number of dimensions + associated with a particular draw from the distribution. Must be 0 or 1. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + + Raises: + ValueError: if `event_ndims` is not 0 or 1. + TypeError: if `scale` is not a `LinearOperator`. + TypeError: if `shift.dtype` does not match `scale.dtype`. + ValueError: if not `scale.is_non_singular`. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + graph_parents = [] + with self._name_scope("init", values=[shift]): + event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") + if tensor_util.constant_value(event_ndims) is not None: + event_ndims = tensor_util.constant_value(event_ndims) + if event_ndims not in (0, 1): + raise ValueError("event_ndims({}) was not 0 or 1".format(event_ndims)) + else: + if validate_args: + # Shape tool will catch if event_ndims is negative. + event_ndims = control_flow_ops.with_dependencies( + [check_ops.assert_less( + event_ndims, 2, message="event_ndims must be 0 or 1")], + event_ndims) + graph_parents += [event_ndims] + + # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`. + dtype = dtypes.float32 + + if shift is not None: + shift = ops.convert_to_tensor(shift, name="shift") + graph_parents += [shift] + dtype = shift.dtype.base_dtype + self._shift = shift + + if scale is not None: + if (shift is not None and + shift.dtype.base_dtype != scale.dtype.base_dtype): + raise TypeError( + "shift.dtype({}) is incompatible with scale.dtype({}).".format( + shift.dtype, scale.dtype)) + if not isinstance(scale, linear_operator.LinearOperator): + raise TypeError("scale is not an instance of tf.LinearOperator") + if validate_args and not scale.is_non_singular: + raise ValueError("Scale matrix must be non-singular.") + graph_parents += scale.graph_parents + if scale.tensor_rank is not None: + batch_ndims = scale.tensor_rank - 2 + else: + batch_ndims = scale.tensor_rank_tensor() - 2 + graph_parents += [batch_ndims] + if scale.dtype is not None: + dtype = scale.dtype.base_dtype + else: + batch_ndims = 0 # We won't need shape inference when scale is None. + self._scale = scale + self._shaper = _DistributionShape( + batch_ndims=batch_ndims, + event_ndims=event_ndims, + validate_args=validate_args) + super(AffineLinearOperator, self).__init__( + event_ndims=event_ndims, + graph_parents=graph_parents, + is_constant_jacobian=True, + dtype=dtype, + validate_args=validate_args, + name=name) + + @property + def shift(self): + """The `shift` `Tensor` in `Y = scale @ X + shift`.""" + return self._shift + + @property + def scale(self): + """The `scale` `LinearOperator` in `Y = scale @ X + shift`.""" + return self._scale + + def _forward(self, x): + y = x + if self.scale is not None: + y, sample_shape = self._shaper.make_batch_of_event_sample_matrices( + y, expand_batch_dim=False) + with ops.control_dependencies(self._maybe_collect_assertions() if + self.validate_args else []): + y = self.scale.matmul(y) + y = self._shaper.undo_make_batch_of_event_sample_matrices( + y, sample_shape, expand_batch_dim=False) + if self.shift is not None: + y += self.shift + return y + + def _inverse(self, y): + x = y + if self.shift is not None: + x -= self.shift + if self.scale is not None: + x, sample_shape = self._shaper.make_batch_of_event_sample_matrices( + x, expand_batch_dim=False) + # Solve fails if the op is singular so we may safely skip this assertion. + x = self.scale.solve(x) + x = self._shaper.undo_make_batch_of_event_sample_matrices( + x, sample_shape, expand_batch_dim=False) + return x + + def _inverse_log_det_jacobian(self, y): + return -self._forward_log_det_jacobian(y) + + def _forward_log_det_jacobian(self, x): # pylint: disable=unused-argument + if self.scale is None: + return constant_op.constant(0, dtype=x.dtype.base_dtype) + with ops.control_dependencies(self._maybe_collect_assertions() if + self.validate_args else []): + return self.scale.log_abs_determinant() + + def _maybe_collect_assertions(self): + try: + return [self.scale.assert_non_singular()] + except NotImplementedError: + pass + return [] diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator_impl.py deleted file mode 100644 index 89043b1410370074f11f2cfa59b6b6663fa62521..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator_impl.py +++ /dev/null @@ -1,231 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""AffineLinearOperator bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops.distributions import bijector -from tensorflow.python.ops.linalg import linear_operator - - -__all__ = [ - "AffineLinearOperator", -] - - -class AffineLinearOperator(bijector.Bijector): - """Compute `Y = g(X; shift, scale) = scale @ X + shift`. - - `shift` is a numeric `Tensor` and `scale` is a `LinearOperator`. - - If `X` is a scalar then the forward transformation is: `scale * X + shift` - where `*` denotes the scalar product. - - Note: we don't always simply transpose `X` (but write it this way for - brevity). Actually the input `X` undergoes the following transformation - before being premultiplied by `scale`: - - 1. If there are no sample dims, we call `X = tf.expand_dims(X, 0)`, i.e., - `new_sample_shape = [1]`. Otherwise do nothing. - 2. The sample shape is flattened to have one dimension, i.e., - `new_sample_shape = [n]` where `n = tf.reduce_prod(old_sample_shape)`. - 3. The sample dim is cyclically rotated left by 1, i.e., - `new_shape = [B1,...,Bb, k, n]` where `n` is as above, `k` is the - event_shape, and `B1,...,Bb` are the batch shapes for each of `b` batch - dimensions. - - (For more details see `shape.make_batch_of_event_sample_matrices`.) - - The result of the above transformation is that `X` can be regarded as a batch - of matrices where each column is a draw from the distribution. After - premultiplying by `scale`, we take the inverse of this procedure. The input - `Y` also undergoes the same transformation before/after premultiplying by - `inv(scale)`. - - Example Use: - - ```python - linalg = tf.linalg - - x = [1., 2, 3] - - shift = [-1., 0., 1] - diag = [1., 2, 3] - scale = linalg.LinearOperatorDiag(diag) - affine = AffineLinearOperator(shift, scale) - # In this case, `forward` is equivalent to: - # y = scale @ x + shift - y = affine.forward(x) # [0., 4, 10] - - shift = [2., 3, 1] - tril = [[1., 0, 0], - [2, 1, 0], - [3, 2, 1]] - scale = linalg.LinearOperatorLowerTriangular(tril) - affine = AffineLinearOperator(shift, scale) - # In this case, `forward` is equivalent to: - # np.squeeze(np.matmul(tril, np.expand_dims(x, -1)), -1) + shift - y = affine.forward(x) # [3., 7, 11] - ``` - - """ - - def __init__(self, - shift=None, - scale=None, - event_ndims=1, - validate_args=False, - name="affine_linear_operator"): - """Instantiates the `AffineLinearOperator` bijector. - - Args: - shift: Floating-point `Tensor`. - scale: Subclass of `LinearOperator`. Represents the (batch) positive - definite matrix `M` in `R^{k x k}`. - event_ndims: Scalar `integer` `Tensor` indicating the number of dimensions - associated with a particular draw from the distribution. Must be 0 or 1. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - - Raises: - ValueError: if `event_ndims` is not 0 or 1. - TypeError: if `scale` is not a `LinearOperator`. - TypeError: if `shift.dtype` does not match `scale.dtype`. - ValueError: if not `scale.is_non_singular`. - """ - self._graph_parents = [] - self._name = name - self._validate_args = validate_args - graph_parents = [] - with self._name_scope("init", values=[shift]): - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - if tensor_util.constant_value(event_ndims) is not None: - event_ndims = tensor_util.constant_value(event_ndims) - if event_ndims not in (0, 1): - raise ValueError("event_ndims({}) was not 0 or 1".format(event_ndims)) - else: - if validate_args: - # Shape tool will catch if event_ndims is negative. - event_ndims = control_flow_ops.with_dependencies( - [check_ops.assert_less( - event_ndims, 2, message="event_ndims must be 0 or 1")], - event_ndims) - graph_parents += [event_ndims] - - # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`. - dtype = dtypes.float32 - - if shift is not None: - shift = ops.convert_to_tensor(shift, name="shift") - graph_parents += [shift] - dtype = shift.dtype.base_dtype - self._shift = shift - - if scale is not None: - if (shift is not None and - shift.dtype.base_dtype != scale.dtype.base_dtype): - raise TypeError( - "shift.dtype({}) is incompatible with scale.dtype({}).".format( - shift.dtype, scale.dtype)) - if not isinstance(scale, linear_operator.LinearOperator): - raise TypeError("scale is not an instance of tf.LinearOperator") - if validate_args and not scale.is_non_singular: - raise ValueError("Scale matrix must be non-singular.") - graph_parents += scale.graph_parents - if scale.tensor_rank is not None: - batch_ndims = scale.tensor_rank - 2 - else: - batch_ndims = scale.tensor_rank_tensor() - 2 - graph_parents += [batch_ndims] - if scale.dtype is not None: - dtype = scale.dtype.base_dtype - else: - batch_ndims = 0 # We won't need shape inference when scale is None. - self._scale = scale - self._shaper = _DistributionShape( - batch_ndims=batch_ndims, - event_ndims=event_ndims, - validate_args=validate_args) - super(AffineLinearOperator, self).__init__( - event_ndims=event_ndims, - graph_parents=graph_parents, - is_constant_jacobian=True, - dtype=dtype, - validate_args=validate_args, - name=name) - - @property - def shift(self): - """The `shift` `Tensor` in `Y = scale @ X + shift`.""" - return self._shift - - @property - def scale(self): - """The `scale` `LinearOperator` in `Y = scale @ X + shift`.""" - return self._scale - - def _forward(self, x): - y = x - if self.scale is not None: - y, sample_shape = self._shaper.make_batch_of_event_sample_matrices( - y, expand_batch_dim=False) - with ops.control_dependencies(self._maybe_collect_assertions() if - self.validate_args else []): - y = self.scale.matmul(y) - y = self._shaper.undo_make_batch_of_event_sample_matrices( - y, sample_shape, expand_batch_dim=False) - if self.shift is not None: - y += self.shift - return y - - def _inverse(self, y): - x = y - if self.shift is not None: - x -= self.shift - if self.scale is not None: - x, sample_shape = self._shaper.make_batch_of_event_sample_matrices( - x, expand_batch_dim=False) - # Solve fails if the op is singular so we may safely skip this assertion. - x = self.scale.solve(x) - x = self._shaper.undo_make_batch_of_event_sample_matrices( - x, sample_shape, expand_batch_dim=False) - return x - - def _inverse_log_det_jacobian(self, y): - return -self._forward_log_det_jacobian(y) - - def _forward_log_det_jacobian(self, x): # pylint: disable=unused-argument - if self.scale is None: - return constant_op.constant(0, dtype=x.dtype.base_dtype) - with ops.control_dependencies(self._maybe_collect_assertions() if - self.validate_args else []): - return self.scale.log_abs_determinant() - - def _maybe_collect_assertions(self): - try: - return [self.scale.assert_non_singular()] - except NotImplementedError: - pass - return [] diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py index 0db10fb75c8483a8209f39370362b05a03d047ca..3ce7c26213034c7345a20faa803c94a1bfa8d579 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py @@ -18,12 +18,151 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.chain_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import itertools -_allowed_symbols = ["Chain"] +from tensorflow.python.framework import constant_op +from tensorflow.python.ops.distributions import bijector -remove_undocumented(__name__, _allowed_symbols) + +__all__ = [ + "Chain", +] + + +class Chain(bijector.Bijector): + """Bijector which applies a sequence of bijectors. + + Example Use: + + ```python + chain = Chain([Exp(), Softplus()], name="one_plus_exp") + ``` + + Results in: + + * Forward: + + ```python + exp = Exp() + softplus = Softplus() + Chain([exp, softplus]).forward(x) + = exp.forward(softplus.forward(x)) + = tf.exp(tf.log(1. + tf.exp(x))) + = 1. + tf.exp(x) + ``` + + * Inverse: + + ```python + exp = Exp() + softplus = Softplus() + Chain([exp, softplus]).inverse(y) + = softplus.inverse(exp.inverse(y)) + = tf.log(tf.exp(tf.log(y)) - 1.) + = tf.log(y - 1.) + ``` + + """ + + def __init__(self, bijectors=None, validate_args=False, name=None): + """Instantiates `Chain` bijector. + + Args: + bijectors: Python `list` of bijector instances. An empty list makes this + bijector equivalent to the `Identity` bijector. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str`, name given to ops managed by this object. Default: + E.g., `Chain([Exp(), Softplus()]).name == "chain_of_exp_of_softplus"`. + + Raises: + ValueError: if bijectors have different dtypes. + """ + if bijectors is None: + bijectors = () + self._bijectors = bijectors + + for a_bijector in bijectors: + if not a_bijector._is_injective: # pylint: disable=protected-access + raise NotImplementedError( + "Invert is not implemented for non-injective bijector ({})".format( + a_bijector.name)) + + dtype = list(set([b.dtype for b in bijectors])) + if len(dtype) > 2: + raise ValueError("incompatible dtypes: %s" % dtype) + elif len(dtype) == 2: + dtype = dtype[1] if dtype[0] is None else dtype[0] + event_ndims = bijectors[0].event_ndims + elif len(dtype) == 1: + dtype = dtype[0] + event_ndims = bijectors[0].event_ndims + else: + dtype = None + event_ndims = None + + super(Chain, self).__init__( + graph_parents=list(itertools.chain.from_iterable( + b.graph_parents for b in bijectors)), + is_constant_jacobian=all(b.is_constant_jacobian for b in bijectors), + validate_args=validate_args, + dtype=dtype, + event_ndims=event_ndims, + name=name or ("identity" if not bijectors else + "_of_".join(["chain"] + [b.name for b in bijectors]))) + + @property + def bijectors(self): + return self._bijectors + + def _shape_helper(self, func_name, input_shape, reverse): + new_shape = input_shape + for b in reversed(self.bijectors) if reverse else self.bijectors: + func = getattr(b, func_name, None) + if func is None: + raise ValueError("unable to call %s on bijector %s (%s)" % + (func_name, b.name, func)) + new_shape = func(new_shape) + return new_shape + + def _forward_event_shape(self, input_shape): + return self._shape_helper("forward_event_shape", input_shape, + reverse=True) + + def _forward_event_shape_tensor(self, input_shape): + return self._shape_helper( + "forward_event_shape_tensor", input_shape, reverse=True) + + def _inverse_event_shape(self, output_shape): + return self._shape_helper("inverse_event_shape", output_shape, + reverse=False) + + def _inverse_event_shape_tensor(self, output_shape): + return self._shape_helper("inverse_event_shape_tensor", output_shape, + reverse=False) + + def _inverse(self, y, **kwargs): + for b in self.bijectors: + y = b.inverse(y, **kwargs.get(b.name, {})) + return y + + def _inverse_log_det_jacobian(self, y, **kwargs): + ildj = constant_op.constant(0., dtype=y.dtype, + name="inverse_log_det_jacobian") + for b in self.bijectors: + ildj += b.inverse_log_det_jacobian(y, **kwargs.get(b.name, {})) + y = b.inverse(y, **kwargs.get(b.name, {})) + return ildj + + def _forward(self, x, **kwargs): + for b in reversed(self.bijectors): + x = b.forward(x, **kwargs.get(b.name, {})) + return x + + def _forward_log_det_jacobian(self, x, **kwargs): + fldj = constant_op.constant(0., dtype=x.dtype, + name="forward_log_det_jacobian") + for b in reversed(self.bijectors): + fldj += b.forward_log_det_jacobian(x, **kwargs.get(b.name, {})) + x = b.forward(x, **kwargs.get(b.name, {})) + return fldj diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/chain_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/chain_impl.py deleted file mode 100644 index 3ce7c26213034c7345a20faa803c94a1bfa8d579..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/chain_impl.py +++ /dev/null @@ -1,168 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Chain bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import itertools - -from tensorflow.python.framework import constant_op -from tensorflow.python.ops.distributions import bijector - - -__all__ = [ - "Chain", -] - - -class Chain(bijector.Bijector): - """Bijector which applies a sequence of bijectors. - - Example Use: - - ```python - chain = Chain([Exp(), Softplus()], name="one_plus_exp") - ``` - - Results in: - - * Forward: - - ```python - exp = Exp() - softplus = Softplus() - Chain([exp, softplus]).forward(x) - = exp.forward(softplus.forward(x)) - = tf.exp(tf.log(1. + tf.exp(x))) - = 1. + tf.exp(x) - ``` - - * Inverse: - - ```python - exp = Exp() - softplus = Softplus() - Chain([exp, softplus]).inverse(y) - = softplus.inverse(exp.inverse(y)) - = tf.log(tf.exp(tf.log(y)) - 1.) - = tf.log(y - 1.) - ``` - - """ - - def __init__(self, bijectors=None, validate_args=False, name=None): - """Instantiates `Chain` bijector. - - Args: - bijectors: Python `list` of bijector instances. An empty list makes this - bijector equivalent to the `Identity` bijector. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str`, name given to ops managed by this object. Default: - E.g., `Chain([Exp(), Softplus()]).name == "chain_of_exp_of_softplus"`. - - Raises: - ValueError: if bijectors have different dtypes. - """ - if bijectors is None: - bijectors = () - self._bijectors = bijectors - - for a_bijector in bijectors: - if not a_bijector._is_injective: # pylint: disable=protected-access - raise NotImplementedError( - "Invert is not implemented for non-injective bijector ({})".format( - a_bijector.name)) - - dtype = list(set([b.dtype for b in bijectors])) - if len(dtype) > 2: - raise ValueError("incompatible dtypes: %s" % dtype) - elif len(dtype) == 2: - dtype = dtype[1] if dtype[0] is None else dtype[0] - event_ndims = bijectors[0].event_ndims - elif len(dtype) == 1: - dtype = dtype[0] - event_ndims = bijectors[0].event_ndims - else: - dtype = None - event_ndims = None - - super(Chain, self).__init__( - graph_parents=list(itertools.chain.from_iterable( - b.graph_parents for b in bijectors)), - is_constant_jacobian=all(b.is_constant_jacobian for b in bijectors), - validate_args=validate_args, - dtype=dtype, - event_ndims=event_ndims, - name=name or ("identity" if not bijectors else - "_of_".join(["chain"] + [b.name for b in bijectors]))) - - @property - def bijectors(self): - return self._bijectors - - def _shape_helper(self, func_name, input_shape, reverse): - new_shape = input_shape - for b in reversed(self.bijectors) if reverse else self.bijectors: - func = getattr(b, func_name, None) - if func is None: - raise ValueError("unable to call %s on bijector %s (%s)" % - (func_name, b.name, func)) - new_shape = func(new_shape) - return new_shape - - def _forward_event_shape(self, input_shape): - return self._shape_helper("forward_event_shape", input_shape, - reverse=True) - - def _forward_event_shape_tensor(self, input_shape): - return self._shape_helper( - "forward_event_shape_tensor", input_shape, reverse=True) - - def _inverse_event_shape(self, output_shape): - return self._shape_helper("inverse_event_shape", output_shape, - reverse=False) - - def _inverse_event_shape_tensor(self, output_shape): - return self._shape_helper("inverse_event_shape_tensor", output_shape, - reverse=False) - - def _inverse(self, y, **kwargs): - for b in self.bijectors: - y = b.inverse(y, **kwargs.get(b.name, {})) - return y - - def _inverse_log_det_jacobian(self, y, **kwargs): - ildj = constant_op.constant(0., dtype=y.dtype, - name="inverse_log_det_jacobian") - for b in self.bijectors: - ildj += b.inverse_log_det_jacobian(y, **kwargs.get(b.name, {})) - y = b.inverse(y, **kwargs.get(b.name, {})) - return ildj - - def _forward(self, x, **kwargs): - for b in reversed(self.bijectors): - x = b.forward(x, **kwargs.get(b.name, {})) - return x - - def _forward_log_det_jacobian(self, x, **kwargs): - fldj = constant_op.constant(0., dtype=x.dtype, - name="forward_log_det_jacobian") - for b in reversed(self.bijectors): - fldj += b.forward_log_det_jacobian(x, **kwargs.get(b.name, {})) - x = b.forward(x, **kwargs.get(b.name, {})) - return fldj diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py index 4686af8bc42a3232cb3a34f2cfcce8323c5896dd..cbd60f92a60612c6cf791b2c7708a3310c6e2b6b 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py @@ -18,12 +18,219 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.cholesky_outer_product_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import numpy as np -_allowed_symbols = ["CholeskyOuterProduct"] +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.ops.distributions import util as distribution_util -remove_undocumented(__name__, _allowed_symbols) + +__all__ = [ + "CholeskyOuterProduct", +] + + +class CholeskyOuterProduct(bijector.Bijector): + """Compute `g(X) = X @ X.T`; X is lower-triangular, positive-diagonal matrix. + + `event_ndims` must be 0 or 2, i.e., scalar or matrix. + + Note: the upper-triangular part of X is ignored (whether or not its zero). + + The surjectivity of g as a map from the set of n x n positive-diagonal + lower-triangular matrices to the set of SPD matrices follows immediately from + executing the Cholesky factorization algorithm on an SPD matrix A to produce a + positive-diagonal lower-triangular matrix L such that `A = L @ L.T`. + + To prove the injectivity of g, suppose that L_1 and L_2 are lower-triangular + with positive diagonals and satisfy `A = L_1 @ L_1.T = L_2 @ L_2.T`. Then + `inv(L_1) @ A @ inv(L_1).T = [inv(L_1) @ L_2] @ [inv(L_1) @ L_2].T = I`. + Setting `L_3 := inv(L_1) @ L_2`, that L_3 is a positive-diagonal + lower-triangular matrix follows from `inv(L_1)` being positive-diagonal + lower-triangular (which follows from the diagonal of a triangular matrix being + its spectrum), and that the product of two positive-diagonal lower-triangular + matrices is another positive-diagonal lower-triangular matrix. + + A simple inductive argument (proceding one column of L_3 at a time) shows + that, if `I = L_3 @ L_3.T`, with L_3 being lower-triangular with positive- + diagonal, then `L_3 = I`. Thus, `L_1 = L_2`, proving injectivity of g. + + Examples: + + ```python + bijector.CholeskyOuterProduct(event_ndims=2).forward(x=[[1., 0], [2, 1]]) + # Result: [[1., 2], [2, 5]], i.e., x @ x.T + + bijector.CholeskyOuterProduct(event_ndims=2).inverse(y=[[1., 2], [2, 5]]) + # Result: [[1., 0], [2, 1]], i.e., cholesky(y). + ``` + + """ + + def __init__(self, event_ndims=2, validate_args=False, + name="cholesky_outer_product"): + """Instantiates the `CholeskyOuterProduct` bijector. + + Args: + event_ndims: `constant` `int32` scalar `Tensor` indicating the number of + dimensions associated with a particular draw from the distribution. Must + be 0 or 2. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + + Raises: + ValueError: if event_ndims is neither 0 or 2. + """ + self._graph_parents = [] + self._name = name + with self._name_scope("init", values=[event_ndims]): + event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") + event_ndims = tensor_util.constant_value(event_ndims) + if event_ndims is None or event_ndims not in [0, 2]: + raise ValueError("`event_ndims` must be a TF constant which is 0 or 2") + self._static_event_ndims = event_ndims + super(CholeskyOuterProduct, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) + + def _forward(self, x): + if self._static_event_ndims == 0: + return math_ops.square(x) + if self.validate_args: + is_matrix = check_ops.assert_rank_at_least(x, 2) + shape = array_ops.shape(x) + is_square = check_ops.assert_equal(shape[-2], shape[-1]) + x = control_flow_ops.with_dependencies([is_matrix, is_square], x) + # For safety, explicitly zero-out the upper triangular part. + x = array_ops.matrix_band_part(x, -1, 0) + return math_ops.matmul(x, x, adjoint_b=True) + + def _inverse(self, y): + return (math_ops.sqrt(y) if self._static_event_ndims == 0 + else linalg_ops.cholesky(y)) + + def _inverse_log_det_jacobian(self, y): + return -self._forward_log_det_jacobian(x=self._inverse(y)) + + def _forward_log_det_jacobian(self, x): + # Let Y be a symmetric, positive definite matrix and write: + # Y = X X.T + # where X is lower-triangular. + # + # Observe that, + # dY[i,j]/dX[a,b] + # = d/dX[a,b] { X[i,:] X[j,:] } + # = sum_{d=1}^p { I[i=a] I[d=b] X[j,d] + I[j=a] I[d=b] X[i,d] } + # + # To compute the Jacobian dX/dY we must represent X,Y as vectors. Since Y is + # symmetric and X is lower-triangular, we need vectors of dimension: + # d = p (p + 1) / 2 + # where X, Y are p x p matrices, p > 0. We use a row-major mapping, i.e., + # k = { i (i + 1) / 2 + j i>=j + # { undef ij thus i,j!=a. + # + # Since the Jacobian is lower-triangular, we need only compute the product + # of diagonal elements: + # d vec[Y] / d vec[X] @[k(i,j), k(i,j)] + # = X[j,j] + I[i=j] X[i,j] + # = 2 X[j,j]. + # Since there is a 2 X[j,j] term for every lower-triangular element of X we + # conclude: + # |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}. + if self._static_event_ndims == 0: + if self.validate_args: + is_positive = check_ops.assert_positive( + x, message="All elements must be positive.") + x = control_flow_ops.with_dependencies([is_positive], x) + return np.log(2.) + math_ops.log(x) + + diag = array_ops.matrix_diag_part(x) + + # We now ensure diag is columnar. Eg, if `diag = [1, 2, 3]` then the output + # is `[[1], [2], [3]]` and if `diag = [[1, 2, 3], [4, 5, 6]]` then the + # output is unchanged. + diag = self._make_columnar(diag) + + if self.validate_args: + is_matrix = check_ops.assert_rank_at_least( + x, 2, message="Input must be a (batch of) matrix.") + shape = array_ops.shape(x) + is_square = check_ops.assert_equal( + shape[-2], shape[-1], + message="Input must be a (batch of) square matrix.") + # Assuming lower-triangular means we only need check diag>0. + is_positive_definite = check_ops.assert_positive( + diag, message="Input must be positive definite.") + x = control_flow_ops.with_dependencies( + [is_matrix, is_square, is_positive_definite], x) + + # Create a vector equal to: [p, p-1, ..., 2, 1]. + if x.get_shape().ndims is None or x.get_shape()[-1].value is None: + p_int = array_ops.shape(x)[-1] + p_float = math_ops.cast(p_int, dtype=x.dtype) + else: + p_int = x.get_shape()[-1].value + p_float = np.array(p_int, dtype=x.dtype.as_numpy_dtype) + exponents = math_ops.linspace(p_float, 1., p_int) + + sum_weighted_log_diag = array_ops.squeeze( + math_ops.matmul(math_ops.log(diag), + exponents[..., array_ops.newaxis]), + squeeze_dims=-1) + fldj = p_float * np.log(2.) + sum_weighted_log_diag + + return fldj + + def _make_columnar(self, x): + """Ensures non-scalar input has at least one column. + + Example: + If `x = [1, 2, 3]` then the output is `[[1], [2], [3]]`. + + If `x = [[1, 2, 3], [4, 5, 6]]` then the output is unchanged. + + If `x = 1` then the output is unchanged. + + Args: + x: `Tensor`. + + Returns: + columnar_x: `Tensor` with at least two dimensions. + """ + if x.get_shape().ndims is not None: + if x.get_shape().ndims == 1: + x = x[array_ops.newaxis, :] + return x + shape = array_ops.shape(x) + maybe_expanded_shape = array_ops.concat([ + shape[:-1], + distribution_util.pick_vector( + math_ops.equal(array_ops.rank(x), 1), + [1], np.array([], dtype=np.int32)), + shape[-1:], + ], 0) + return array_ops.reshape(x, maybe_expanded_shape) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product_impl.py deleted file mode 100644 index cbd60f92a60612c6cf791b2c7708a3310c6e2b6b..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product_impl.py +++ /dev/null @@ -1,236 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""CholeskyOuterProduct bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector -from tensorflow.python.ops.distributions import util as distribution_util - - -__all__ = [ - "CholeskyOuterProduct", -] - - -class CholeskyOuterProduct(bijector.Bijector): - """Compute `g(X) = X @ X.T`; X is lower-triangular, positive-diagonal matrix. - - `event_ndims` must be 0 or 2, i.e., scalar or matrix. - - Note: the upper-triangular part of X is ignored (whether or not its zero). - - The surjectivity of g as a map from the set of n x n positive-diagonal - lower-triangular matrices to the set of SPD matrices follows immediately from - executing the Cholesky factorization algorithm on an SPD matrix A to produce a - positive-diagonal lower-triangular matrix L such that `A = L @ L.T`. - - To prove the injectivity of g, suppose that L_1 and L_2 are lower-triangular - with positive diagonals and satisfy `A = L_1 @ L_1.T = L_2 @ L_2.T`. Then - `inv(L_1) @ A @ inv(L_1).T = [inv(L_1) @ L_2] @ [inv(L_1) @ L_2].T = I`. - Setting `L_3 := inv(L_1) @ L_2`, that L_3 is a positive-diagonal - lower-triangular matrix follows from `inv(L_1)` being positive-diagonal - lower-triangular (which follows from the diagonal of a triangular matrix being - its spectrum), and that the product of two positive-diagonal lower-triangular - matrices is another positive-diagonal lower-triangular matrix. - - A simple inductive argument (proceding one column of L_3 at a time) shows - that, if `I = L_3 @ L_3.T`, with L_3 being lower-triangular with positive- - diagonal, then `L_3 = I`. Thus, `L_1 = L_2`, proving injectivity of g. - - Examples: - - ```python - bijector.CholeskyOuterProduct(event_ndims=2).forward(x=[[1., 0], [2, 1]]) - # Result: [[1., 2], [2, 5]], i.e., x @ x.T - - bijector.CholeskyOuterProduct(event_ndims=2).inverse(y=[[1., 2], [2, 5]]) - # Result: [[1., 0], [2, 1]], i.e., cholesky(y). - ``` - - """ - - def __init__(self, event_ndims=2, validate_args=False, - name="cholesky_outer_product"): - """Instantiates the `CholeskyOuterProduct` bijector. - - Args: - event_ndims: `constant` `int32` scalar `Tensor` indicating the number of - dimensions associated with a particular draw from the distribution. Must - be 0 or 2. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - - Raises: - ValueError: if event_ndims is neither 0 or 2. - """ - self._graph_parents = [] - self._name = name - with self._name_scope("init", values=[event_ndims]): - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims = tensor_util.constant_value(event_ndims) - if event_ndims is None or event_ndims not in [0, 2]: - raise ValueError("`event_ndims` must be a TF constant which is 0 or 2") - self._static_event_ndims = event_ndims - super(CholeskyOuterProduct, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) - - def _forward(self, x): - if self._static_event_ndims == 0: - return math_ops.square(x) - if self.validate_args: - is_matrix = check_ops.assert_rank_at_least(x, 2) - shape = array_ops.shape(x) - is_square = check_ops.assert_equal(shape[-2], shape[-1]) - x = control_flow_ops.with_dependencies([is_matrix, is_square], x) - # For safety, explicitly zero-out the upper triangular part. - x = array_ops.matrix_band_part(x, -1, 0) - return math_ops.matmul(x, x, adjoint_b=True) - - def _inverse(self, y): - return (math_ops.sqrt(y) if self._static_event_ndims == 0 - else linalg_ops.cholesky(y)) - - def _inverse_log_det_jacobian(self, y): - return -self._forward_log_det_jacobian(x=self._inverse(y)) - - def _forward_log_det_jacobian(self, x): - # Let Y be a symmetric, positive definite matrix and write: - # Y = X X.T - # where X is lower-triangular. - # - # Observe that, - # dY[i,j]/dX[a,b] - # = d/dX[a,b] { X[i,:] X[j,:] } - # = sum_{d=1}^p { I[i=a] I[d=b] X[j,d] + I[j=a] I[d=b] X[i,d] } - # - # To compute the Jacobian dX/dY we must represent X,Y as vectors. Since Y is - # symmetric and X is lower-triangular, we need vectors of dimension: - # d = p (p + 1) / 2 - # where X, Y are p x p matrices, p > 0. We use a row-major mapping, i.e., - # k = { i (i + 1) / 2 + j i>=j - # { undef ij thus i,j!=a. - # - # Since the Jacobian is lower-triangular, we need only compute the product - # of diagonal elements: - # d vec[Y] / d vec[X] @[k(i,j), k(i,j)] - # = X[j,j] + I[i=j] X[i,j] - # = 2 X[j,j]. - # Since there is a 2 X[j,j] term for every lower-triangular element of X we - # conclude: - # |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}. - if self._static_event_ndims == 0: - if self.validate_args: - is_positive = check_ops.assert_positive( - x, message="All elements must be positive.") - x = control_flow_ops.with_dependencies([is_positive], x) - return np.log(2.) + math_ops.log(x) - - diag = array_ops.matrix_diag_part(x) - - # We now ensure diag is columnar. Eg, if `diag = [1, 2, 3]` then the output - # is `[[1], [2], [3]]` and if `diag = [[1, 2, 3], [4, 5, 6]]` then the - # output is unchanged. - diag = self._make_columnar(diag) - - if self.validate_args: - is_matrix = check_ops.assert_rank_at_least( - x, 2, message="Input must be a (batch of) matrix.") - shape = array_ops.shape(x) - is_square = check_ops.assert_equal( - shape[-2], shape[-1], - message="Input must be a (batch of) square matrix.") - # Assuming lower-triangular means we only need check diag>0. - is_positive_definite = check_ops.assert_positive( - diag, message="Input must be positive definite.") - x = control_flow_ops.with_dependencies( - [is_matrix, is_square, is_positive_definite], x) - - # Create a vector equal to: [p, p-1, ..., 2, 1]. - if x.get_shape().ndims is None or x.get_shape()[-1].value is None: - p_int = array_ops.shape(x)[-1] - p_float = math_ops.cast(p_int, dtype=x.dtype) - else: - p_int = x.get_shape()[-1].value - p_float = np.array(p_int, dtype=x.dtype.as_numpy_dtype) - exponents = math_ops.linspace(p_float, 1., p_int) - - sum_weighted_log_diag = array_ops.squeeze( - math_ops.matmul(math_ops.log(diag), - exponents[..., array_ops.newaxis]), - squeeze_dims=-1) - fldj = p_float * np.log(2.) + sum_weighted_log_diag - - return fldj - - def _make_columnar(self, x): - """Ensures non-scalar input has at least one column. - - Example: - If `x = [1, 2, 3]` then the output is `[[1], [2], [3]]`. - - If `x = [[1, 2, 3], [4, 5, 6]]` then the output is unchanged. - - If `x = 1` then the output is unchanged. - - Args: - x: `Tensor`. - - Returns: - columnar_x: `Tensor` with at least two dimensions. - """ - if x.get_shape().ndims is not None: - if x.get_shape().ndims == 1: - x = x[array_ops.newaxis, :] - return x - shape = array_ops.shape(x) - maybe_expanded_shape = array_ops.concat([ - shape[:-1], - distribution_util.pick_vector( - math_ops.equal(array_ops.rank(x), 1), - [1], np.array([], dtype=np.int32)), - shape[-1:], - ], 0) - return array_ops.reshape(x, maybe_expanded_shape) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector.py b/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector.py index d254b635d28099a09a2054536f04ffee3a355b2f..ccb1f029277bc07011df7be047a075274f2b3a27 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector.py @@ -18,12 +18,38 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.conditional_bijector_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.ops.distributions import util as distribution_util -_allowed_symbols = ["ConditionalBijector"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = ["ConditionalBijector"] + + +class ConditionalBijector(bijector.Bijector): + """Conditional Bijector is a Bijector that allows intrinsic conditioning.""" + + @distribution_util.AppendDocstring(kwargs_dict={ + "**condition_kwargs": + "Named arguments forwarded to subclass implementation."}) + def forward(self, x, name="forward", **condition_kwargs): + return self._call_forward(x, name, **condition_kwargs) + + @distribution_util.AppendDocstring(kwargs_dict={ + "**condition_kwargs": + "Named arguments forwarded to subclass implementation."}) + def inverse(self, y, name="inverse", **condition_kwargs): + return self._call_inverse(y, name, **condition_kwargs) + + @distribution_util.AppendDocstring(kwargs_dict={ + "**condition_kwargs": + "Named arguments forwarded to subclass implementation."}) + def inverse_log_det_jacobian( + self, y, name="inverse_log_det_jacobian", **condition_kwargs): + return self._call_inverse_log_det_jacobian(y, name, **condition_kwargs) + + @distribution_util.AppendDocstring(kwargs_dict={ + "**condition_kwargs": + "Named arguments forwarded to subclass implementation."}) + def forward_log_det_jacobian( + self, x, name="forward_log_det_jacobian", **condition_kwargs): + return self._call_forward_log_det_jacobian(x, name, **condition_kwargs) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector_impl.py deleted file mode 100644 index ccb1f029277bc07011df7be047a075274f2b3a27..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector_impl.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""ConditionalBijector base.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.ops.distributions import bijector -from tensorflow.python.ops.distributions import util as distribution_util - - -__all__ = ["ConditionalBijector"] - - -class ConditionalBijector(bijector.Bijector): - """Conditional Bijector is a Bijector that allows intrinsic conditioning.""" - - @distribution_util.AppendDocstring(kwargs_dict={ - "**condition_kwargs": - "Named arguments forwarded to subclass implementation."}) - def forward(self, x, name="forward", **condition_kwargs): - return self._call_forward(x, name, **condition_kwargs) - - @distribution_util.AppendDocstring(kwargs_dict={ - "**condition_kwargs": - "Named arguments forwarded to subclass implementation."}) - def inverse(self, y, name="inverse", **condition_kwargs): - return self._call_inverse(y, name, **condition_kwargs) - - @distribution_util.AppendDocstring(kwargs_dict={ - "**condition_kwargs": - "Named arguments forwarded to subclass implementation."}) - def inverse_log_det_jacobian( - self, y, name="inverse_log_det_jacobian", **condition_kwargs): - return self._call_inverse_log_det_jacobian(y, name, **condition_kwargs) - - @distribution_util.AppendDocstring(kwargs_dict={ - "**condition_kwargs": - "Named arguments forwarded to subclass implementation."}) - def forward_log_det_jacobian( - self, x, name="forward_log_det_jacobian", **condition_kwargs): - return self._call_forward_log_det_jacobian(x, name, **condition_kwargs) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/exp.py b/tensorflow/contrib/distributions/python/ops/bijectors/exp.py index 399d713098eb7223601beb9518dc51dd6160ad64..b1ff840d62a73c941a4d67dec73b5c9f4d5353f9 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/exp.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/exp.py @@ -18,12 +18,49 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.exp_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.contrib.distributions.python.ops.bijectors import power_transform -_allowed_symbols = ["Exp"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "Exp", +] + + +class Exp(power_transform.PowerTransform): + """Compute `Y = g(X) = exp(X)`. + + Example Use: + + ```python + # Create the Y=g(X)=exp(X) transform which works only on Tensors with 1 + # batch ndim and 2 event ndims (i.e., vector of matrices). + exp = Exp(event_ndims=2) + x = [[[1., 2], + [3, 4]], + [[5, 6], + [7, 8]]] + exp(x) == exp.forward(x) + log(x) == exp.inverse(x) + ``` + + Note: the exp(.) is applied element-wise but the Jacobian is a reduction + over the event space. + """ + + def __init__(self, + event_ndims=0, + validate_args=False, + name="exp"): + """Instantiates the `Exp` bijector. + + Args: + event_ndims: Scalar `int32` `Tensor` indicating the number of dimensions + associated with a particular draw from the distribution. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + super(Exp, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/exp_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/exp_impl.py deleted file mode 100644 index b1ff840d62a73c941a4d67dec73b5c9f4d5353f9..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/exp_impl.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Exp bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.distributions.python.ops.bijectors import power_transform - - -__all__ = [ - "Exp", -] - - -class Exp(power_transform.PowerTransform): - """Compute `Y = g(X) = exp(X)`. - - Example Use: - - ```python - # Create the Y=g(X)=exp(X) transform which works only on Tensors with 1 - # batch ndim and 2 event ndims (i.e., vector of matrices). - exp = Exp(event_ndims=2) - x = [[[1., 2], - [3, 4]], - [[5, 6], - [7, 8]]] - exp(x) == exp.forward(x) - log(x) == exp.inverse(x) - ``` - - Note: the exp(.) is applied element-wise but the Jacobian is a reduction - over the event space. - """ - - def __init__(self, - event_ndims=0, - validate_args=False, - name="exp"): - """Instantiates the `Exp` bijector. - - Args: - event_ndims: Scalar `int32` `Tensor` indicating the number of dimensions - associated with a particular draw from the distribution. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - """ - super(Exp, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py index cf37aa51115ed98ab263bc03bcb297a03432a7ae..67f39785563255be0fe154aca3cbcf01c6a01e73 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py @@ -18,12 +18,107 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.gumbel_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["Gumbel"] +__all__ = [ + "Gumbel", +] -remove_undocumented(__name__, _allowed_symbols) + +class Gumbel(bijector.Bijector): + """Compute `Y = g(X) = exp(-exp(-(X - loc) / scale))`. + + This bijector maps inputs from `[-inf, inf]` to [0, 1]`. The inverse of the + bijector applied to a uniform random variable `X ~ U(0, 1) gives back a + random variable with the + [Gumbel distribution](https://en.wikipedia.org/wiki/Gumbel_distribution): + + ```none + Y ~ Gumbel(loc, scale) + pdf(y; loc, scale) = exp( + -( (y - loc) / scale + exp(- (y - loc) / scale) ) ) / scale + ``` + """ + + def __init__(self, + loc=0., + scale=1., + event_ndims=0, + validate_args=False, + name="gumbel"): + """Instantiates the `Gumbel` bijector. + + Args: + loc: Float-like `Tensor` that is the same dtype and is + broadcastable with `scale`. + This is `loc` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`. + scale: Positive Float-like `Tensor` that is the same dtype and is + broadcastable with `loc`. + This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`. + event_ndims: Python scalar indicating the number of dimensions associated + with a particular draw from the distribution. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + with self._name_scope("init", values=[loc, scale]): + self._loc = ops.convert_to_tensor(loc, name="loc") + self._scale = ops.convert_to_tensor(scale, name="scale") + check_ops.assert_same_float_dtype([self._loc, self._scale]) + if validate_args: + self._scale = control_flow_ops.with_dependencies([ + check_ops.assert_positive( + self._scale, message="Argument scale was not positive") + ], self._scale) + + super(Gumbel, self).__init__( + event_ndims=event_ndims, validate_args=validate_args, name=name) + + @property + def loc(self): + """The `loc` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.""" + return self._loc + + @property + def scale(self): + """This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.""" + return self._scale + + def _forward(self, x): + z = (x - self.loc) / self.scale + return math_ops.exp(-math_ops.exp(-z)) + + def _inverse(self, y): + y = self._maybe_assert_valid_y(y) + return self.loc - self.scale * math_ops.log(-math_ops.log(y)) + + def _inverse_log_det_jacobian(self, y): + y = self._maybe_assert_valid_y(y) + event_dims = self._event_dims_tensor(y) + return math_ops.reduce_sum( + math_ops.log(self.scale / (-math_ops.log(y) * y)), axis=event_dims) + + def _forward_log_det_jacobian(self, x): + event_dims = self._event_dims_tensor(x) + z = (x - self.loc) / self.scale + return math_ops.reduce_sum( + -z - math_ops.exp(-z) - math_ops.log(self.scale), axis=event_dims) + + def _maybe_assert_valid_y(self, y): + if not self.validate_args: + return y + is_positive = check_ops.assert_non_negative( + y, message="Inverse transformation input must be greater than 0.") + less_than_one = check_ops.assert_less_equal( + y, + constant_op.constant(1., y.dtype), + message="Inverse transformation input must be less than or equal to 1.") + return control_flow_ops.with_dependencies([is_positive, less_than_one], y) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel_impl.py deleted file mode 100644 index 67f39785563255be0fe154aca3cbcf01c6a01e73..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel_impl.py +++ /dev/null @@ -1,124 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Gumbel bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector - -__all__ = [ - "Gumbel", -] - - -class Gumbel(bijector.Bijector): - """Compute `Y = g(X) = exp(-exp(-(X - loc) / scale))`. - - This bijector maps inputs from `[-inf, inf]` to [0, 1]`. The inverse of the - bijector applied to a uniform random variable `X ~ U(0, 1) gives back a - random variable with the - [Gumbel distribution](https://en.wikipedia.org/wiki/Gumbel_distribution): - - ```none - Y ~ Gumbel(loc, scale) - pdf(y; loc, scale) = exp( - -( (y - loc) / scale + exp(- (y - loc) / scale) ) ) / scale - ``` - """ - - def __init__(self, - loc=0., - scale=1., - event_ndims=0, - validate_args=False, - name="gumbel"): - """Instantiates the `Gumbel` bijector. - - Args: - loc: Float-like `Tensor` that is the same dtype and is - broadcastable with `scale`. - This is `loc` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`. - scale: Positive Float-like `Tensor` that is the same dtype and is - broadcastable with `loc`. - This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`. - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - """ - self._graph_parents = [] - self._name = name - self._validate_args = validate_args - with self._name_scope("init", values=[loc, scale]): - self._loc = ops.convert_to_tensor(loc, name="loc") - self._scale = ops.convert_to_tensor(scale, name="scale") - check_ops.assert_same_float_dtype([self._loc, self._scale]) - if validate_args: - self._scale = control_flow_ops.with_dependencies([ - check_ops.assert_positive( - self._scale, message="Argument scale was not positive") - ], self._scale) - - super(Gumbel, self).__init__( - event_ndims=event_ndims, validate_args=validate_args, name=name) - - @property - def loc(self): - """The `loc` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.""" - return self._loc - - @property - def scale(self): - """This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.""" - return self._scale - - def _forward(self, x): - z = (x - self.loc) / self.scale - return math_ops.exp(-math_ops.exp(-z)) - - def _inverse(self, y): - y = self._maybe_assert_valid_y(y) - return self.loc - self.scale * math_ops.log(-math_ops.log(y)) - - def _inverse_log_det_jacobian(self, y): - y = self._maybe_assert_valid_y(y) - event_dims = self._event_dims_tensor(y) - return math_ops.reduce_sum( - math_ops.log(self.scale / (-math_ops.log(y) * y)), axis=event_dims) - - def _forward_log_det_jacobian(self, x): - event_dims = self._event_dims_tensor(x) - z = (x - self.loc) / self.scale - return math_ops.reduce_sum( - -z - math_ops.exp(-z) - math_ops.log(self.scale), axis=event_dims) - - def _maybe_assert_valid_y(self, y): - if not self.validate_args: - return y - is_positive = check_ops.assert_non_negative( - y, message="Inverse transformation input must be greater than 0.") - less_than_one = check_ops.assert_less_equal( - y, - constant_op.constant(1., y.dtype), - message="Inverse transformation input must be less than or equal to 1.") - return control_flow_ops.with_dependencies([is_positive, less_than_one], y) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/inline.py b/tensorflow/contrib/distributions/python/ops/bijectors/inline.py index db10c3fc3a9135b4c408ada74622ba9b360f9ec1..fab1b22fbf92e7b92a5ec86ec62d66bec71a8c94 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/inline.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/inline.py @@ -18,12 +18,124 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.inline_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["Inline"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "Inline", +] + + +class Inline(bijector.Bijector): + """Bijector constructed from custom callables. + + Example Use: + + ```python + exp = Inline( + forward_fn=tf.exp, + inverse_fn=tf.log, + inverse_log_det_jacobian_fn=( + lambda y: -tf.reduce_sum(tf.log(y), axis=-1)), + name="exp") + ``` + + The above example is equivalent to the `Bijector` `Exp(event_ndims=1)`. + """ + + def __init__(self, + forward_fn=None, + inverse_fn=None, + inverse_log_det_jacobian_fn=None, + forward_log_det_jacobian_fn=None, + forward_event_shape_fn=None, + forward_event_shape_tensor_fn=None, + inverse_event_shape_fn=None, + inverse_event_shape_tensor_fn=None, + is_constant_jacobian=False, + validate_args=False, + name="inline"): + """Creates a `Bijector` from callables. + + Args: + forward_fn: Python callable implementing the forward transformation. + inverse_fn: Python callable implementing the inverse transformation. + inverse_log_det_jacobian_fn: Python callable implementing the + log o det o jacobian of the inverse transformation. + forward_log_det_jacobian_fn: Python callable implementing the + log o det o jacobian of the forward transformation. + forward_event_shape_fn: Python callable implementing non-identical + static event shape changes. Default: shape is assumed unchanged. + forward_event_shape_tensor_fn: Python callable implementing non-identical + event shape changes. Default: shape is assumed unchanged. + inverse_event_shape_fn: Python callable implementing non-identical + static event shape changes. Default: shape is assumed unchanged. + inverse_event_shape_tensor_fn: Python callable implementing non-identical + event shape changes. Default: shape is assumed unchanged. + is_constant_jacobian: Python `bool` indicating that the Jacobian is + constant for all input arguments. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str`, name given to ops managed by this object. + """ + super(Inline, self).__init__( + event_ndims=0, + is_constant_jacobian=is_constant_jacobian, + validate_args=validate_args, + name=name) + self._forward_fn = forward_fn + self._inverse_fn = inverse_fn + self._inverse_log_det_jacobian_fn = inverse_log_det_jacobian_fn + self._forward_log_det_jacobian_fn = forward_log_det_jacobian_fn + self._forward_event_shape_fn = forward_event_shape_fn + self._forward_event_shape_tensor_fn = forward_event_shape_tensor_fn + self._inverse_event_shape_fn = inverse_event_shape_fn + self._inverse_event_shape_tensor_fn = inverse_event_shape_tensor_fn + + def _forward_event_shape(self, input_shape): + if self._forward_event_shape_fn is None: + # By default assume shape doesn't change. + return input_shape + return self._forward_event_shape_fn(input_shape) + + def _forward_event_shape_tensor(self, input_shape): + if self._forward_event_shape_tensor_fn is None: + # By default assume shape doesn't change. + return input_shape + return self._forward_event_shape_tensor_fn(input_shape) + + def _inverse_event_shape(self, output_shape): + if self._inverse_event_shape_fn is None: + # By default assume shape doesn't change. + return output_shape + return self._inverse_event_shape_fn(output_shape) + + def _inverse_event_shape_tensor(self, output_shape): + if self._inverse_event_shape_tensor_fn is None: + # By default assume shape doesn't change. + return output_shape + return self._inverse_event_shape_tensor_fn(output_shape) + + def _forward(self, x, **kwargs): + if not callable(self._forward_fn): + raise NotImplementedError( + "forward_fn is not a callable function.") + return self._forward_fn(x, **kwargs) + + def _inverse(self, y, **kwargs): + if not callable(self._inverse_fn): + raise NotImplementedError( + "inverse_fn is not a callable function.") + return self._inverse_fn(y, **kwargs) + + def _inverse_log_det_jacobian(self, y, **kwargs): + if not callable(self._inverse_log_det_jacobian_fn): + raise NotImplementedError( + "inverse_log_det_jacobian_fn is not a callable function.") + return self._inverse_log_det_jacobian_fn(y, **kwargs) + + def _forward_log_det_jacobian(self, y, **kwargs): + if not callable(self._forward_log_det_jacobian_fn): + raise NotImplementedError( + "forward_log_det_jacobian_fn is not a callable function.") + return self._forward_log_det_jacobian_fn(y, **kwargs) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/inline_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/inline_impl.py deleted file mode 100644 index fab1b22fbf92e7b92a5ec86ec62d66bec71a8c94..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/inline_impl.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Inline bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.ops.distributions import bijector - - -__all__ = [ - "Inline", -] - - -class Inline(bijector.Bijector): - """Bijector constructed from custom callables. - - Example Use: - - ```python - exp = Inline( - forward_fn=tf.exp, - inverse_fn=tf.log, - inverse_log_det_jacobian_fn=( - lambda y: -tf.reduce_sum(tf.log(y), axis=-1)), - name="exp") - ``` - - The above example is equivalent to the `Bijector` `Exp(event_ndims=1)`. - """ - - def __init__(self, - forward_fn=None, - inverse_fn=None, - inverse_log_det_jacobian_fn=None, - forward_log_det_jacobian_fn=None, - forward_event_shape_fn=None, - forward_event_shape_tensor_fn=None, - inverse_event_shape_fn=None, - inverse_event_shape_tensor_fn=None, - is_constant_jacobian=False, - validate_args=False, - name="inline"): - """Creates a `Bijector` from callables. - - Args: - forward_fn: Python callable implementing the forward transformation. - inverse_fn: Python callable implementing the inverse transformation. - inverse_log_det_jacobian_fn: Python callable implementing the - log o det o jacobian of the inverse transformation. - forward_log_det_jacobian_fn: Python callable implementing the - log o det o jacobian of the forward transformation. - forward_event_shape_fn: Python callable implementing non-identical - static event shape changes. Default: shape is assumed unchanged. - forward_event_shape_tensor_fn: Python callable implementing non-identical - event shape changes. Default: shape is assumed unchanged. - inverse_event_shape_fn: Python callable implementing non-identical - static event shape changes. Default: shape is assumed unchanged. - inverse_event_shape_tensor_fn: Python callable implementing non-identical - event shape changes. Default: shape is assumed unchanged. - is_constant_jacobian: Python `bool` indicating that the Jacobian is - constant for all input arguments. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str`, name given to ops managed by this object. - """ - super(Inline, self).__init__( - event_ndims=0, - is_constant_jacobian=is_constant_jacobian, - validate_args=validate_args, - name=name) - self._forward_fn = forward_fn - self._inverse_fn = inverse_fn - self._inverse_log_det_jacobian_fn = inverse_log_det_jacobian_fn - self._forward_log_det_jacobian_fn = forward_log_det_jacobian_fn - self._forward_event_shape_fn = forward_event_shape_fn - self._forward_event_shape_tensor_fn = forward_event_shape_tensor_fn - self._inverse_event_shape_fn = inverse_event_shape_fn - self._inverse_event_shape_tensor_fn = inverse_event_shape_tensor_fn - - def _forward_event_shape(self, input_shape): - if self._forward_event_shape_fn is None: - # By default assume shape doesn't change. - return input_shape - return self._forward_event_shape_fn(input_shape) - - def _forward_event_shape_tensor(self, input_shape): - if self._forward_event_shape_tensor_fn is None: - # By default assume shape doesn't change. - return input_shape - return self._forward_event_shape_tensor_fn(input_shape) - - def _inverse_event_shape(self, output_shape): - if self._inverse_event_shape_fn is None: - # By default assume shape doesn't change. - return output_shape - return self._inverse_event_shape_fn(output_shape) - - def _inverse_event_shape_tensor(self, output_shape): - if self._inverse_event_shape_tensor_fn is None: - # By default assume shape doesn't change. - return output_shape - return self._inverse_event_shape_tensor_fn(output_shape) - - def _forward(self, x, **kwargs): - if not callable(self._forward_fn): - raise NotImplementedError( - "forward_fn is not a callable function.") - return self._forward_fn(x, **kwargs) - - def _inverse(self, y, **kwargs): - if not callable(self._inverse_fn): - raise NotImplementedError( - "inverse_fn is not a callable function.") - return self._inverse_fn(y, **kwargs) - - def _inverse_log_det_jacobian(self, y, **kwargs): - if not callable(self._inverse_log_det_jacobian_fn): - raise NotImplementedError( - "inverse_log_det_jacobian_fn is not a callable function.") - return self._inverse_log_det_jacobian_fn(y, **kwargs) - - def _forward_log_det_jacobian(self, y, **kwargs): - if not callable(self._forward_log_det_jacobian_fn): - raise NotImplementedError( - "forward_log_det_jacobian_fn is not a callable function.") - return self._forward_log_det_jacobian_fn(y, **kwargs) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/invert.py b/tensorflow/contrib/distributions/python/ops/bijectors/invert.py index c134e10109ce5065eb58de1d847e3c487258954c..2c603fe61f36dd27f4984fe6c13c11f2fb534321 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/invert.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/invert.py @@ -18,12 +18,85 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.invert_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.ops.distributions import bijector as bijector_lib -_allowed_symbols = ["Invert"] +__all__ = [ + "Invert", +] -remove_undocumented(__name__, _allowed_symbols) + +class Invert(bijector_lib.Bijector): + """Bijector which inverts another Bijector. + + Example Use: [ExpGammaDistribution (see Background & Context)]( + https://reference.wolfram.com/language/ref/ExpGammaDistribution.html) + models `Y=log(X)` where `X ~ Gamma`. + + ```python + exp_gamma_distribution = TransformedDistribution( + distribution=Gamma(concentration=1., rate=2.), + bijector=bijector.Invert(bijector.Exp()) + ``` + + """ + + def __init__(self, bijector, validate_args=False, name=None): + """Creates a `Bijector` which swaps the meaning of `inverse` and `forward`. + + Note: An inverted bijector's `inverse_log_det_jacobian` is often more + efficient if the base bijector implements `_forward_log_det_jacobian`. If + `_forward_log_det_jacobian` is not implemented then the following code is + used: + + ```python + y = self.inverse(x, **kwargs) + return -self.inverse_log_det_jacobian(y, **kwargs) + ``` + + Args: + bijector: Bijector instance. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str`, name given to ops managed by this object. + """ + + if not bijector._is_injective: # pylint: disable=protected-access + raise NotImplementedError( + "Invert is not implemented for non-injective bijectors.") + + self._bijector = bijector + super(Invert, self).__init__( + event_ndims=bijector.event_ndims, + graph_parents=bijector.graph_parents, + is_constant_jacobian=bijector.is_constant_jacobian, + validate_args=validate_args, + dtype=bijector.dtype, + name=name or "_".join(["invert", bijector.name])) + + def _forward_event_shape(self, input_shape): + return self.bijector._inverse_event_shape(input_shape) # pylint: disable=protected-access + + def _forward_event_shape_tensor(self, input_shape): + return self.bijector._inverse_event_shape_tensor(input_shape) # pylint: disable=protected-access + + def _inverse_event_shape(self, output_shape): + return self.bijector._forward_event_shape(output_shape) # pylint: disable=protected-access + + def _inverse_event_shape_tensor(self, output_shape): + return self.bijector._forward_event_shape_tensor(output_shape) # pylint: disable=protected-access + + @property + def bijector(self): + return self._bijector + + def _forward(self, x, **kwargs): + return self.bijector._inverse(x, **kwargs) # pylint: disable=protected-access + + def _inverse(self, y, **kwargs): + return self.bijector._forward(y, **kwargs) # pylint: disable=protected-access + + def _inverse_log_det_jacobian(self, y, **kwargs): + return self.bijector._forward_log_det_jacobian(y, **kwargs) # pylint: disable=protected-access + + def _forward_log_det_jacobian(self, x, **kwargs): + return self.bijector._inverse_log_det_jacobian(x, **kwargs) # pylint: disable=protected-access diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/invert_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/invert_impl.py deleted file mode 100644 index 2c603fe61f36dd27f4984fe6c13c11f2fb534321..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/invert_impl.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Invert bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.ops.distributions import bijector as bijector_lib - -__all__ = [ - "Invert", -] - - -class Invert(bijector_lib.Bijector): - """Bijector which inverts another Bijector. - - Example Use: [ExpGammaDistribution (see Background & Context)]( - https://reference.wolfram.com/language/ref/ExpGammaDistribution.html) - models `Y=log(X)` where `X ~ Gamma`. - - ```python - exp_gamma_distribution = TransformedDistribution( - distribution=Gamma(concentration=1., rate=2.), - bijector=bijector.Invert(bijector.Exp()) - ``` - - """ - - def __init__(self, bijector, validate_args=False, name=None): - """Creates a `Bijector` which swaps the meaning of `inverse` and `forward`. - - Note: An inverted bijector's `inverse_log_det_jacobian` is often more - efficient if the base bijector implements `_forward_log_det_jacobian`. If - `_forward_log_det_jacobian` is not implemented then the following code is - used: - - ```python - y = self.inverse(x, **kwargs) - return -self.inverse_log_det_jacobian(y, **kwargs) - ``` - - Args: - bijector: Bijector instance. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str`, name given to ops managed by this object. - """ - - if not bijector._is_injective: # pylint: disable=protected-access - raise NotImplementedError( - "Invert is not implemented for non-injective bijectors.") - - self._bijector = bijector - super(Invert, self).__init__( - event_ndims=bijector.event_ndims, - graph_parents=bijector.graph_parents, - is_constant_jacobian=bijector.is_constant_jacobian, - validate_args=validate_args, - dtype=bijector.dtype, - name=name or "_".join(["invert", bijector.name])) - - def _forward_event_shape(self, input_shape): - return self.bijector._inverse_event_shape(input_shape) # pylint: disable=protected-access - - def _forward_event_shape_tensor(self, input_shape): - return self.bijector._inverse_event_shape_tensor(input_shape) # pylint: disable=protected-access - - def _inverse_event_shape(self, output_shape): - return self.bijector._forward_event_shape(output_shape) # pylint: disable=protected-access - - def _inverse_event_shape_tensor(self, output_shape): - return self.bijector._forward_event_shape_tensor(output_shape) # pylint: disable=protected-access - - @property - def bijector(self): - return self._bijector - - def _forward(self, x, **kwargs): - return self.bijector._inverse(x, **kwargs) # pylint: disable=protected-access - - def _inverse(self, y, **kwargs): - return self.bijector._forward(y, **kwargs) # pylint: disable=protected-access - - def _inverse_log_det_jacobian(self, y, **kwargs): - return self.bijector._forward_log_det_jacobian(y, **kwargs) # pylint: disable=protected-access - - def _forward_log_det_jacobian(self, x, **kwargs): - return self.bijector._inverse_log_det_jacobian(x, **kwargs) # pylint: disable=protected-access diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py index 132dc570f94719b6c71fb269866c943774481b7e..06c7c61ec3dc3980e0d12a984739dca5a925ac9f 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py @@ -18,16 +18,459 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import numpy as np -_allowed_symbols = [ +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.layers import core as layers +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import template as template_ops +from tensorflow.python.ops import variable_scope as variable_scope_lib +from tensorflow.python.ops.distributions import bijector as bijector_lib + + +__all__ = [ "MaskedAutoregressiveFlow", - "masked_dense", "masked_autoregressive_default_template", + "masked_dense", ] -remove_undocumented(__name__, _allowed_symbols) + +class MaskedAutoregressiveFlow(bijector_lib.Bijector): + """Affine MaskedAutoregressiveFlow bijector for vector-valued events. + + The affine autoregressive flow [1] provides a relatively simple framework for + user-specified (deep) architectures to learn a distribution over vector-valued + events. Regarding terminology, + + "Autoregressive models decompose the joint density as a product of + conditionals, and model each conditional in turn. Normalizing flows + transform a base density (e.g. a standard Gaussian) into the target density + by an invertible transformation with tractable Jacobian." [1] + + In other words, the "autoregressive property" is equivalent to the + decomposition, `p(x) = prod{ p(x[i] | x[0:i]) : i=0, ..., d }`. The provided + `shift_and_log_scale_fn`, `masked_autoregressive_default_template`, achieves + this property by zeroing out weights in its `masked_dense` layers. + + In the `tf.distributions` framework, a "normalizing flow" is implemented as a + `tf.distributions.bijectors.Bijector`. The `forward` "autoregression" + is implemented using a `tf.while_loop` and a deep neural network (DNN) with + masked weights such that the autoregressive property is automatically met in + the `inverse`. + + A `TransformedDistribution` using `MaskedAutoregressiveFlow(...)` uses the + (expensive) forward-mode calculation to draw samples and the (cheap) + reverse-mode calculation to compute log-probabilities. Conversely, a + `TransformedDistribution` using `Invert(MaskedAutoregressiveFlow(...))` uses + the (expensive) forward-mode calculation to compute log-probabilities and the + (cheap) reverse-mode calculation to compute samples. See "Example Use" + [below] for more details. + + Given a `shift_and_log_scale_fn`, the forward and inverse transformations are + (a sequence of) affine transformations. A "valid" `shift_and_log_scale_fn` + must compute each `shift` (aka `loc` or "mu" [2]) and `log(scale)` (aka + "alpha" [2]) such that each are broadcastable with the arguments to `forward` + and `inverse`, i.e., such that the calculations in `forward`, `inverse` + [below] are possible. + + For convenience, `masked_autoregressive_default_template` is offered as a + possible `shift_and_log_scale_fn` function. It implements the MADE + architecture [2]. MADE is a feed-forward network that computes a `shift` and + `log(scale)` using `masked_dense` layers in a deep neural network. Weights are + masked to ensure the autoregressive property. It is possible that this + architecture is suboptimal for your task. To build alternative networks, + either change the arguments to `masked_autoregressive_default_template`, use + the `masked_dense` function to roll-out your own, or use some other + architecture, e.g., using `tf.layers`. + + Warning: no attempt is made to validate that the `shift_and_log_scale_fn` + enforces the "autoregressive property". + + Assuming `shift_and_log_scale_fn` has valid shape and autoregressive + semantics, the forward transformation is, + + ```python + def forward(x): + y = zeros_like(x) + event_size = x.shape[-1] + for _ in range(event_size): + shift, log_scale = shift_and_log_scale_fn(y) + y = x * math_ops.exp(log_scale) + shift + return y + ``` + + and the inverse transformation is, + + ```python + def inverse(y): + shift, log_scale = shift_and_log_scale_fn(y) + return (y - shift) / math_ops.exp(log_scale) + ``` + + Notice that the `inverse` does not need a for-loop. This is because in the + forward pass each calculation of `shift` and `log_scale` is based on the `y` + calculated so far (not `x`). In the `inverse`, the `y` is fully known, thus is + equivalent to the scaling used in `forward` after `event_size` passes, i.e., + the "last" `y` used to compute `shift`, `log_scale`. (Roughly speaking, this + also proves the transform is bijective.) + + #### Example Use + + ```python + tfd = tf.contrib.distributions + tfb = tfd.bijectors + + dims = 5 + + # A common choice for a normalizing flow is to use a Gaussian for the base + # distribution. (However, any continuous distribution would work.) E.g., + maf = tfd.TransformedDistribution( + distribution=tfd.Normal(loc=0., scale=1.), + bijector=tfb.MaskedAutoregressiveFlow( + shift_and_log_scale_fn=tfb.masked_autoregressive_default_template( + hidden_layers=[512, 512])), + event_shape=[dims]) + + x = maf.sample() # Expensive; uses `tf.while_loop`, no Bijector caching. + maf.log_prob(x) # Almost free; uses Bijector caching. + maf.log_prob(0.) # Cheap; no `tf.while_loop` despite no Bijector caching. + + # [1] also describes an "Inverse Autoregressive Flow", e.g., + iaf = tfd.TransformedDistribution( + distribution=tfd.Normal(loc=0., scale=1.), + bijector=tfb.Invert(tfb.MaskedAutoregressiveFlow( + shift_and_log_scale_fn=tfb.masked_autoregressive_default_template( + hidden_layers=[512, 512]))), + event_shape=[dims]) + + x = iaf.sample() # Cheap; no `tf.while_loop` despite no Bijector caching. + iaf.log_prob(x) # Almost free; uses Bijector caching. + iaf.log_prob(0.) # Expensive; uses `tf.while_loop`, no Bijector caching. + + # In many (if not most) cases the default `shift_and_log_scale_fn` will be a + # poor choice. Here's an example of using a "shift only" version and with a + # different number/depth of hidden layers. + shift_only = True + maf_no_scale_hidden2 = tfd.TransformedDistribution( + distribution=tfd.Normal(loc=0., scale=1.), + bijector=tfb.MaskedAutoregressiveFlow( + tfb.masked_autoregressive_default_template( + hidden_layers=[32], + shift_only=shift_only), + is_constant_jacobian=shift_only), + event_shape=[dims]) + ``` + + [1]: "Masked Autoregressive Flow for Density Estimation." + George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017. + https://arxiv.org/abs/1705.07057 + + [2]: "MADE: Masked Autoencoder for Distribution Estimation." + Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. + https://arxiv.org/abs/1502.03509 + + """ + + def __init__(self, + shift_and_log_scale_fn, + is_constant_jacobian=False, + validate_args=False, + name=None): + """Creates the MaskedAutoregressiveFlow bijector. + + Args: + shift_and_log_scale_fn: Python `callable` which computes `shift` and + `log_scale` from both the forward domain (`x`) and the inverse domain + (`y`). Calculation must respect the "autoregressive property" (see class + docstring). Suggested default + `masked_autoregressive_default_template(hidden_layers=...)`. + Typically the function contains `tf.Variables` and is wrapped using + `tf.make_template`. Returning `None` for either (both) `shift`, + `log_scale` is equivalent to (but more efficient than) returning zero. + is_constant_jacobian: Python `bool`. Default: `False`. When `True` the + implementation assumes `log_scale` does not depend on the forward domain + (`x`) or inverse domain (`y`) values. (No validation is made; + `is_constant_jacobian=False` is always safe but possibly computationally + inefficient.) + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str`, name given to ops managed by this object. + """ + name = name or "masked_autoregressive_flow" + self._shift_and_log_scale_fn = shift_and_log_scale_fn + super(MaskedAutoregressiveFlow, self).__init__( + is_constant_jacobian=is_constant_jacobian, + validate_args=validate_args, + name=name) + + def _forward(self, x): + event_size = array_ops.shape(x)[-1] + y0 = array_ops.zeros_like(x, name="y0") + # call the template once to ensure creation + _ = self._shift_and_log_scale_fn(y0) + def _loop_body(index, y0): + """While-loop body for autoregression calculation.""" + # Set caching device to avoid re-getting the tf.Variable for every while + # loop iteration. + with variable_scope_lib.variable_scope( + variable_scope_lib.get_variable_scope()) as vs: + if vs.caching_device is None: + vs.set_caching_device(lambda op: op.device) + shift, log_scale = self._shift_and_log_scale_fn(y0) + y = x + if log_scale is not None: + y *= math_ops.exp(log_scale) + if shift is not None: + y += shift + return index + 1, y + _, y = control_flow_ops.while_loop( + cond=lambda index, _: index < event_size, + body=_loop_body, + loop_vars=[0, y0]) + return y + + def _inverse(self, y): + shift, log_scale = self._shift_and_log_scale_fn(y) + x = y + if shift is not None: + x -= shift + if log_scale is not None: + x *= math_ops.exp(-log_scale) + return x + + def _inverse_log_det_jacobian(self, y): + _, log_scale = self._shift_and_log_scale_fn(y) + if log_scale is None: + return constant_op.constant(0., dtype=y.dtype, name="ildj") + return -math_ops.reduce_sum(log_scale, axis=-1) + + +MASK_INCLUSIVE = "inclusive" +MASK_EXCLUSIVE = "exclusive" + + +def _gen_slices(num_blocks, n_in, n_out, mask_type=MASK_EXCLUSIVE): + """Generate the slices for building an autoregressive mask.""" + # TODO(b/67594795): Better support of dynamic shape. + slices = [] + col = 0 + d_in = n_in // num_blocks + d_out = n_out // num_blocks + row = d_out if mask_type == MASK_EXCLUSIVE else 0 + for _ in range(num_blocks): + row_slice = slice(row, None) + col_slice = slice(col, col + d_in) + slices.append([row_slice, col_slice]) + col += d_in + row += d_out + return slices + + +def _gen_mask(num_blocks, + n_in, + n_out, + mask_type=MASK_EXCLUSIVE, + dtype=dtypes.float32): + """Generate the mask for building an autoregressive dense layer.""" + # TODO(b/67594795): Better support of dynamic shape. + mask = np.zeros([n_out, n_in], dtype=dtype.as_numpy_dtype()) + slices = _gen_slices(num_blocks, n_in, n_out, mask_type=mask_type) + for [row_slice, col_slice] in slices: + mask[row_slice, col_slice] = 1 + return mask + + +def masked_dense(inputs, + units, + num_blocks=None, + exclusive=False, + kernel_initializer=None, + reuse=None, + name=None, + *args, + **kwargs): + """A autoregressively masked dense layer. Analogous to `tf.layers.dense`. + + See [1] for detailed explanation. + + [1]: "MADE: Masked Autoencoder for Distribution Estimation." + Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. + https://arxiv.org/abs/1502.03509 + + Arguments: + inputs: Tensor input. + units: Python `int` scalar representing the dimensionality of the output + space. + num_blocks: Python `int` scalar representing the number of blocks for the + MADE masks. + exclusive: Python `bool` scalar representing whether to zero the diagonal of + the mask, used for the first layer of a MADE. + kernel_initializer: Initializer function for the weight matrix. + If `None` (default), weights are initialized using the + `tf.glorot_random_initializer`. + reuse: Python `bool` scalar representing whether to reuse the weights of a + previous layer by the same name. + name: Python `str` used to describe ops managed by this function. + *args: `tf.layers.dense` arguments. + **kwargs: `tf.layers.dense` keyword arguments. + + Returns: + Output tensor. + + Raises: + NotImplementedError: if rightmost dimension of `inputs` is unknown prior to + graph execution. + """ + # TODO(b/67594795): Better support of dynamic shape. + input_depth = inputs.shape.with_rank_at_least(1)[-1].value + if input_depth is None: + raise NotImplementedError( + "Rightmost dimension must be known prior to graph execution.") + + mask = _gen_mask(num_blocks, input_depth, units, + MASK_EXCLUSIVE if exclusive else MASK_INCLUSIVE).T + + if kernel_initializer is None: + kernel_initializer = init_ops.glorot_normal_initializer() + + def masked_initializer(shape, dtype=None, partition_info=None): + return mask * kernel_initializer(shape, dtype, partition_info) + + with ops.name_scope(name, "masked_dense", [inputs, units, num_blocks]): + layer = layers.Dense( + units, + kernel_initializer=masked_initializer, + kernel_constraint=lambda x: mask * x, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse, + *args, + **kwargs) + return layer.apply(inputs) + + +def masked_autoregressive_default_template( + hidden_layers, + shift_only=False, + activation=nn_ops.relu, + log_scale_min_clip=-5., + log_scale_max_clip=3., + log_scale_clip_gradient=False, + name=None, + *args, + **kwargs): + """Build the MADE Model [1]. + + This will be wrapped in a make_template to ensure the variables are only + created once. It takes the input and returns the `loc` ("mu" [1]) and + `log_scale` ("alpha" [1]) from the MADE network. + + Warning: This function uses `masked_dense` to create randomly initialized + `tf.Variables`. It is presumed that these will be fit, just as you would any + other neural architecture which uses `tf.layers.dense`. + + #### About Hidden Layers: + + Each element of `hidden_layers` should be greater than the `input_depth` + (i.e., `input_depth = tf.shape(input)[-1]` where `input` is the input to the + neural network). This is necessary to ensure the autoregressivity property. + + #### About Clipping: + + This function also optionally clips the `log_scale` (but possibly not its + gradient). This is useful because if `log_scale` is too small/large it might + underflow/overflow making it impossible for the `MaskedAutoregressiveFlow` + bijector to implement a bijection. Additionally, the `log_scale_clip_gradient` + `bool` indicates whether the gradient should also be clipped. The default does + not clip the gradient; this is useful because it still provides gradient + information (for fitting) yet solves the numerical stability problem. I.e., + `log_scale_clip_gradient = False` means + `grad[exp(clip(x))] = grad[x] exp(clip(x))` rather than the usual + `grad[clip(x)] exp(clip(x))`. + + [1]: "MADE: Masked Autoencoder for Distribution Estimation." + Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. + https://arxiv.org/abs/1502.03509 + + Arguments: + hidden_layers: Python `list`-like of non-negative integer, scalars + indicating the number of units in each hidden layer. Default: `[512, 512]. + shift_only: Python `bool` indicating if only the `shift` term shall be + computed. Default: `False`. + activation: Activation function (callable). Explicitly setting to `None` + implies a linear activation. + log_scale_min_clip: `float`-like scalar `Tensor`, or a `Tensor` with the + same shape as `log_scale`. The minimum value to clip by. Default: -5. + log_scale_max_clip: `float`-like scalar `Tensor`, or a `Tensor` with the + same shape as `log_scale`. The maximum value to clip by. Default: 3. + log_scale_clip_gradient: Python `bool` indicating that the gradient of + `tf.clip_by_value` should be preserved. Default: `False`. + name: A name for ops managed by this function. Default: + "masked_autoregressive_default_template". + *args: `tf.layers.dense` arguments. + **kwargs: `tf.layers.dense` keyword arguments. + + Returns: + shift: `Float`-like `Tensor` of shift terms (the "mu" in [2]). + log_scale: `Float`-like `Tensor` of log(scale) terms (the "alpha" in [2]). + + Raises: + NotImplementedError: if rightmost dimension of `inputs` is unknown prior to + graph execution. + """ + + with ops.name_scope(name, "masked_autoregressive_default_template", + values=[log_scale_min_clip, log_scale_max_clip]): + def _fn(x): + """MADE parameterized via `masked_autoregressive_default_template`.""" + # TODO(b/67594795): Better support of dynamic shape. + input_depth = x.shape.with_rank_at_least(1)[-1].value + if input_depth is None: + raise NotImplementedError( + "Rightmost dimension must be known prior to graph execution.") + input_shape = (np.int32(x.shape.as_list()) if x.shape.is_fully_defined() + else array_ops.shape(x)) + for i, units in enumerate(hidden_layers): + x = masked_dense( + inputs=x, + units=units, + num_blocks=input_depth, + exclusive=True if i == 0 else False, + activation=activation, + *args, + **kwargs) + x = masked_dense( + inputs=x, + units=(1 if shift_only else 2) * input_depth, + num_blocks=input_depth, + activation=None, + *args, + **kwargs) + if shift_only: + x = array_ops.reshape(x, shape=input_shape) + return x, None + x = array_ops.reshape( + x, shape=array_ops.concat([input_shape, [2]], axis=0)) + shift, log_scale = array_ops.unstack(x, num=2, axis=-1) + which_clip = (math_ops.clip_by_value if log_scale_clip_gradient + else _clip_by_value_preserve_grad) + log_scale = which_clip(log_scale, log_scale_min_clip, log_scale_max_clip) + return shift, log_scale + return template_ops.make_template( + "masked_autoregressive_default_template", _fn) + + +def _clip_by_value_preserve_grad(x, clip_value_min, clip_value_max, name=None): + """Clips input while leaving gradient unaltered.""" + with ops.name_scope(name, "clip_by_value_preserve_grad", + [x, clip_value_min, clip_value_max]): + clip_x = clip_ops.clip_by_value(x, clip_value_min, clip_value_max) + return x + array_ops.stop_gradient(clip_x - x) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive_impl.py deleted file mode 100644 index ae142883931274b594dbbafbe86bd71e75c621bc..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive_impl.py +++ /dev/null @@ -1,473 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""MaskedAutoregressiveFlow bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.layers import core as layers -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import clip_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import template as template_ops -from tensorflow.python.ops import variable_scope as variable_scope_lib -from tensorflow.python.ops.distributions import bijector as bijector_lib - - -__all__ = [ - "MaskedAutoregressiveFlow", - "masked_autoregressive_default_template", - "masked_dense", -] - - -class MaskedAutoregressiveFlow(bijector_lib.Bijector): - """Affine MaskedAutoregressiveFlow bijector for vector-valued events. - - The affine autoregressive flow [1] provides a relatively simple framework for - user-specified (deep) architectures to learn a distribution over vector-valued - events. Regarding terminology, - - "Autoregressive models decompose the joint density as a product of - conditionals, and model each conditional in turn. Normalizing flows - transform a base density (e.g. a standard Gaussian) into the target density - by an invertible transformation with tractable Jacobian." [1] - - In other words, the "autoregressive property" is equivalent to the - decomposition, `p(x) = prod{ p(x[i] | x[0:i]) : i=0, ..., d }`. The provided - `shift_and_log_scale_fn`, `masked_autoregressive_default_template`, achieves - this property by zeroing out weights in its `masked_dense` layers. - - In the `tf.distributions` framework, a "normalizing flow" is implemented as a - `tf.distributions.bijectors.Bijector`. The `forward` "autoregression" - is implemented using a `tf.while_loop` and a deep neural network (DNN) with - masked weights such that the autoregressive property is automatically met in - the `inverse`. - - A `TransformedDistribution` using `MaskedAutoregressiveFlow(...)` uses the - (expensive) forward-mode calculation to draw samples and the (cheap) - reverse-mode calculation to compute log-probabilities. Conversely, a - `TransformedDistribution` using `Invert(MaskedAutoregressiveFlow(...))` uses - the (expensive) forward-mode calculation to compute log-probabilities and the - (cheap) reverse-mode calculation to compute samples. See "Example Use" - [below] for more details. - - Given a `shift_and_log_scale_fn`, the forward and inverse transformations are - (a sequence of) affine transformations. A "valid" `shift_and_log_scale_fn` - must compute each `shift` (aka `loc` or "mu" [2]) and `log(scale)` (aka - "alpha" [2]) such that each are broadcastable with the arguments to `forward` - and `inverse`, i.e., such that the calculations in `forward`, `inverse` - [below] are possible. - - For convenience, `masked_autoregressive_default_template` is offered as a - possible `shift_and_log_scale_fn` function. It implements the MADE - architecture [2]. MADE is a feed-forward network that computes a `shift` and - `log(scale)` using `masked_dense` layers in a deep neural network. Weights are - masked to ensure the autoregressive property. It is possible that this - architecture is suboptimal for your task. To build alternative networks, - either change the arguments to `masked_autoregressive_default_template`, use - the `masked_dense` function to roll-out your own, or use some other - architecture, e.g., using `tf.layers`. - - Warning: no attempt is made to validate that the `shift_and_log_scale_fn` - enforces the "autoregressive property". - - Assuming `shift_and_log_scale_fn` has valid shape and autoregressive - semantics, the forward transformation is, - - ```python - def forward(x): - y = zeros_like(x) - event_size = x.shape[-1] - for _ in range(event_size): - shift, log_scale = shift_and_log_scale_fn(y) - y = x * math_ops.exp(log_scale) + shift - return y - ``` - - and the inverse transformation is, - - ```python - def inverse(y): - shift, log_scale = shift_and_log_scale_fn(y) - return (y - shift) / math_ops.exp(log_scale) - ``` - - Notice that the `inverse` does not need a for-loop. This is because in the - forward pass each calculation of `shift` and `log_scale` is based on the `y` - calculated so far (not `x`). In the `inverse`, the `y` is fully known, thus is - equivalent to the scaling used in `forward` after `event_size` passes, i.e., - the "last" `y` used to compute `shift`, `log_scale`. (Roughly speaking, this - also proves the transform is bijective.) - - #### Example Use - - ```python - ds = tf.contrib.distributions - bs = tf.contrib.distributions.bijectors - - dims = 5 - - # A common choice for a normalizing flow is to use a Gaussian for the base - # distribution. (However, any continuous distribution would work.) E.g., - maf = ds.TransformedDistribution( - distribution=ds.Normal(loc=0., scale=1.), - bijector=bs.MaskedAutoregressiveFlow( - shift_and_log_scale_fn=bs.masked_autoregressive_default_template( - hidden_layers=[512, 512])), - event_shape=[dims]) - - x = maf.sample() # Expensive; uses `tf.while_loop`, no Bijector caching. - maf.log_prob(x) # Almost free; uses Bijector caching. - maf.log_prob(0.) # Cheap; no `tf.while_loop` despite no Bijector caching. - - # [1] also describes an "Inverse Autoregressive Flow", e.g., - iaf = ds.TransformedDistribution( - distribution=ds.Normal(loc=0., scale=1.), - bijector=bs.Invert(bs.MaskedAutoregressiveFlow( - shift_and_log_scale_fn=bs.masked_autoregressive_default_template( - hidden_layers=[512, 512]))), - event_shape=[dims]) - - x = iaf.sample() # Cheap; no `tf.while_loop` despite no Bijector caching. - iaf.log_prob(x) # Almost free; uses Bijector caching. - iaf.log_prob(0.) # Expensive; uses `tf.while_loop`, no Bijector caching. - - # In many (if not most) cases the default `shift_and_log_scale_fn` will be a - # poor choice. Here's an example of using a "shift only" version and with a - # different number/depth of hidden layers. - shift_only = True - maf_no_scale_hidden2 = ds.TransformedDistribution( - distribution=ds.Normal(loc=0., scale=1.), - bijector=bs.MaskedAutoregressiveFlow( - bs.masked_autoregressive_default_template( - hidden_layers=[32], - shift_only=shift_only), - is_constant_jacobian=shift_only), - event_shape=[dims]) - ``` - - [1]: "Masked Autoregressive Flow for Density Estimation." - George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017. - https://arxiv.org/abs/1705.07057 - - [2]: "MADE: Masked Autoencoder for Distribution Estimation." - Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. - https://arxiv.org/abs/1502.03509 - - """ - - def __init__(self, - shift_and_log_scale_fn, - is_constant_jacobian=False, - validate_args=False, - name=None): - """Creates the MaskedAutoregressiveFlow bijector. - - Args: - shift_and_log_scale_fn: Python `callable` which computes `shift` and - `log_scale` from both the forward domain (`x`) and the inverse domain - (`y`). Calculation must respect the "autoregressive property" (see class - docstring). Suggested default - `masked_autoregressive_default_template(hidden_layers=...)`. - Typically the function contains `tf.Variables` and is wrapped using - `tf.make_template`. Returning `None` for either (both) `shift`, - `log_scale` is equivalent to (but more efficient than) returning zero. - is_constant_jacobian: Python `bool`. Default: `False`. When `True` the - implementation assumes `log_scale` does not depend on the forward domain - (`x`) or inverse domain (`y`) values. (No validation is made; - `is_constant_jacobian=False` is always safe but possibly computationally - inefficient.) - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str`, name given to ops managed by this object. - """ - name = name or "masked_autoregressive_flow" - self._shift_and_log_scale_fn = shift_and_log_scale_fn - super(MaskedAutoregressiveFlow, self).__init__( - is_constant_jacobian=is_constant_jacobian, - validate_args=validate_args, - name=name) - - def _forward(self, x): - event_size = array_ops.shape(x)[-1] - def _loop_body(index, y0): - """While-loop body for autoregression calculation.""" - # Set caching device to avoid re-getting the tf.Variable for every while - # loop iteration. - with variable_scope_lib.variable_scope( - variable_scope_lib.get_variable_scope()) as vs: - if vs.caching_device is None: - vs.set_caching_device(lambda op: op.device) - shift, log_scale = self._shift_and_log_scale_fn(y0) - y = x - if log_scale is not None: - y *= math_ops.exp(log_scale) - if shift is not None: - y += shift - return index + 1, y - _, y = control_flow_ops.while_loop( - cond=lambda index, _: index < event_size, - body=_loop_body, - loop_vars=[0, array_ops.zeros_like(x, name="y0")]) - return y - - def _inverse(self, y): - shift, log_scale = self._shift_and_log_scale_fn(y) - x = y - if shift is not None: - x -= shift - if log_scale is not None: - x *= math_ops.exp(-log_scale) - return x - - def _inverse_log_det_jacobian(self, y): - _, log_scale = self._shift_and_log_scale_fn(y) - if log_scale is None: - return constant_op.constant(0., dtype=y.dtype, name="ildj") - return -math_ops.reduce_sum(log_scale, axis=-1) - - -MASK_INCLUSIVE = "inclusive" -MASK_EXCLUSIVE = "exclusive" - - -def _gen_slices(num_blocks, n_in, n_out, mask_type=MASK_EXCLUSIVE): - """Generate the slices for building an autoregressive mask.""" - # TODO(b/67594795): Better support of dynamic shape. - slices = [] - col = 0 - d_in = n_in // num_blocks - d_out = n_out // num_blocks - row = d_out if mask_type == MASK_EXCLUSIVE else 0 - for _ in range(num_blocks): - row_slice = slice(row, None) - col_slice = slice(col, col + d_in) - slices.append([row_slice, col_slice]) - col += d_in - row += d_out - return slices - - -def _gen_mask(num_blocks, - n_in, - n_out, - mask_type=MASK_EXCLUSIVE, - dtype=dtypes.float32): - """Generate the mask for building an autoregressive dense layer.""" - # TODO(b/67594795): Better support of dynamic shape. - mask = np.zeros([n_out, n_in], dtype=dtype.as_numpy_dtype()) - slices = _gen_slices(num_blocks, n_in, n_out, mask_type=mask_type) - for [row_slice, col_slice] in slices: - mask[row_slice, col_slice] = 1 - return mask - - -def masked_dense(inputs, - units, - num_blocks=None, - exclusive=False, - kernel_initializer=None, - reuse=None, - name=None, - *args, - **kwargs): - """A autoregressively masked dense layer. Analogous to `tf.layers.dense`. - - See [1] for detailed explanation. - - [1]: "MADE: Masked Autoencoder for Distribution Estimation." - Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. - https://arxiv.org/abs/1502.03509 - - Arguments: - inputs: Tensor input. - units: Python `int` scalar representing the dimensionality of the output - space. - num_blocks: Python `int` scalar representing the number of blocks for the - MADE masks. - exclusive: Python `bool` scalar representing whether to zero the diagonal of - the mask, used for the first layer of a MADE. - kernel_initializer: Initializer function for the weight matrix. - If `None` (default), weights are initialized using the - `tf.glorot_random_initializer`. - reuse: Python `bool` scalar representing whether to reuse the weights of a - previous layer by the same name. - name: Python `str` used to describe ops managed by this function. - *args: `tf.layers.dense` arguments. - **kwargs: `tf.layers.dense` keyword arguments. - - Returns: - Output tensor. - - Raises: - NotImplementedError: if rightmost dimension of `inputs` is unknown prior to - graph execution. - """ - # TODO(b/67594795): Better support of dynamic shape. - input_depth = inputs.shape.with_rank_at_least(1)[-1].value - if input_depth is None: - raise NotImplementedError( - "Rightmost dimension must be known prior to graph execution.") - - mask = _gen_mask(num_blocks, input_depth, units, - MASK_EXCLUSIVE if exclusive else MASK_INCLUSIVE).T - - if kernel_initializer is None: - kernel_initializer = init_ops.glorot_normal_initializer() - - def masked_initializer(shape, dtype=None, partition_info=None): - return mask * kernel_initializer(shape, dtype, partition_info) - - with ops.name_scope(name, "masked_dense", [inputs, units, num_blocks]): - layer = layers.Dense( - units, - kernel_initializer=masked_initializer, - kernel_constraint=lambda x: mask * x, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse, - *args, - **kwargs) - return layer.apply(inputs) - - -def masked_autoregressive_default_template( - hidden_layers, - shift_only=False, - activation=nn_ops.relu, - log_scale_min_clip=-5., - log_scale_max_clip=3., - log_scale_clip_gradient=False, - name=None, - *args, - **kwargs): - """Build the MADE Model [1]. - - This will be wrapped in a make_template to ensure the variables are only - created once. It takes the input and returns the `loc` ("mu" [1]) and - `log_scale` ("alpha" [1]) from the MADE network. - - Warning: This function uses `masked_dense` to create randomly initialized - `tf.Variables`. It is presumed that these will be fit, just as you would any - other neural architecture which uses `tf.layers.dense`. - - #### About Hidden Layers: - - Each element of `hidden_layers` should be greater than the `input_depth` - (i.e., `input_depth = tf.shape(input)[-1]` where `input` is the input to the - neural network). This is necessary to ensure the autoregressivity property. - - #### About Clipping: - - This function also optionally clips the `log_scale` (but possibly not its - gradient). This is useful because if `log_scale` is too small/large it might - underflow/overflow making it impossible for the `MaskedAutoregressiveFlow` - bijector to implement a bijection. Additionally, the `log_scale_clip_gradient` - `bool` indicates whether the gradient should also be clipped. The default does - not clip the gradient; this is useful because it still provides gradient - information (for fitting) yet solves the numerical stability problem. I.e., - `log_scale_clip_gradient = False` means - `grad[exp(clip(x))] = grad[x] exp(clip(x))` rather than the usual - `grad[clip(x)] exp(clip(x))`. - - [1]: "MADE: Masked Autoencoder for Distribution Estimation." - Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. - https://arxiv.org/abs/1502.03509 - - Arguments: - hidden_layers: Python `list`-like of non-negative integer, scalars - indicating the number of units in each hidden layer. Default: `[512, 512]. - shift_only: Python `bool` indicating if only the `shift` term shall be - computed. Default: `False`. - activation: Activation function (callable). Explicitly setting to `None` - implies a linear activation. - log_scale_min_clip: `float`-like scalar `Tensor`, or a `Tensor` with the - same shape as `log_scale`. The minimum value to clip by. Default: -5. - log_scale_max_clip: `float`-like scalar `Tensor`, or a `Tensor` with the - same shape as `log_scale`. The maximum value to clip by. Default: 3. - log_scale_clip_gradient: Python `bool` indicating that the gradient of - `tf.clip_by_value` should be preserved. Default: `False`. - name: A name for ops managed by this function. Default: - "masked_autoregressive_default_template". - *args: `tf.layers.dense` arguments. - **kwargs: `tf.layers.dense` keyword arguments. - - Returns: - shift: `Float`-like `Tensor` of shift terms (the "mu" in [2]). - log_scale: `Float`-like `Tensor` of log(scale) terms (the "alpha" in [2]). - - Raises: - NotImplementedError: if rightmost dimension of `inputs` is unknown prior to - graph execution. - """ - - with ops.name_scope(name, "masked_autoregressive_default_template", - values=[log_scale_min_clip, log_scale_max_clip]): - def _fn(x): - """MADE parameterized via `masked_autoregressive_default_template`.""" - # TODO(b/67594795): Better support of dynamic shape. - input_depth = x.shape.with_rank_at_least(1)[-1].value - if input_depth is None: - raise NotImplementedError( - "Rightmost dimension must be known prior to graph execution.") - input_shape = (np.int32(x.shape.as_list()) if x.shape.is_fully_defined() - else array_ops.shape(x)) - for i, units in enumerate(hidden_layers): - x = masked_dense( - inputs=x, - units=units, - num_blocks=input_depth, - exclusive=True if i == 0 else False, - activation=activation, - *args, - **kwargs) - x = masked_dense( - inputs=x, - units=(1 if shift_only else 2) * input_depth, - num_blocks=input_depth, - activation=None, - *args, - **kwargs) - if shift_only: - x = array_ops.reshape(x, shape=input_shape) - return x, None - x = array_ops.reshape( - x, shape=array_ops.concat([input_shape, [2]], axis=0)) - shift, log_scale = array_ops.unstack(x, num=2, axis=-1) - which_clip = (math_ops.clip_by_value if log_scale_clip_gradient - else _clip_by_value_preserve_grad) - log_scale = which_clip(log_scale, log_scale_min_clip, log_scale_max_clip) - return shift, log_scale - return template_ops.make_template( - "masked_autoregressive_default_template", _fn) - - -def _clip_by_value_preserve_grad(x, clip_value_min, clip_value_max, name=None): - """Clips input while leaving gradient unaltered.""" - with ops.name_scope(name, "clip_by_value_preserve_grad", - [x, clip_value_min, clip_value_max]): - clip_x = clip_ops.clip_by_value(x, clip_value_min, clip_value_max) - return x + array_ops.stop_gradient(clip_x - x) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py index a187ce22d686ee1203802ae2bfe64b0e1a3ea850..8654cc39d0c41ec4f1b85cd5fc4366ceaf4b224d 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py @@ -12,18 +12,127 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Permute bijector.""" +"""Permutation bijectors.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.permute_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import numpy as np -_allowed_symbols = ["Permute"] +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops.distributions import bijector as bijector_lib -remove_undocumented(__name__, _allowed_symbols) + +__all__ = [ + "Permute", +] + + +class Permute(bijector_lib.Bijector): + """Permutes the rightmost dimension of a `Tensor`. + + ```python + tfd = tf.contrib.distributions + + reverse = tfd.bijectors.Permute(permutation=[2, 1, 0]) + + reverse.forward([-1., 0., 1.]) + # ==> [1., 0., -1] + + reverse.inverse([1., 0., -1]) + # ==> [-1., 0., 1.] + + reverse.forward_log_det_jacobian(any_value) + # ==> 0. + + reverse.inverse_log_det_jacobian(any_value) + # ==> 0. + ``` + + Warning: `tf.estimator` may repeatedly build the graph thus + `Permute(np.random.permutation(event_size)).astype("int32"))` is not a + reliable parameterization (nor would it be even if using `tf.constant`). A + safe alternative is to use `tf.get_variable` to achieve "init once" behavior, + i.e., + + ```python + def init_once(x, name): + return tf.get_variable(name, initializer=x, trainable=False) + + Permute(permutation=init_once( + np.random.permutation(event_size).astype("int32"), + name="permutation")) + ``` + + """ + + def __init__(self, permutation, validate_args=False, name=None): + """Creates the `Permute` bijector. + + Args: + permutation: An `int`-like vector-shaped `Tensor` representing the + permutation to apply to the rightmost dimension of the transformed + `Tensor`. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str`, name given to ops managed by this object. + + Raises: + TypeError: if `not permutation.dtype.is_integer`. + ValueError: if `permutation` does not contain exactly one of each of + `{0, 1, ..., d}`. + """ + with ops.name_scope(name, "permute", values=[permutation]): + permutation = ops.convert_to_tensor( + permutation, + name="permutation") + if not permutation.dtype.is_integer: + raise TypeError("permutation.dtype ({}) should be `int`-like.".format( + permutation.dtype.name)) + p = tensor_util.constant_value(permutation) + if p is not None: + if set(p) != set(np.arange(p.size)): + raise ValueError("Permutation over `d` must contain exactly one of " + "each of `{0, 1, ..., d}`.") + elif validate_args: + p, _ = nn_ops.top_k(-permutation, + k=array_ops.shape(permutation)[-1], + sorted=True) + permutation = control_flow_ops.with_dependencies([ + check_ops.assert_equal( + -p, math_ops.range(array_ops.size(p)), + message=("Permutation over `d` must contain exactly one of " + "each of `{0, 1, ..., d}`.")), + ], permutation) + self._permutation = permutation + super(Permute, self).__init__( + is_constant_jacobian=True, + validate_args=validate_args, + name=name or "permute") + + @property + def permutation(self): + return self._permutation + + def _forward(self, x): + return array_ops.gather(x, self.permutation, axis=-1) + + def _inverse(self, y): + return array_ops.gather( + y, + array_ops.invert_permutation(self.permutation), + axis=-1) + + def _inverse_log_det_jacobian(self, y): + return constant_op.constant(0., dtype=y.dtype) + + def _forward_log_det_jacobian(self, x): + return constant_op.constant(0., dtype=x.dtype) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/permute_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/permute_impl.py deleted file mode 100644 index b1d8f2f41b28a88208a19824377f93882b767f03..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/permute_impl.py +++ /dev/null @@ -1,138 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Permutation bijectors.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops.distributions import bijector as bijector_lib - - -__all__ = [ - "Permute", -] - - -class Permute(bijector_lib.Bijector): - """Permutes the rightmost dimension of a `Tensor`. - - ```python - bs = tf.contrib.distributions.bijectors - - reverse = bs.Permute(permutation=[2, 1, 0]) - - reverse.forward([-1., 0., 1.]) - # ==> [1., 0., -1] - - reverse.inverse([1., 0., -1]) - # ==> [-1., 0., 1.] - - reverse.forward_log_det_jacobian(any_value) - # ==> 0. - - reverse.inverse_log_det_jacobian(any_value) - # ==> 0. - ``` - - Warning: `tf.estimator` may repeatedly build the graph thus - `Permute(np.random.permutation(event_size)).astype("int32"))` is not a - reliable parameterization (nor would it be even if using `tf.constant`). A - safe alternative is to use `tf.get_variable` to achieve "init once" behavior, - i.e., - - ```python - def init_once(x, name): - return tf.get_variable(name, initializer=x, trainable=False) - - Permute(permutation=init_once( - np.random.permutation(event_size).astype("int32"), - name="permutation")) - ``` - - """ - - def __init__(self, permutation, validate_args=False, name=None): - """Creates the `Permute` bijector. - - Args: - permutation: An `int`-like vector-shaped `Tensor` representing the - permutation to apply to the rightmost dimension of the transformed - `Tensor`. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str`, name given to ops managed by this object. - - Raises: - TypeError: if `not permutation.dtype.is_integer`. - ValueError: if `permutation` does not contain exactly one of each of - `{0, 1, ..., d}`. - """ - with ops.name_scope(name, "permute", values=[permutation]): - permutation = ops.convert_to_tensor( - permutation, - name="permutation") - if not permutation.dtype.is_integer: - raise TypeError("permutation.dtype ({}) should be `int`-like.".format( - permutation.dtype.name)) - p = tensor_util.constant_value(permutation) - if p is not None: - if set(p) != set(np.arange(p.size)): - raise ValueError("Permutation over `d` must contain exactly one of " - "each of `{0, 1, ..., d}`.") - elif validate_args: - p, _ = nn_ops.top_k(-permutation, - k=array_ops.shape(permutation)[-1], - sorted=True) - permutation = control_flow_ops.with_dependencies([ - check_ops.assert_equal( - -p, math_ops.range(array_ops.size(p)), - message=("Permutation over `d` must contain exactly one of " - "each of `{0, 1, ..., d}`.")), - ], permutation) - self._permutation = permutation - super(Permute, self).__init__( - is_constant_jacobian=True, - validate_args=validate_args, - name=name or "permute") - - @property - def permutation(self): - return self._permutation - - def _forward(self, x): - return array_ops.gather(x, self.permutation, axis=-1) - - def _inverse(self, y): - return array_ops.gather( - y, - array_ops.invert_permutation(self.permutation), - axis=-1) - - def _inverse_log_det_jacobian(self, y): - return constant_op.constant(0., dtype=y.dtype) - - def _forward_log_det_jacobian(self, x): - return constant_op.constant(0., dtype=x.dtype) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py b/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py index a83199549cd16101ab7b39b43d19a17bc66f03df..c37db61720d10949f294ff7b2e9778ba6efa57f0 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py @@ -18,12 +18,110 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.power_transform_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["PowerTransform"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "PowerTransform", +] + + +class PowerTransform(bijector.Bijector): + """Compute `Y = g(X) = (1 + X * c)**(1 / c), X >= -1 / c`. + + The [power transform](https://en.wikipedia.org/wiki/Power_transform) maps + inputs from `[0, inf]` to `[-1/c, inf]`; this is equivalent to the `inverse` + of this bijector. + + This bijector is equivalent to the `Exp` bijector when `c=0`. + """ + + def __init__(self, + power=0., + event_ndims=0, + validate_args=False, + name="power_transform"): + """Instantiates the `PowerTransform` bijector. + + Args: + power: Python `float` scalar indicating the transform power, i.e., + `Y = g(X) = (1 + X * c)**(1 / c)` where `c` is the `power`. + event_ndims: Python scalar indicating the number of dimensions associated + with a particular draw from the distribution. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + + Raises: + ValueError: if `power < 0` or is not known statically. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + with self._name_scope("init", values=[power]): + power = tensor_util.constant_value( + ops.convert_to_tensor(power, name="power")) + if power is None or power < 0: + raise ValueError("`power` must be a non-negative TF constant.") + self._power = power + super(PowerTransform, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) + + @property + def power(self): + """The `c` in: `Y = g(X) = (1 + X * c)**(1 / c)`.""" + return self._power + + def _forward(self, x): + x = self._maybe_assert_valid_x(x) + if self.power == 0.: + return math_ops.exp(x) + # If large x accuracy is an issue, consider using: + # (1. + x * self.power)**(1. / self.power) when x >> 1. + return math_ops.exp(math_ops.log1p(x * self.power) / self.power) + + def _inverse(self, y): + y = self._maybe_assert_valid_y(y) + if self.power == 0.: + return math_ops.log(y) + # If large y accuracy is an issue, consider using: + # (y**self.power - 1.) / self.power when y >> 1. + return math_ops.expm1(math_ops.log(y) * self.power) / self.power + + def _inverse_log_det_jacobian(self, y): + y = self._maybe_assert_valid_y(y) + event_dims = self._event_dims_tensor(y) + return (self.power - 1.) * math_ops.reduce_sum( + math_ops.log(y), axis=event_dims) + + def _forward_log_det_jacobian(self, x): + x = self._maybe_assert_valid_x(x) + event_dims = self._event_dims_tensor(x) + if self.power == 0.: + return math_ops.reduce_sum(x, axis=event_dims) + return (1. / self.power - 1.) * math_ops.reduce_sum( + math_ops.log1p(x * self.power), + axis=event_dims) + + def _maybe_assert_valid_x(self, x): + if not self.validate_args or self.power == 0.: + return x + is_valid = check_ops.assert_non_negative( + 1. + self.power * x, + message="Forward transformation input must be at least {}.".format( + -1. / self.power)) + return control_flow_ops.with_dependencies([is_valid], x) + + def _maybe_assert_valid_y(self, y): + if not self.validate_args: + return y + is_valid = check_ops.assert_positive( + y, message="Inverse transformation input must be greater than 0.") + return control_flow_ops.with_dependencies([is_valid], y) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/power_transform_impl.py deleted file mode 100644 index c37db61720d10949f294ff7b2e9778ba6efa57f0..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform_impl.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""PowerTransform bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector - - -__all__ = [ - "PowerTransform", -] - - -class PowerTransform(bijector.Bijector): - """Compute `Y = g(X) = (1 + X * c)**(1 / c), X >= -1 / c`. - - The [power transform](https://en.wikipedia.org/wiki/Power_transform) maps - inputs from `[0, inf]` to `[-1/c, inf]`; this is equivalent to the `inverse` - of this bijector. - - This bijector is equivalent to the `Exp` bijector when `c=0`. - """ - - def __init__(self, - power=0., - event_ndims=0, - validate_args=False, - name="power_transform"): - """Instantiates the `PowerTransform` bijector. - - Args: - power: Python `float` scalar indicating the transform power, i.e., - `Y = g(X) = (1 + X * c)**(1 / c)` where `c` is the `power`. - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - - Raises: - ValueError: if `power < 0` or is not known statically. - """ - self._graph_parents = [] - self._name = name - self._validate_args = validate_args - with self._name_scope("init", values=[power]): - power = tensor_util.constant_value( - ops.convert_to_tensor(power, name="power")) - if power is None or power < 0: - raise ValueError("`power` must be a non-negative TF constant.") - self._power = power - super(PowerTransform, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) - - @property - def power(self): - """The `c` in: `Y = g(X) = (1 + X * c)**(1 / c)`.""" - return self._power - - def _forward(self, x): - x = self._maybe_assert_valid_x(x) - if self.power == 0.: - return math_ops.exp(x) - # If large x accuracy is an issue, consider using: - # (1. + x * self.power)**(1. / self.power) when x >> 1. - return math_ops.exp(math_ops.log1p(x * self.power) / self.power) - - def _inverse(self, y): - y = self._maybe_assert_valid_y(y) - if self.power == 0.: - return math_ops.log(y) - # If large y accuracy is an issue, consider using: - # (y**self.power - 1.) / self.power when y >> 1. - return math_ops.expm1(math_ops.log(y) * self.power) / self.power - - def _inverse_log_det_jacobian(self, y): - y = self._maybe_assert_valid_y(y) - event_dims = self._event_dims_tensor(y) - return (self.power - 1.) * math_ops.reduce_sum( - math_ops.log(y), axis=event_dims) - - def _forward_log_det_jacobian(self, x): - x = self._maybe_assert_valid_x(x) - event_dims = self._event_dims_tensor(x) - if self.power == 0.: - return math_ops.reduce_sum(x, axis=event_dims) - return (1. / self.power - 1.) * math_ops.reduce_sum( - math_ops.log1p(x * self.power), - axis=event_dims) - - def _maybe_assert_valid_x(self, x): - if not self.validate_args or self.power == 0.: - return x - is_valid = check_ops.assert_non_negative( - 1. + self.power * x, - message="Forward transformation input must be at least {}.".format( - -1. / self.power)) - return control_flow_ops.with_dependencies([is_valid], x) - - def _maybe_assert_valid_y(self, y): - if not self.validate_args: - return y - is_valid = check_ops.assert_positive( - y, message="Inverse transformation input must be greater than 0.") - return control_flow_ops.with_dependencies([is_valid], y) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py index 8997f7ab6929745275edb38712a5bbb0a9b25ddb..55eca063126797d577653f0d6bcdfddf8192bdb5 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py @@ -12,18 +12,303 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Reshape bijector.""" +"""Reshape bijectors.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.reshape_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import numpy as np -_allowed_symbols = ["Reshape"] +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector as bijector_lib -remove_undocumented(__name__, _allowed_symbols) + +__all__ = [ + "Reshape", +] + + +def _static_ndims_from_shape(shape): + return shape.shape.with_rank_at_least(1)[0].value + + +def _ndims_from_shape(shape): + return array_ops.shape(shape)[0] + + +class Reshape(bijector_lib.Bijector): + """Reshapes the `event_shape` of a `Tensor`. + + The semantics generally follow that of `tf.reshape()`, with + a few differences: + + * The user must provide both the input and output shape, so that + the transformation can be inverted. If an input shape is not + specified, the default assumes a vector-shaped input, i.e., + event_shape_in = (-1,). + * The `Reshape` bijector automatically broadcasts over the leftmost + dimensions of its input (`sample_shape` and `batch_shape`); only + the rightmost `event_ndims_in` dimensions are reshaped. The + number of dimensions to reshape is inferred from the provided + `event_shape_in` (`event_ndims_in = len(event_shape_in)`). + + Example usage: + ```python + + tfd = tf.contrib.distributions + + r = tfd.bijectors.Reshape(event_shape_out=[1, -1]) + + r.forward([3., 4.]) # shape [2] + # ==> [[3., 4.]] # shape [1, 2] + + r.forward([[1., 2.], [3., 4.]]) # shape [2, 2] + # ==> [[[1., 2.]], + # [[3., 4.]]] # shape [2, 1, 2] + + r.inverse([[3., 4.]]) # shape [1,2] + # ==> [3., 4.] # shape [2] + + r.forward_log_det_jacobian(any_value) + # ==> 0. + + r.inverse_log_det_jacobian(any_value) + # ==> 0. + ``` + + """ + + def __init__(self, event_shape_out, event_shape_in=(-1,), + validate_args=False, name=None): + """Creates a `Reshape` bijector. + + Args: + event_shape_out: An `int`-like vector-shaped `Tensor` + representing the event shape of the transformed output. + event_shape_in: An optional `int`-like vector-shape `Tensor` + representing the event shape of the input. This is required in + order to define inverse operations; the default of (-1,) + assumes a vector-shaped input. + validate_args: Python `bool` indicating whether arguments should + be checked for correctness. + name: Python `str`, name given to ops managed by this object. + + Raises: + TypeError: if either `event_shape_in` or `event_shape_out` has + non-integer `dtype`. + ValueError: if either of `event_shape_in` or `event_shape_out` + has non-vector shape (`rank > 1`), or if their sizes do not + match. + """ + with ops.name_scope(name, "reshape", + values=[event_shape_out, event_shape_in]): + + event_shape_out = ops.convert_to_tensor(event_shape_out, + name="event_shape_out", + preferred_dtype=dtypes.int32) + event_shape_in = ops.convert_to_tensor(event_shape_in, + name="event_shape_in", + preferred_dtype=dtypes.int32) + + assertions = [] + assertions.extend(self._maybe_check_valid_shape( + event_shape_out, validate_args)) + assertions.extend(self._maybe_check_valid_shape( + event_shape_in, validate_args)) + + self._assertions = assertions + self._event_shape_in = event_shape_in + self._event_shape_out = event_shape_out + + super(Reshape, self).__init__(is_constant_jacobian=True, + validate_args=validate_args, + name=name or "reshape") + + def _maybe_check_valid_shape(self, shape, validate_args): + """Check that a shape Tensor is int-type and otherwise sane.""" + if not shape.dtype.is_integer: + raise TypeError("{} dtype ({}) should be `int`-like.".format( + shape.op.name, shape.dtype.name)) + + assertions = [] + + ndims = array_ops.rank(shape) + ndims_ = tensor_util.constant_value(ndims) + if ndims_ is not None and ndims_ > 1: + raise ValueError("`{}` rank ({}) should be <= 1.".format( + shape.op.name, ndims_)) + elif validate_args: + assertions.append(check_ops.assert_less_equal( + ndims, 1, message="`{}` rank should be <= 1.".format(shape.op.name))) + + shape_ = tensor_util.constant_value_as_shape(shape) + if shape_.is_fully_defined(): + es = np.int32(shape_.as_list()) + if sum(es == -1) > 1: + raise ValueError( + "`{}` must have at most one `-1` (given {})" + .format(shape.op.name, es)) + if np.any(es < -1): + raise ValueError( + "`{}` elements must be either positive integers or `-1`" + "(given {})." + .format(shape.op.name, es)) + elif validate_args: + assertions.extend([ + check_ops.assert_less_equal( + math_ops.reduce_sum( + math_ops.cast(math_ops.equal(shape, -1), dtypes.int32)), + 1, + message="`{}` elements must have at most one `-1`." + .format(shape.op.name)), + check_ops.assert_greater_equal( + shape, -1, + message="`{}` elements must be either positive integers or `-1`." + .format(shape.op.name)), + ]) + return assertions + + def _reshape_helper(self, x, event_shape_in, event_shape_out): + """Reshape only the event_shape of an input `Tensor`.""" + + event_ndims_in_ = _static_ndims_from_shape(event_shape_in) + event_ndims_in = _ndims_from_shape(event_shape_in) + x_ndims_, x_ndims = x.shape.ndims, array_ops.rank(x) + + assertions = [] + + # Ensure x.event_shape is compatible with event_shape_in. + if (event_ndims_in_ is not None + and x_ndims_ is not None + and x.shape.with_rank_at_least(event_ndims_in_)[ + x_ndims_-event_ndims_in_:].is_fully_defined()): + x_event_shape_, x_event_shape = [ # pylint: disable=unbalanced-tuple-unpacking + np.int32(x.shape[x_ndims_-event_ndims_in_:])]*2 + else: + x_event_shape_, x_event_shape = ( + None, array_ops.shape(x)[x_ndims-event_ndims_in:]) + + event_shape_in_ = tensor_util.constant_value(event_shape_in) + + if x_event_shape_ is not None and event_shape_in_ is not None: + # Compare the shape dimensions that are fully specified in the + # input (i.e., for which event_shape_in is not -1). If x_event_shape + # matches along all of these dimensions, it is compatible with + # the desired input shape and any further mismatches (i.e., + # imcompatibility with the desired *output* shape) will be + # caught inside of array_ops.reshape() below. + x_event_shape_specified_ = x_event_shape_[event_shape_in_ >= 0] + event_shape_in_specified_ = event_shape_in_[event_shape_in_ >= 0] + if not np.equal(x_event_shape_specified_, + event_shape_in_specified_).all(): + raise ValueError( + "Input `event_shape` does not match `event_shape_in` ({} vs {}).". + format(x_event_shape_, event_shape_in_)) + elif self.validate_args: + # Similarly to the static case, we compare the shape dimensions + # that are fully specified in the input. We extract these + # dimensions using boolean_mask(), which requires that the mask + # have known ndims. We can assume that shape Tensors always have + # ndims==1 (this assumption is verified inside of + # _maybe_check_valid_shape), so the reshape operation is just a + # no-op that formally encodes this fact to make boolean_mask() + # happy. + event_shape_mask = array_ops.reshape(event_shape_in >= 0, [-1]) + x_event_shape_specified = array_ops.boolean_mask(x_event_shape, + event_shape_mask) + event_shape_in_specified = array_ops.boolean_mask(event_shape_in, + event_shape_mask) + assertions.append(check_ops.assert_equal( + x_event_shape_specified, event_shape_in_specified, + message="Input `event_shape` does not match `event_shape_in`.")) + + if assertions: + x = control_flow_ops.with_dependencies(assertions, x) + + # get the parts of shape(x) that will not change + sample_and_batch_shape = array_ops.shape(x) + + ndims = (x.shape.ndims if x.shape.ndims is not None + else array_ops.rank(x)) + sample_and_batch_shape = sample_and_batch_shape[ + :(ndims - math_ops.abs(event_ndims_in))] + + if (event_ndims_in_ is not None + and x_ndims_ is not None + and event_ndims_in_ == x_ndims_): + # Hack to allow forward/inverse_event_shape to do shape + # inference by calling this helper method with a dummy Tensor of + # shape event_shape_in. In this special case, + # sample_and_batch_shape will be empty so we can preserve static + # shape information by avoiding the concat operation below + # (which would be a no-op). + new_shape = event_shape_out + else: + new_shape = array_ops.concat( + [sample_and_batch_shape, event_shape_out], axis=0) + + return array_ops.reshape(x, new_shape) + + def _forward(self, x): + with ops.control_dependencies(self._assertions): + return self._reshape_helper(x, + self._event_shape_in, + self._event_shape_out) + + def _inverse(self, y): + with ops.control_dependencies(self._assertions): + return self._reshape_helper(y, + self._event_shape_out, + self._event_shape_in) + + def _inverse_log_det_jacobian(self, y): + with ops.control_dependencies(self._assertions): + return constant_op.constant(0., dtype=y.dtype) + + def _forward_log_det_jacobian(self, x): + with ops.control_dependencies(self._assertions): + return constant_op.constant(0., dtype=x.dtype) + + def _forward_event_shape(self, input_shape): + # NOTE: this method and the other *_event_shape* methods + # compute shape by explicit transformation of a dummy + # variable. This approach is not generally recommended because it + # bloats the graph and could in general trigger side effects. + # + # In this particular case of the Reshape bijector, the + # forward and inverse transforms have no side effects, and we + # believe the reduction in code complexity from delegating the + # heavy lifting to tf.reshape() is worth the added graph ops. + # However, you should think hard before implementing this approach + # in other Bijectors; it is strongly preferred to compute + # shapes explicitly whenever it's feasible to do so. + with ops.control_dependencies(self._assertions): + dummy = array_ops.zeros(dtype=dtypes.float32, shape=input_shape) + dummy_reshaped = self.forward(dummy) + return dummy_reshaped.shape + + def _inverse_event_shape(self, output_shape): + with ops.control_dependencies(self._assertions): + dummy = array_ops.zeros(dtype=dtypes.float32, shape=output_shape) + dummy_reshaped = self.inverse(dummy) + return dummy_reshaped.shape + + def _forward_event_shape_tensor(self, input_shape): + with ops.control_dependencies(self._assertions): + dummy = array_ops.zeros(dtype=dtypes.float32, shape=input_shape) + dummy_reshaped = self.forward(dummy) + return array_ops.shape(dummy_reshaped) + + def _inverse_event_shape_tensor(self, output_shape): + with ops.control_dependencies(self._assertions): + dummy = array_ops.zeros(dtype=dtypes.float32, shape=output_shape) + dummy_reshaped = self.inverse(dummy) + return array_ops.shape(dummy_reshaped) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py deleted file mode 100644 index 93682639aa3be3b8f59a369dedb6ee773c468130..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py +++ /dev/null @@ -1,297 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Reshape bijectors.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector as bijector_lib - - -__all__ = [ - "Reshape", -] - - -class Reshape(bijector_lib.Bijector): - """Reshapes the `event_shape` of a `Tensor`. - - The semantics generally follow that of `tf.reshape()`, with - a few differences: - * The user must provide both the input and output shape, so that - the transformation can be inverted. - * The `Reshape` bijector automatically broadcasts over the leftmost - dimensions of its input (`sample_shape` and `batch_shape`); only - the rightmost `event_ndims_in` dimensions are reshaped. The - number of dimensions to reshape is inferred from the provided - `event_shape_in` (`event_ndims_in = len(event_shape_in)`). - * The `Reshape` bijector does not currently support - partially-specified shapes, i.e., those with a dimension - implicitly specified by `-1`. - - Example usage: - ```python - - bs = tf.contrib.distributions.bijectors - - reverse = bs.Reshape(event_shape_out=[1,2], - event_shape_in=[2,]) - - reverse.forward([1., 2.]) # shape [2,] - # ==> [[1., 2.]] # shape [1,2] - - reverse.forward([[1., 2.], [3., 4.]]) # shape [2, 2] - # ==> [[[1., 2.]], [[3., 4.]]] # shape [2, 1, 2] - - reverse.inverse([[1., 2.]]) # shape [1,2] - # ==> [1., 2.] # shape [2,] - - reverse.forward_log_det_jacobian(any_value) - # ==> 0. - - reverse.inverse_log_det_jacobian(any_value) - # ==> 0. - ``` - - """ - - def __init__(self, event_shape_out, event_shape_in, - validate_args=False, name=None): - """Creates a `Reshape` bijector. - - Args: - event_shape_out: An `int`-like vector-shaped `Tensor` - representing the fully specified (no -1's) event shape of the - transformed output. - event_shape_in: An `int`-like vector-shaped `Tensor` - representing the fully specified (no -1's) event shape of the - input. - validate_args: Python `bool` indicating whether arguments should - be checked for correctness. - name: Python `str`, name given to ops managed by this object. - - Raises: - TypeError: if either `event_shape_in` or `event_shape_out` has - non-vector shape (`rank > 1`), or non-integer `dtype`. - ValueError: if either `event_shape_in` or `event_shape_out` - contains non-positive entries, or if their sizes do not match - (`prod(event_shape_in)` != `prod(event_shape_out)`), or if - their dimensionality(s) cannot be statically inferred. - """ - with ops.name_scope(name, "reshape", - values=[event_shape_out, event_shape_in]): - - event_shape_out = ops.convert_to_tensor(event_shape_out, - name="event_shape_out", - preferred_dtype=dtypes.int32) - event_shape_in = ops.convert_to_tensor(event_shape_in, - name="event_shape_in", - preferred_dtype=dtypes.int32) - - # check that input shapes are positive integers - assertions = [] - assertions += self._maybe_check_valid_shape( - event_shape_out, "event_shape_out", - validate_args=validate_args) - assertions += self._maybe_check_valid_shape( - event_shape_in, "event_shape_in", validate_args=validate_args) - - # check that prod(event_shape_in) = prod(event_shape_out) - assertions += self._maybe_check_matching_sizes( - event_shape_in, event_shape_out, validate_args=validate_args) - - self._assertions = assertions - self._event_shape_in = event_shape_in - self._event_shape_out = event_shape_out - self._event_shape_in_static = tensor_util.constant_value_as_shape( - event_shape_in) - self._event_shape_out_static = tensor_util.constant_value_as_shape( - event_shape_out) - - super(Reshape, self).__init__(is_constant_jacobian=True, - validate_args=validate_args, - name=name or "reshape") - - def _maybe_check_valid_shape(self, shape_tensor, label, - validate_args=False): - """Check that a shape Tensor is int-type and positive.""" - - assertions = [] - - if not shape_tensor.dtype.is_integer: - raise TypeError("{} dtype ({}) should be `int`-like.".format( - label, shape_tensor.dtype.name)) - - shape_rank = tensor_util.constant_value(array_ops.rank(shape_tensor)) - if shape_rank is not None and shape_rank > 1: - raise ValueError("{} rank should be <= 1.".format(label)) - - s = tensor_util.constant_value(shape_tensor) - if s is not None: - if (s <= 0).any(): - raise ValueError("{} entries must be positive, but found {}".format( - label, s)) - elif validate_args: - assertions.append(check_ops.assert_positive( - shape_tensor, message="{} entries must be positive".format(label))) - - return assertions - - def _maybe_check_matching_sizes(self, event_shape_in, event_shape_out, - validate_args=False): - """Check that prod(event_shape_in)==prod(event_shape_out).""" - - def _get_size_from_shape(shape): - """Computes size from a shape `Tensor`, statically if possible.""" - s = tensor_util.constant_value(shape) - if s is not None: - return [np.int32(np.prod(s))]*2 - return None, math_ops.reduce_prod(shape, name="size") - - # Ensure `event_shape_in` is compatible with `event_shape_out`. - event_size_in_, event_size_in = _get_size_from_shape( # pylint: disable=unbalanced-tuple-unpacking - event_shape_in) - event_size_out_, event_size_out = _get_size_from_shape( # pylint: disable=unbalanced-tuple-unpacking - event_shape_out) - - assertions = [] - if event_size_in_ is not None and event_size_out_ is not None: - if event_size_in_ != event_size_out_: - raise ValueError( - "Input `event_size` ({}) does not match output `event_size` ({}).". - format(event_size_in, event_size_out_)) - elif validate_args: - assertions.append(check_ops.assert_equal( - event_size_in, event_size_out, - message="Input/output `event_size`s do not match.")) - - return assertions - - def _reshape_helper(self, x, event_shape_in, event_shape_out): - """Reshape only the event_shape of an input `Tensor`.""" - - def _get_rank_from_shape(shape): - """Computes rank from a shape `Tensor`, statically if possible.""" - # Uses fact that rank is "shape of shape". - ndims = shape.shape.with_rank_at_least(1)[0].value - if ndims is not None: - return ndims, ndims - return None, array_ops.shape(shape)[0] - - event_ndims_in_, event_ndims_in = _get_rank_from_shape(event_shape_in) - - assertions = [] - # Ensure x.event_shape is compatible with event_shape_in. - if x.shape.ndims is not None: - x_ndims_, x_ndims = [x.shape.ndims]*2 - else: - x_ndims_, x_ndims = None, array_ops.rank(x) - - if (event_ndims_in_ is not None - and x_ndims_ is not None - and x.shape.with_rank_at_least(event_ndims_in_)[ - x_ndims_-event_ndims_in_:].is_fully_defined()): - x_event_shape_, x_event_shape = [ # pylint: disable=unbalanced-tuple-unpacking - np.int32(x.shape[x_ndims_-event_ndims_in_:])]*2 - else: - x_event_shape_, x_event_shape = ( - None, array_ops.shape(x)[x_ndims-event_ndims_in:]) - - event_shape_in_ = tensor_util.constant_value(event_shape_in) - - if x_event_shape_ is not None and event_shape_in_ is not None: - if not np.equal(x_event_shape_, event_shape_in_).all(): - raise ValueError( - "Input `event_shape` ({}) does not match `event_shape_in` ({}).". - format(x_event_shape_, event_shape_in_)) - elif self.validate_args: - assertions.append(check_ops.assert_equal( - x_event_shape, event_shape_in, - message="Input `event_shape` does not match `event_shape_in`.")) - - if assertions: - x = control_flow_ops.with_dependencies(assertions, x) - - # get the parts of shape(x) that will not change - sample_and_batch_shape = array_ops.shape(x) - - ndims = (x.shape.ndims if x.shape.ndims is not None - else array_ops.rank(x)) - sample_and_batch_shape = sample_and_batch_shape[ - :(ndims - math_ops.abs(event_ndims_in))] - - new_shape = array_ops.concat( - [sample_and_batch_shape, event_shape_out], axis=0) - - return array_ops.reshape(x, new_shape) - - def _forward(self, x): - with ops.control_dependencies(self._assertions): - return self._reshape_helper(x, - self._event_shape_in, - self._event_shape_out) - - def _inverse(self, y): - with ops.control_dependencies(self._assertions): - return self._reshape_helper(y, - self._event_shape_out, - self._event_shape_in) - - def _inverse_log_det_jacobian(self, y): - with ops.control_dependencies(self._assertions): - return constant_op.constant(0., dtype=y.dtype) - - def _forward_log_det_jacobian(self, x): - with ops.control_dependencies(self._assertions): - return constant_op.constant(0., dtype=x.dtype) - - def _forward_event_shape(self, input_shape): - self._event_shape_in_static.assert_is_compatible_with(input_shape) - return self._event_shape_out_static - - def _inverse_event_shape(self, output_shape): - self._event_shape_out_static.assert_is_compatible_with(output_shape) - return self._event_shape_in_static - - def _forward_event_shape_tensor(self, input_shape): - input_assertions = self._maybe_check_valid_shape( - input_shape, "input event shape", validate_args=self.validate_args) - input_assertions += self._maybe_check_matching_sizes( - input_shape, self._event_shape_out, - validate_args=self.validate_args) - - return control_flow_ops.with_dependencies( - input_assertions + self._assertions, self._event_shape_out) - - def _inverse_event_shape_tensor(self, output_shape): - - output_assertions = self._maybe_check_valid_shape( - output_shape, "output event shape", validate_args=self.validate_args) - output_assertions += self._maybe_check_matching_sizes( - output_shape, self._event_shape_in, validate_args=self.validate_args) - - return control_flow_ops.with_dependencies( - output_assertions + self._assertions, self._event_shape_in) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py index c20e76c0b7367369865faf973377201c8b8b17e6..a640dfe7dfbcce96261589c7fc49107deaefdd54 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py @@ -18,12 +18,31 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["Sigmoid"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "Sigmoid", +] + + +class Sigmoid(bijector.Bijector): + """Bijector which computes `Y = g(X) = 1 / (1 + exp(-X))`.""" + + def __init__(self, validate_args=False, name="sigmoid"): + super(Sigmoid, self).__init__( + event_ndims=0, validate_args=validate_args, name=name) + + def _forward(self, x): + return math_ops.sigmoid(x) + + def _inverse(self, y): + return math_ops.log(y) - math_ops.log1p(-y) + + def _inverse_log_det_jacobian(self, y): + return -math_ops.log(y) - math_ops.log1p(-y) + + def _forward_log_det_jacobian(self, x): + return -nn_ops.softplus(-x) - nn_ops.softplus(x) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py index 448125230d24066697624bce03fed71a2c2f00b1..223bc9d042c69be05b0e578835a31ed6e83c0c97 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py @@ -18,12 +18,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_centered_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.contrib.distributions.python.ops.bijectors import softmax_centered -_allowed_symbols = ["SigmoidCentered"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "SigmoidCentered", +] + + +class SigmoidCentered(softmax_centered.SoftmaxCentered): + """Bijector which computes Y = g(X) = exp([X 0]) / (1 + exp(-X)). + + Equivalent to: `bijector.SoftmaxCentered(event_ndims=0)`. + + See `bijector.SoftmaxCentered` for more details. + """ + + def __init__(self, validate_args=False, name="sigmoid_centered"): + super(SigmoidCentered, self).__init__( + event_ndims=0, validate_args=validate_args, name=name) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py index b3cf03c24612f5c618c71c0a8615f272acdf2d10..3a75e4ae9495793901b0da91a5aa3982aab35852 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py @@ -18,12 +18,162 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.sinh_arcsinh_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import numpy as np -_allowed_symbols = ["SinhArcsinh"] +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "SinhArcsinh", +] + + +def _sqrtx2p1(x): + """Implementation of `sqrt(1 + x**2)` which is stable despite large `x`.""" + return array_ops.where( + math_ops.abs(x) * np.sqrt(np.finfo(x.dtype.as_numpy_dtype).eps) <= 1., + math_ops.sqrt(x**2. + 1.), + # For large x, calculating x**2 can overflow. This can be alleviated by + # considering: + # sqrt(1 + x**2) + # = exp(0.5 log(1 + x**2)) + # = exp(0.5 log(x**2 * (1 + x**-2))) + # = exp(log(x) + 0.5 * log(1 + x**-2)) + # = |x| * exp(0.5 log(1 + x**-2)) + # = |x| * sqrt(1 + x**-2) + # We omit the last term in this approximation. + # When |x| > 1 / sqrt(machineepsilon), the second term will be 1, + # due to sqrt(1 + x**-2) = 1. This is also true with the gradient term, + # and higher order gradients, since the first order derivative of + # sqrt(1 + x**-2) is -2 * x**-3 / (1 + x**-2) = -2 / (x**3 + x), + # and all nth-order derivatives will be O(x**-(n + 2)). This makes any + # gradient terms that contain any derivatives of sqrt(1 + x**-2) vanish. + math_ops.abs(x)) + + +class SinhArcsinh(bijector.Bijector): + """Compute `Y = g(X) = Sinh( (Arcsinh(X) + skewness) * tailweight )`. + + For `skewness in (-inf, inf)` and `tailweight in (0, inf)`, this + transformation is a + diffeomorphism of the real line `(-inf, inf)`. The inverse transform is + `X = g^{-1}(Y) = Sinh( ArcSinh(Y) / tailweight - skewness )`. + + The `SinhArcsinh` transformation of the Normal is described in + [Sinh-arcsinh distributions](https://www.jstor.org/stable/27798865) + This Bijector allows a similar transformation of any distribution supported on + `(-inf, inf)`. + + #### Meaning of the parameters + + * If `skewness = 0` and `tailweight = 1`, this transform is the identity. + * Positive (negative) `skewness` leads to positive (negative) skew. + * positive skew means, for unimodal `X` centered at zero, the mode of `Y` is + "tilted" to the right. + * positive skew means positive values of `Y` become more likely, and + negative values become less likely. + * Larger (smaller) `tailweight` leads to fatter (thinner) tails. + * Fatter tails mean larger values of `|Y|` become more likely. + * If `X` is a unit Normal, `tailweight < 1` leads to a distribution that is + "flat" around `Y = 0`, and a very steep drop-off in the tails. + * If `X` is a unit Normal, `tailweight > 1` leads to a distribution more + peaked at the mode with heavier tails. + + To see the argument about the tails, note that for `|X| >> 1` and + `|X| >> (|skewness| * tailweight)**tailweight`, we have + `Y approx 0.5 X**tailweight e**(sign(X) skewness * tailweight)`. + """ + + def __init__(self, + skewness=None, + tailweight=None, + event_ndims=0, + validate_args=False, + name="SinhArcsinh"): + """Instantiates the `SinhArcsinh` bijector. + + Args: + skewness: Skewness parameter. Float-type `Tensor`. Default is `0` + of type `float32`. + tailweight: Tailweight parameter. Positive `Tensor` of same `dtype` as + `skewness` and broadcastable `shape`. Default is `1` of type `float32`. + event_ndims: Python scalar indicating the number of dimensions associated + with a particular draw from the distribution. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + with self._name_scope("init", values=[skewness, tailweight]): + tailweight = 1. if tailweight is None else tailweight + skewness = 0. if skewness is None else skewness + self._skewness = ops.convert_to_tensor( + skewness, name="skewness") + self._tailweight = ops.convert_to_tensor( + tailweight, name="tailweight", dtype=self._skewness.dtype) + check_ops.assert_same_float_dtype([self._skewness, self._tailweight]) + if validate_args: + self._tailweight = control_flow_ops.with_dependencies([ + check_ops.assert_positive( + self._tailweight, + message="Argument tailweight was not positive") + ], self._tailweight) + super(SinhArcsinh, self).__init__( + event_ndims=event_ndims, validate_args=validate_args, name=name) + + @property + def skewness(self): + """The `skewness` in: `Y = Sinh((Arcsinh(X) + skewness) * tailweight)`.""" + return self._skewness + + @property + def tailweight(self): + """The `tailweight` in: `Y = Sinh((Arcsinh(X) + skewness) * tailweight)`.""" + return self._tailweight + + def _forward(self, x): + return math_ops.sinh((math_ops.asinh(x) + self.skewness) * self.tailweight) + + def _inverse(self, y): + return math_ops.sinh(math_ops.asinh(y) / self.tailweight - self.skewness) + + def _inverse_log_det_jacobian(self, y): + # x = sinh(arcsinh(y) / tailweight - skewness) + # Using sinh' = cosh, arcsinh'(y) = 1 / sqrt(y**2 + 1), + # dx/dy + # = cosh(arcsinh(y) / tailweight - skewness) + # / (tailweight * sqrt(y**2 + 1)) + event_dims = self._event_dims_tensor(y) + return math_ops.reduce_sum( + # This is computed inside the log to avoid catastrophic cancellations + # from cosh((arcsinh(y) / tailweight) - skewness) and sqrt(x**2 + 1). + math_ops.log(math_ops.cosh( + math_ops.asinh(y) / self.tailweight - self.skewness) + # TODO(srvasude): Consider using cosh(arcsinh(x)) in cases + # where (arcsinh(x) / tailweight) - skewness ~= arcsinh(x). + / _sqrtx2p1(y)) + - math_ops.log(self.tailweight), + axis=event_dims) + + def _forward_log_det_jacobian(self, x): + # y = sinh((arcsinh(x) + skewness) * tailweight) + # Using sinh' = cosh, arcsinh'(x) = 1 / sqrt(x**2 + 1), + # dy/dx + # = cosh((arcsinh(x) + skewness) * tailweight) * tailweight / sqrt(x**2 + 1) + event_dims = self._event_dims_tensor(x) + return math_ops.reduce_sum( + # This is computed inside the log to avoid catastrophic cancellations + # from cosh((arcsinh(x) + skewness) * tailweight) and sqrt(x**2 + 1). + math_ops.log(math_ops.cosh( + (math_ops.asinh(x) + self.skewness) * self.tailweight) + # TODO(srvasude): Consider using cosh(arcsinh(x)) in cases + # where (arcsinh(x) + skewness) * tailweight ~= arcsinh(x). + / _sqrtx2p1(x)) + + math_ops.log(self.tailweight), + axis=event_dims) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh_impl.py deleted file mode 100644 index 3a75e4ae9495793901b0da91a5aa3982aab35852..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh_impl.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""SinhArcsinh bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector - -__all__ = [ - "SinhArcsinh", -] - - -def _sqrtx2p1(x): - """Implementation of `sqrt(1 + x**2)` which is stable despite large `x`.""" - return array_ops.where( - math_ops.abs(x) * np.sqrt(np.finfo(x.dtype.as_numpy_dtype).eps) <= 1., - math_ops.sqrt(x**2. + 1.), - # For large x, calculating x**2 can overflow. This can be alleviated by - # considering: - # sqrt(1 + x**2) - # = exp(0.5 log(1 + x**2)) - # = exp(0.5 log(x**2 * (1 + x**-2))) - # = exp(log(x) + 0.5 * log(1 + x**-2)) - # = |x| * exp(0.5 log(1 + x**-2)) - # = |x| * sqrt(1 + x**-2) - # We omit the last term in this approximation. - # When |x| > 1 / sqrt(machineepsilon), the second term will be 1, - # due to sqrt(1 + x**-2) = 1. This is also true with the gradient term, - # and higher order gradients, since the first order derivative of - # sqrt(1 + x**-2) is -2 * x**-3 / (1 + x**-2) = -2 / (x**3 + x), - # and all nth-order derivatives will be O(x**-(n + 2)). This makes any - # gradient terms that contain any derivatives of sqrt(1 + x**-2) vanish. - math_ops.abs(x)) - - -class SinhArcsinh(bijector.Bijector): - """Compute `Y = g(X) = Sinh( (Arcsinh(X) + skewness) * tailweight )`. - - For `skewness in (-inf, inf)` and `tailweight in (0, inf)`, this - transformation is a - diffeomorphism of the real line `(-inf, inf)`. The inverse transform is - `X = g^{-1}(Y) = Sinh( ArcSinh(Y) / tailweight - skewness )`. - - The `SinhArcsinh` transformation of the Normal is described in - [Sinh-arcsinh distributions](https://www.jstor.org/stable/27798865) - This Bijector allows a similar transformation of any distribution supported on - `(-inf, inf)`. - - #### Meaning of the parameters - - * If `skewness = 0` and `tailweight = 1`, this transform is the identity. - * Positive (negative) `skewness` leads to positive (negative) skew. - * positive skew means, for unimodal `X` centered at zero, the mode of `Y` is - "tilted" to the right. - * positive skew means positive values of `Y` become more likely, and - negative values become less likely. - * Larger (smaller) `tailweight` leads to fatter (thinner) tails. - * Fatter tails mean larger values of `|Y|` become more likely. - * If `X` is a unit Normal, `tailweight < 1` leads to a distribution that is - "flat" around `Y = 0`, and a very steep drop-off in the tails. - * If `X` is a unit Normal, `tailweight > 1` leads to a distribution more - peaked at the mode with heavier tails. - - To see the argument about the tails, note that for `|X| >> 1` and - `|X| >> (|skewness| * tailweight)**tailweight`, we have - `Y approx 0.5 X**tailweight e**(sign(X) skewness * tailweight)`. - """ - - def __init__(self, - skewness=None, - tailweight=None, - event_ndims=0, - validate_args=False, - name="SinhArcsinh"): - """Instantiates the `SinhArcsinh` bijector. - - Args: - skewness: Skewness parameter. Float-type `Tensor`. Default is `0` - of type `float32`. - tailweight: Tailweight parameter. Positive `Tensor` of same `dtype` as - `skewness` and broadcastable `shape`. Default is `1` of type `float32`. - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - """ - self._graph_parents = [] - self._name = name - self._validate_args = validate_args - with self._name_scope("init", values=[skewness, tailweight]): - tailweight = 1. if tailweight is None else tailweight - skewness = 0. if skewness is None else skewness - self._skewness = ops.convert_to_tensor( - skewness, name="skewness") - self._tailweight = ops.convert_to_tensor( - tailweight, name="tailweight", dtype=self._skewness.dtype) - check_ops.assert_same_float_dtype([self._skewness, self._tailweight]) - if validate_args: - self._tailweight = control_flow_ops.with_dependencies([ - check_ops.assert_positive( - self._tailweight, - message="Argument tailweight was not positive") - ], self._tailweight) - super(SinhArcsinh, self).__init__( - event_ndims=event_ndims, validate_args=validate_args, name=name) - - @property - def skewness(self): - """The `skewness` in: `Y = Sinh((Arcsinh(X) + skewness) * tailweight)`.""" - return self._skewness - - @property - def tailweight(self): - """The `tailweight` in: `Y = Sinh((Arcsinh(X) + skewness) * tailweight)`.""" - return self._tailweight - - def _forward(self, x): - return math_ops.sinh((math_ops.asinh(x) + self.skewness) * self.tailweight) - - def _inverse(self, y): - return math_ops.sinh(math_ops.asinh(y) / self.tailweight - self.skewness) - - def _inverse_log_det_jacobian(self, y): - # x = sinh(arcsinh(y) / tailweight - skewness) - # Using sinh' = cosh, arcsinh'(y) = 1 / sqrt(y**2 + 1), - # dx/dy - # = cosh(arcsinh(y) / tailweight - skewness) - # / (tailweight * sqrt(y**2 + 1)) - event_dims = self._event_dims_tensor(y) - return math_ops.reduce_sum( - # This is computed inside the log to avoid catastrophic cancellations - # from cosh((arcsinh(y) / tailweight) - skewness) and sqrt(x**2 + 1). - math_ops.log(math_ops.cosh( - math_ops.asinh(y) / self.tailweight - self.skewness) - # TODO(srvasude): Consider using cosh(arcsinh(x)) in cases - # where (arcsinh(x) / tailweight) - skewness ~= arcsinh(x). - / _sqrtx2p1(y)) - - math_ops.log(self.tailweight), - axis=event_dims) - - def _forward_log_det_jacobian(self, x): - # y = sinh((arcsinh(x) + skewness) * tailweight) - # Using sinh' = cosh, arcsinh'(x) = 1 / sqrt(x**2 + 1), - # dy/dx - # = cosh((arcsinh(x) + skewness) * tailweight) * tailweight / sqrt(x**2 + 1) - event_dims = self._event_dims_tensor(x) - return math_ops.reduce_sum( - # This is computed inside the log to avoid catastrophic cancellations - # from cosh((arcsinh(x) + skewness) * tailweight) and sqrt(x**2 + 1). - math_ops.log(math_ops.cosh( - (math_ops.asinh(x) + self.skewness) * self.tailweight) - # TODO(srvasude): Consider using cosh(arcsinh(x)) in cases - # where (arcsinh(x) + skewness) * tailweight ~= arcsinh(x). - / _sqrtx2p1(x)) - + math_ops.log(self.tailweight), - axis=event_dims) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py index be6608f97880ae68e10b17c815bf2d8438293261..a9dcce6c526600f3b26c6bceb730417000917ce7 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py @@ -18,12 +18,223 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import numpy as np -_allowed_symbols = ["SoftmaxCentered"] +from tensorflow.contrib.distributions.python.ops import distribution_util +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops.distributions import bijector -remove_undocumented(__name__, _allowed_symbols) + +__all__ = [ + "SoftmaxCentered", +] + + +class SoftmaxCentered(bijector.Bijector): + """Bijector which computes `Y = g(X) = exp([X 0]) / sum(exp([X 0]))`. + + To implement [softmax](https://en.wikipedia.org/wiki/Softmax_function) as a + bijection, the forward transformation appends a value to the input and the + inverse removes this coordinate. The appended coordinate represents a pivot, + e.g., `softmax(x) = exp(x-c) / sum(exp(x-c))` where `c` is the implicit last + coordinate. + + Because we append a coordinate, this bijector only supports `event_ndim in [0, + 1]`, i.e., scalars and vectors. + + Example Use: + + ```python + bijector.SoftmaxCentered(event_ndims=1).forward(tf.log([2, 3, 4])) + # Result: [0.2, 0.3, 0.4, 0.1] + # Extra result: 0.1 + + bijector.SoftmaxCentered(event_ndims=1).inverse([0.2, 0.3, 0.4, 0.1]) + # Result: tf.log([2, 3, 4]) + # Extra coordinate removed. + ``` + + At first blush it may seem like the [Invariance of domain]( + https://en.wikipedia.org/wiki/Invariance_of_domain) theorem implies this + implementation is not a bijection. However, the appended dimension + makes the (forward) image non-open and the theorem does not directly apply. + """ + + def __init__(self, + event_ndims=0, + validate_args=False, + name="softmax_centered"): + self._graph_parents = [] + self._name = name + with self._name_scope("init", values=[event_ndims]): + event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") + event_ndims = tensor_util.constant_value(event_ndims) + if event_ndims is None or event_ndims not in [0, 1]: + raise ValueError("`event_ndims` must be a TF constant which is 0 or 1") + self._static_event_ndims = event_ndims + super(SoftmaxCentered, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) + + def _forward_event_shape(self, input_shape): + if input_shape.ndims is None: + return input_shape + if input_shape.ndims != self._static_event_ndims: + raise ValueError("input_shape.dims = %d != %d" % + (input_shape.ndims, self._static_event_ndims)) + if input_shape.ndims == 0: + return tensor_shape.TensorShape([2]) + if input_shape.ndims == 1: + return tensor_shape.TensorShape(input_shape[0] + 1) + # Unreachable code: + raise ValueError("event_ndims = %d must be 0 or 1" % input_shape.ndims) + + def _forward_event_shape_tensor(self, input_shape): + ndims = array_ops.shape(input_shape) + if self.validate_args: + # It is not possible for a negative shape so we need only check <= 1. + is_zero_or_one = check_ops.assert_equal( + ndims, 0 if self._static_event_ndims == 0 else 1, + message="event_ndims must be 0 or 1") + ndims = control_flow_ops.with_dependencies([is_zero_or_one], ndims) + if self._static_event_ndims == 0: + return ops.convert_to_tensor( + [2], dtype=dtypes.int32, name="output_shape") + return input_shape + 1 + + def _inverse_event_shape(self, output_shape): + if output_shape.ndims is None: + return output_shape + if output_shape.ndims != 1: + raise ValueError("output_shape.ndims = %d != 1" % output_shape.ndims) + if self._static_event_ndims == 0: + return tensor_shape.TensorShape([]) + return tensor_shape.TensorShape(output_shape[0] - 1) + + def _inverse_event_shape_tensor(self, output_shape): + ndims = array_ops.shape(output_shape)[0] + if self.validate_args: + # It is not possible for a negative shape so we need only check <= 1. + is_one = check_ops.assert_equal( + ndims, 1, message="event_ndims must be 1") + ndims = control_flow_ops.with_dependencies([is_one], ndims) + if self._static_event_ndims == 0: + return ops.convert_to_tensor([], dtype=dtypes.int32, name="output_shape") + return array_ops.expand_dims(output_shape[0] - 1, dim=0) + + def _forward(self, x): + # Pad the last dim with a zeros vector. We need this because it lets us + # infer the scale in the inverse function. + y = array_ops.expand_dims(x, dim=-1) if self._static_event_ndims == 0 else x + y = distribution_util.pad(y, axis=-1, back=True) + + # Set shape hints. + if x.shape.ndims is not None: + shape = x.shape.as_list() + if self._static_event_ndims == 0: + shape += [2] + elif shape[-1] is not None: + shape[-1] += 1 + shape = tensor_shape.TensorShape(shape) + y.shape.assert_is_compatible_with(shape) + y.set_shape(shape) + + # Since we only support event_ndims in [0, 1] and we do padding, we always + # reduce over the last dimension, i.e., dim=-1 (which is the default). + return nn_ops.softmax(y) + + def _inverse(self, y): + # To derive the inverse mapping note that: + # y[i] = exp(x[i]) / normalization + # and + # y[end] = 1 / normalization. + # Thus: + # x[i] = log(exp(x[i])) - log(y[end]) - log(normalization) + # = log(exp(x[i])/normalization) - log(y[end]) + # = log(y[i]) - log(y[end]) + shape = (np.asarray(y.shape.as_list(), dtype=np.int32) + if y.shape.is_fully_defined() + else array_ops.shape(y, name="shape")) + ndims = distribution_util.prefer_static_rank(y) + + # Do this first to make sure CSE catches that it'll happen again in + # _inverse_log_det_jacobian. + x = math_ops.log(y) + + # We now extract the last coordinate of the rightmost dimension. + # Our trick is to slice from [0,0,...,shape[-1]-1] to shape[:-1]+[1]. + begin = array_ops.one_hot(indices=ndims-1, + depth=ndims, + on_value=shape[-1]-np.array(1, dtype=shape.dtype), + dtype=shape.dtype) + size = array_ops.concat([shape[:-1], np.asarray([1], dtype=shape.dtype)], 0) + log_normalization = -array_ops.strided_slice(x, begin, begin + size) + + # Here we slice out all but the last coordinate; see above for idea. + begin = array_ops.zeros_like(shape) + size = array_ops.concat([shape[:-1], [shape[-1] - 1]], 0) + x = array_ops.strided_slice(x, begin, begin + size) + + x += log_normalization + + if self._static_event_ndims == 0: + x = array_ops.squeeze(x, squeeze_dims=[ndims-1]) + + # Set shape hints. + if y.shape.ndims is not None: + shape = y.shape.as_list() + if self._static_event_ndims == 0: + shape = shape[:-1] + elif shape[-1] is not None: + shape[-1] -= 1 + shape = tensor_shape.TensorShape(shape) + x.shape.assert_is_compatible_with(shape) + x.set_shape(shape) + + return x + + def _inverse_log_det_jacobian(self, y): + # WLOG, consider the vector case: + # x = log(y[:-1]) - log(y[-1]) + # where, + # y[-1] = 1 - sum(y[:-1]). + # We have: + # det{ dX/dY } = det{ diag(1 ./ y[:-1]) + 1 / y[-1] } + # = det{ inv{ diag(y[:-1]) - y[:-1]' y[:-1] } } (1) + # = 1 / det{ diag(y[:-1]) - y[:-1]' y[:-1] } + # = 1 / { (1 + y[:-1]' inv(diag(y[:-1])) y[:-1]) * + # det(diag(y[:-1])) } (2) + # = 1 / { y[-1] prod(y[:-1]) } + # = 1 / prod(y) + # (1) - https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula + # or by noting that det{ dX/dY } = 1 / det{ dY/dX } from Bijector + # docstring "Tip". + # (2) - https://en.wikipedia.org/wiki/Matrix_determinant_lemma + return -math_ops.reduce_sum(math_ops.log(y), axis=-1) + + def _forward_log_det_jacobian(self, x): + if self._static_event_ndims == 0: + return x - 2. * nn_ops.softplus(x) + else: + # This code is similar to nn_ops.log_softmax but different because we have + # an implicit zero column to handle. I.e., instead of: + # reduce_sum(logits - reduce_sum(exp(logits), dim)) + # we must do: + # log_normalization = 1 + reduce_sum(exp(logits)) + # -log_normalization + reduce_sum(logits - log_normalization) + log_normalization = nn_ops.softplus( + math_ops.reduce_logsumexp(x, axis=-1, keep_dims=True)) + fldj = (-log_normalization + + math_ops.reduce_sum(x - log_normalization, + axis=-1, + keep_dims=True)) + return array_ops.squeeze(fldj, squeeze_dims=-1) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered_impl.py deleted file mode 100644 index 8645cc1b6b04be75a419342591272f07a4a1711c..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered_impl.py +++ /dev/null @@ -1,245 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""SoftmaxCentered bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops.distributions import bijector - - -__all__ = [ - "SoftmaxCentered", -] - - -class SoftmaxCentered(bijector.Bijector): - """Bijector which computes `Y = g(X) = exp([X 0]) / sum(exp([X 0]))`. - - To implement [softmax](https://en.wikipedia.org/wiki/Softmax_function) as a - bijection, the forward transformation appends a value to the input and the - inverse removes this coordinate. The appended coordinate represents a pivot, - e.g., `softmax(x) = exp(x-c) / sum(exp(x-c))` where `c` is the implicit last - coordinate. - - Because we append a coordinate, this bijector only supports `event_ndim in [0, - 1]`, i.e., scalars and vectors. - - Example Use: - - ```python - bijector.SoftmaxCentered(event_ndims=1).forward(tf.log([2, 3, 4])) - # Result: [0.2, 0.3, 0.4, 0.1] - # Extra result: 0.1 - - bijector.SoftmaxCentered(event_ndims=1).inverse([0.2, 0.3, 0.4, 0.1]) - # Result: tf.log([2, 3, 4]) - # Extra coordinate removed. - ``` - - At first blush it may seem like the [Invariance of domain]( - https://en.wikipedia.org/wiki/Invariance_of_domain) theorem implies this - implementation is not a bijection. However, the appended dimension - makes the (forward) image non-open and the theorem does not directly apply. - """ - - def __init__(self, - event_ndims=0, - validate_args=False, - name="softmax_centered"): - self._graph_parents = [] - self._name = name - with self._name_scope("init", values=[event_ndims]): - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims = tensor_util.constant_value(event_ndims) - if event_ndims is None or event_ndims not in [0, 1]: - raise ValueError("`event_ndims` must be a TF constant which is 0 or 1") - self._static_event_ndims = event_ndims - super(SoftmaxCentered, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) - - def _forward_event_shape(self, input_shape): - if input_shape.ndims is None: - return input_shape - if input_shape.ndims != self._static_event_ndims: - raise ValueError("input_shape.dims = %d != %d" % - (input_shape.ndims, self._static_event_ndims)) - if input_shape.ndims == 0: - return tensor_shape.TensorShape([2]) - if input_shape.ndims == 1: - return tensor_shape.TensorShape(input_shape[0] + 1) - # Unreachable code: - raise ValueError("event_ndims = %d must be 0 or 1" % input_shape.ndims) - - def _forward_event_shape_tensor(self, input_shape): - ndims = array_ops.shape(input_shape) - if self.validate_args: - # It is not possible for a negative shape so we need only check <= 1. - is_zero_or_one = check_ops.assert_equal( - ndims, 0 if self._static_event_ndims == 0 else 1, - message="event_ndims must be 0 or 1") - ndims = control_flow_ops.with_dependencies([is_zero_or_one], ndims) - if self._static_event_ndims == 0: - return ops.convert_to_tensor( - [2], dtype=dtypes.int32, name="output_shape") - return input_shape + 1 - - def _inverse_event_shape(self, output_shape): - if output_shape.ndims is None: - return output_shape - if output_shape.ndims != 1: - raise ValueError("output_shape.ndims = %d != 1" % output_shape.ndims) - if self._static_event_ndims == 0: - return tensor_shape.TensorShape([]) - return tensor_shape.TensorShape(output_shape[0] - 1) - - def _inverse_event_shape_tensor(self, output_shape): - ndims = array_ops.shape(output_shape)[0] - if self.validate_args: - # It is not possible for a negative shape so we need only check <= 1. - is_one = check_ops.assert_equal( - ndims, 1, message="event_ndims must be 1") - ndims = control_flow_ops.with_dependencies([is_one], ndims) - if self._static_event_ndims == 0: - return ops.convert_to_tensor([], dtype=dtypes.int32, name="output_shape") - return array_ops.expand_dims(output_shape[0] - 1, dim=0) - - def _forward(self, x): - # Pad the last dim with a zeros vector. We need this because it lets us - # infer the scale in the inverse function. - y = array_ops.expand_dims(x, dim=-1) if self._static_event_ndims == 0 else x - ndims = (y.get_shape().ndims if y.get_shape().ndims is not None - else array_ops.rank(y)) - y = array_ops.pad(y, - paddings=array_ops.concat( - (array_ops.zeros( - (ndims - 1, 2), dtype=dtypes.int32), [[0, 1]]), - 0)) - - # Set shape hints. - if x.get_shape().ndims is not None: - shape = x.get_shape().as_list() - if self._static_event_ndims == 0: - shape += [2] - elif shape[-1] is not None: - shape[-1] += 1 - shape = tensor_shape.TensorShape(shape) - y.get_shape().assert_is_compatible_with(shape) - y.set_shape(shape) - - # Since we only support event_ndims in [0, 1] and we do padding, we always - # reduce over the last dimension, i.e., dim=-1 (which is the default). - return nn_ops.softmax(y) - - def _inverse(self, y): - # To derive the inverse mapping note that: - # y[i] = exp(x[i]) / normalization - # and - # y[end] = 1 / normalization. - # Thus: - # x[i] = log(exp(x[i])) - log(y[end]) - log(normalization) - # = log(exp(x[i])/normalization) - log(y[end]) - # = log(y[i]) - log(y[end]) - shape = (np.asarray(y.get_shape().as_list(), dtype=np.int32) - if y.get_shape().is_fully_defined() - else array_ops.shape(y, name="shape")) - ndims = y.get_shape().ndims or math_ops.rank(y, name="ndims") - - # Do this first to make sure CSE catches that it'll happen again in - # _inverse_log_det_jacobian. - x = math_ops.log(y) - - # We now extract the last coordinate of the rightmost dimension. - # Our trick is to slice from [0,0,...,shape[-1]-1] to shape[:-1]+[1]. - begin = array_ops.one_hot(indices=ndims-1, - depth=ndims, - on_value=shape[-1]-np.array(1, dtype=shape.dtype), - dtype=shape.dtype) - size = array_ops.concat([shape[:-1], np.asarray([1], dtype=shape.dtype)], 0) - log_normalization = -array_ops.strided_slice(x, begin, begin + size) - - # Here we slice out all but the last coordinate; see above for idea. - begin = array_ops.zeros_like(shape) - size = array_ops.concat([shape[:-1], [shape[-1] - 1]], 0) - x = array_ops.strided_slice(x, begin, begin + size) - - x += log_normalization - - if self._static_event_ndims == 0: - x = array_ops.squeeze(x, squeeze_dims=[ndims-1]) - - # Set shape hints. - if y.get_shape().ndims is not None: - shape = y.get_shape().as_list() - if self._static_event_ndims == 0: - shape = shape[:-1] - elif shape[-1] is not None: - shape[-1] -= 1 - shape = tensor_shape.TensorShape(shape) - x.get_shape().assert_is_compatible_with(shape) - x.set_shape(shape) - - return x - - def _inverse_log_det_jacobian(self, y): - # WLOG, consider the vector case: - # x = log(y[:-1]) - log(y[-1]) - # where, - # y[-1] = 1 - sum(y[:-1]). - # We have: - # det{ dX/dY } = det{ diag(1 ./ y[:-1]) + 1 / y[-1] } - # = det{ inv{ diag(y[:-1]) - y[:-1]' y[:-1] } } (1) - # = 1 / det{ diag(y[:-1]) - y[:-1]' y[:-1] } - # = 1 / { (1 + y[:-1]' inv(diag(y[:-1])) y[:-1]) * - # det(diag(y[:-1])) } (2) - # = 1 / { y[-1] prod(y[:-1]) } - # = 1 / prod(y) - # (1) - https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula - # or by noting that det{ dX/dY } = 1 / det{ dY/dX } from Bijector - # docstring "Tip". - # (2) - https://en.wikipedia.org/wiki/Matrix_determinant_lemma - return -math_ops.reduce_sum(math_ops.log(y), axis=-1) - - def _forward_log_det_jacobian(self, x): - if self._static_event_ndims == 0: - return x - 2. * nn_ops.softplus(x) - else: - # This code is similar to nn_ops.log_softmax but different because we have - # an implicit zero column to handle. I.e., instead of: - # reduce_sum(logits - reduce_sum(exp(logits), dim)) - # we must do: - # log_normalization = 1 + reduce_sum(exp(logits)) - # -log_normalization + reduce_sum(logits - log_normalization) - log_normalization = nn_ops.softplus( - math_ops.reduce_logsumexp(x, axis=-1, keep_dims=True)) - fldj = (-log_normalization + - math_ops.reduce_sum(x - log_normalization, - axis=-1, - keep_dims=True)) - return array_ops.squeeze(fldj, squeeze_dims=-1) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py b/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py index 250a1144b53bb43271ff7ee494604d9bae6feda8..81957fcf78922fa15fd20a25d144071f431161ae 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py @@ -18,12 +18,127 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.softplus_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.framework import ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.ops.distributions import util as distribution_util -_allowed_symbols = ["Softplus"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "Softplus", +] + + +class Softplus(bijector.Bijector): + """Bijector which computes `Y = g(X) = Log[1 + exp(X)]`. + + The softplus `Bijector` has the following two useful properties: + + * The domain is the positive real numbers + * `softplus(x) approx x`, for large `x`, so it does not overflow as easily as + the `Exp` `Bijector`. + + The optional nonzero `hinge_softness` parameter changes the transition at + zero. With `hinge_softness = c`, the bijector is: + + ```f_c(x) := c * g(x / c) = c * Log[1 + exp(x / c)].``` + + For large `x >> 1`, `c * Log[1 + exp(x / c)] approx c * Log[exp(x / c)] = x`, + so the behavior for large `x` is the same as the standard softplus. + + As `c > 0` approaches 0 from the right, `f_c(x)` becomes less and less soft, + approaching `max(0, x)`. + + * `c = 1` is the default. + * `c > 0` but small means `f(x) approx ReLu(x) = max(0, x)`. + * `c < 0` flips sign and reflects around the `y-axis`: `f_{-c}(x) = -f_c(-x)`. + * `c = 0` results in a non-bijective transformation and triggers an exception. + + Example Use: + + ```python + # Create the Y=g(X)=softplus(X) transform which works only on Tensors with 1 + # batch ndim and 2 event ndims (i.e., vector of matrices). + softplus = Softplus(event_ndims=2) + x = [[[1., 2], + [3, 4]], + [[5, 6], + [7, 8]]] + log(1 + exp(x)) == softplus.forward(x) + log(exp(x) - 1) == softplus.inverse(x) + ``` + + Note: log(.) and exp(.) are applied element-wise but the Jacobian is a + reduction over the event space. + """ + + @distribution_util.AppendDocstring( + kwargs_dict={ + "hinge_softness": ( + "Nonzero floating point `Tensor`. Controls the softness of what " + "would otherwise be a kink at the origin. Default is 1.0")}) + def __init__(self, + event_ndims=0, + hinge_softness=None, + validate_args=False, + name="softplus"): + with ops.name_scope(name, values=[hinge_softness]): + if hinge_softness is not None: + self._hinge_softness = ops.convert_to_tensor( + hinge_softness, name="hinge_softness") + else: + self._hinge_softness = None + if validate_args: + nonzero_check = check_ops.assert_none_equal( + ops.convert_to_tensor( + 0, dtype=self.hinge_softness.dtype), + self.hinge_softness, + message="hinge_softness must be non-zero") + self._hinge_softness = control_flow_ops.with_dependencies( + [nonzero_check], self.hinge_softness) + + super(Softplus, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) + + def _forward(self, x): + if self.hinge_softness is None: + return nn_ops.softplus(x) + hinge_softness = math_ops.cast(self.hinge_softness, x.dtype) + return hinge_softness * nn_ops.softplus(x / hinge_softness) + + def _inverse(self, y): + if self.hinge_softness is None: + return distribution_util.softplus_inverse(y) + hinge_softness = math_ops.cast(self.hinge_softness, y.dtype) + return hinge_softness * distribution_util.softplus_inverse( + y / hinge_softness) + + def _inverse_log_det_jacobian(self, y): + # Could also do: + # ildj = math_ops.reduce_sum(y - distribution_util.softplus_inverse(y), + # axis=event_dims) + # but the following is more numerically stable. Ie, + # Y = Log[1 + exp{X}] ==> X = Log[exp{Y} - 1] + # ==> dX/dY = exp{Y} / (exp{Y} - 1) + # = 1 / (1 - exp{-Y}), + # which is the most stable for large Y > 0. For small Y, we use + # 1 - exp{-Y} approx Y. + if self.hinge_softness is not None: + y /= math_ops.cast(self.hinge_softness, y.dtype) + return -math_ops.reduce_sum(math_ops.log(-math_ops.expm1(-y)), + axis=self._event_dims_tensor(y)) + + def _forward_log_det_jacobian(self, x): + if self.hinge_softness is not None: + x /= math_ops.cast(self.hinge_softness, x.dtype) + return -math_ops.reduce_sum(nn_ops.softplus(-x), + axis=self._event_dims_tensor(x)) + + @property + def hinge_softness(self): + return self._hinge_softness diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softplus_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/softplus_impl.py deleted file mode 100644 index 81957fcf78922fa15fd20a25d144071f431161ae..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softplus_impl.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Softplus bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops.distributions import bijector -from tensorflow.python.ops.distributions import util as distribution_util - - -__all__ = [ - "Softplus", -] - - -class Softplus(bijector.Bijector): - """Bijector which computes `Y = g(X) = Log[1 + exp(X)]`. - - The softplus `Bijector` has the following two useful properties: - - * The domain is the positive real numbers - * `softplus(x) approx x`, for large `x`, so it does not overflow as easily as - the `Exp` `Bijector`. - - The optional nonzero `hinge_softness` parameter changes the transition at - zero. With `hinge_softness = c`, the bijector is: - - ```f_c(x) := c * g(x / c) = c * Log[1 + exp(x / c)].``` - - For large `x >> 1`, `c * Log[1 + exp(x / c)] approx c * Log[exp(x / c)] = x`, - so the behavior for large `x` is the same as the standard softplus. - - As `c > 0` approaches 0 from the right, `f_c(x)` becomes less and less soft, - approaching `max(0, x)`. - - * `c = 1` is the default. - * `c > 0` but small means `f(x) approx ReLu(x) = max(0, x)`. - * `c < 0` flips sign and reflects around the `y-axis`: `f_{-c}(x) = -f_c(-x)`. - * `c = 0` results in a non-bijective transformation and triggers an exception. - - Example Use: - - ```python - # Create the Y=g(X)=softplus(X) transform which works only on Tensors with 1 - # batch ndim and 2 event ndims (i.e., vector of matrices). - softplus = Softplus(event_ndims=2) - x = [[[1., 2], - [3, 4]], - [[5, 6], - [7, 8]]] - log(1 + exp(x)) == softplus.forward(x) - log(exp(x) - 1) == softplus.inverse(x) - ``` - - Note: log(.) and exp(.) are applied element-wise but the Jacobian is a - reduction over the event space. - """ - - @distribution_util.AppendDocstring( - kwargs_dict={ - "hinge_softness": ( - "Nonzero floating point `Tensor`. Controls the softness of what " - "would otherwise be a kink at the origin. Default is 1.0")}) - def __init__(self, - event_ndims=0, - hinge_softness=None, - validate_args=False, - name="softplus"): - with ops.name_scope(name, values=[hinge_softness]): - if hinge_softness is not None: - self._hinge_softness = ops.convert_to_tensor( - hinge_softness, name="hinge_softness") - else: - self._hinge_softness = None - if validate_args: - nonzero_check = check_ops.assert_none_equal( - ops.convert_to_tensor( - 0, dtype=self.hinge_softness.dtype), - self.hinge_softness, - message="hinge_softness must be non-zero") - self._hinge_softness = control_flow_ops.with_dependencies( - [nonzero_check], self.hinge_softness) - - super(Softplus, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) - - def _forward(self, x): - if self.hinge_softness is None: - return nn_ops.softplus(x) - hinge_softness = math_ops.cast(self.hinge_softness, x.dtype) - return hinge_softness * nn_ops.softplus(x / hinge_softness) - - def _inverse(self, y): - if self.hinge_softness is None: - return distribution_util.softplus_inverse(y) - hinge_softness = math_ops.cast(self.hinge_softness, y.dtype) - return hinge_softness * distribution_util.softplus_inverse( - y / hinge_softness) - - def _inverse_log_det_jacobian(self, y): - # Could also do: - # ildj = math_ops.reduce_sum(y - distribution_util.softplus_inverse(y), - # axis=event_dims) - # but the following is more numerically stable. Ie, - # Y = Log[1 + exp{X}] ==> X = Log[exp{Y} - 1] - # ==> dX/dY = exp{Y} / (exp{Y} - 1) - # = 1 / (1 - exp{-Y}), - # which is the most stable for large Y > 0. For small Y, we use - # 1 - exp{-Y} approx Y. - if self.hinge_softness is not None: - y /= math_ops.cast(self.hinge_softness, y.dtype) - return -math_ops.reduce_sum(math_ops.log(-math_ops.expm1(-y)), - axis=self._event_dims_tensor(y)) - - def _forward_log_det_jacobian(self, x): - if self.hinge_softness is not None: - x /= math_ops.cast(self.hinge_softness, x.dtype) - return -math_ops.reduce_sum(nn_ops.softplus(-x), - axis=self._event_dims_tensor(x)) - - @property - def hinge_softness(self): - return self._hinge_softness diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py b/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py index d439f28884d8bd7f2b808317e10c5b5e44bfcfa2..00520bcda85e9527767e6342bf75f10667c264a8 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py @@ -18,12 +18,132 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.weibull_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["Weibull"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "Weibull", +] + + +class Weibull(bijector.Bijector): + """Compute `Y = g(X) = 1 - exp((-X / scale) ** concentration), X >= 0`. + + This bijector maps inputs from `[0, inf]` to [0, 1]`. The inverse of the + bijector applied to a uniform random variable `X ~ U(0, 1) gives back a + random variable with the + [Weibull distribution](https://en.wikipedia.org/wiki/Weibull_distribution): + + ```none + Y ~ Weibull(scale, concentration) + pdf(y; scale, concentration, y >= 0) = (scale / concentration) * ( + scale / concentration) ** (concentration - 1) * exp( + -(y / scale) ** concentration) + ``` + """ + + def __init__(self, + scale=1., + concentration=1., + event_ndims=0, + validate_args=False, + name="weibull"): + """Instantiates the `Weibull` bijector. + + Args: + scale: Positive Float-type `Tensor` that is the same dtype and is + broadcastable with `concentration`. + This is `l` in `Y = g(X) = 1 - exp((-x / l) ** k)`. + concentration: Positive Float-type `Tensor` that is the same dtype and is + broadcastable with `scale`. + This is `k` in `Y = g(X) = 1 - exp((-x / l) ** k)`. + event_ndims: Python scalar indicating the number of dimensions associated + with a particular draw from the distribution. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + with self._name_scope("init", values=[scale, concentration]): + self._scale = ops.convert_to_tensor(scale, name="scale") + self._concentration = ops.convert_to_tensor( + concentration, name="concentration") + check_ops.assert_same_float_dtype([self._scale, self._concentration]) + if validate_args: + self._scale = control_flow_ops.with_dependencies([ + check_ops.assert_positive( + self._scale, + message="Argument scale was not positive") + ], self._scale) + self._concentration = control_flow_ops.with_dependencies([ + check_ops.assert_positive( + self._concentration, + message="Argument concentration was not positive") + ], self._concentration) + + super(Weibull, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) + + @property + def scale(self): + """The `l` in `Y = g(X) = 1 - exp((-x / l) ** k)`.""" + return self._scale + + @property + def concentration(self): + """The `k` in `Y = g(X) = 1 - exp((-x / l) ** k)`.""" + return self._concentration + + def _forward(self, x): + x = self._maybe_assert_valid_x(x) + return -math_ops.expm1(-((x / self.scale) ** self.concentration)) + + def _inverse(self, y): + y = self._maybe_assert_valid_y(y) + return self.scale * (-math_ops.log1p(-y)) ** (1 / self.concentration) + + def _inverse_log_det_jacobian(self, y): + y = self._maybe_assert_valid_y(y) + event_dims = self._event_dims_tensor(y) + return math_ops.reduce_sum( + -math_ops.log1p(-y) + + (1 / self.concentration - 1) * math_ops.log(-math_ops.log1p(-y)) + + math_ops.log(self.scale / self.concentration), + axis=event_dims) + + def _forward_log_det_jacobian(self, x): + x = self._maybe_assert_valid_x(x) + event_dims = self._event_dims_tensor(x) + return math_ops.reduce_sum( + -(x / self.scale) ** self.concentration + + (self.concentration - 1) * math_ops.log(x) + + math_ops.log(self.concentration) + + -self.concentration * math_ops.log(self.scale), + axis=event_dims) + + def _maybe_assert_valid_x(self, x): + if not self.validate_args: + return x + is_valid = check_ops.assert_non_negative( + x, + message="Forward transformation input must be at least {}.".format(0)) + return control_flow_ops.with_dependencies([is_valid], x) + + def _maybe_assert_valid_y(self, y): + if not self.validate_args: + return y + is_positive = check_ops.assert_non_negative( + y, message="Inverse transformation input must be greater than 0.") + less_than_one = check_ops.assert_less_equal( + y, constant_op.constant(1., y.dtype), + message="Inverse transformation input must be less than or equal to 1.") + return control_flow_ops.with_dependencies([is_positive, less_than_one], y) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/weibull_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/weibull_impl.py deleted file mode 100644 index 00520bcda85e9527767e6342bf75f10667c264a8..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/weibull_impl.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Weibull bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector - - -__all__ = [ - "Weibull", -] - - -class Weibull(bijector.Bijector): - """Compute `Y = g(X) = 1 - exp((-X / scale) ** concentration), X >= 0`. - - This bijector maps inputs from `[0, inf]` to [0, 1]`. The inverse of the - bijector applied to a uniform random variable `X ~ U(0, 1) gives back a - random variable with the - [Weibull distribution](https://en.wikipedia.org/wiki/Weibull_distribution): - - ```none - Y ~ Weibull(scale, concentration) - pdf(y; scale, concentration, y >= 0) = (scale / concentration) * ( - scale / concentration) ** (concentration - 1) * exp( - -(y / scale) ** concentration) - ``` - """ - - def __init__(self, - scale=1., - concentration=1., - event_ndims=0, - validate_args=False, - name="weibull"): - """Instantiates the `Weibull` bijector. - - Args: - scale: Positive Float-type `Tensor` that is the same dtype and is - broadcastable with `concentration`. - This is `l` in `Y = g(X) = 1 - exp((-x / l) ** k)`. - concentration: Positive Float-type `Tensor` that is the same dtype and is - broadcastable with `scale`. - This is `k` in `Y = g(X) = 1 - exp((-x / l) ** k)`. - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - """ - self._graph_parents = [] - self._name = name - self._validate_args = validate_args - with self._name_scope("init", values=[scale, concentration]): - self._scale = ops.convert_to_tensor(scale, name="scale") - self._concentration = ops.convert_to_tensor( - concentration, name="concentration") - check_ops.assert_same_float_dtype([self._scale, self._concentration]) - if validate_args: - self._scale = control_flow_ops.with_dependencies([ - check_ops.assert_positive( - self._scale, - message="Argument scale was not positive") - ], self._scale) - self._concentration = control_flow_ops.with_dependencies([ - check_ops.assert_positive( - self._concentration, - message="Argument concentration was not positive") - ], self._concentration) - - super(Weibull, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) - - @property - def scale(self): - """The `l` in `Y = g(X) = 1 - exp((-x / l) ** k)`.""" - return self._scale - - @property - def concentration(self): - """The `k` in `Y = g(X) = 1 - exp((-x / l) ** k)`.""" - return self._concentration - - def _forward(self, x): - x = self._maybe_assert_valid_x(x) - return -math_ops.expm1(-((x / self.scale) ** self.concentration)) - - def _inverse(self, y): - y = self._maybe_assert_valid_y(y) - return self.scale * (-math_ops.log1p(-y)) ** (1 / self.concentration) - - def _inverse_log_det_jacobian(self, y): - y = self._maybe_assert_valid_y(y) - event_dims = self._event_dims_tensor(y) - return math_ops.reduce_sum( - -math_ops.log1p(-y) + - (1 / self.concentration - 1) * math_ops.log(-math_ops.log1p(-y)) + - math_ops.log(self.scale / self.concentration), - axis=event_dims) - - def _forward_log_det_jacobian(self, x): - x = self._maybe_assert_valid_x(x) - event_dims = self._event_dims_tensor(x) - return math_ops.reduce_sum( - -(x / self.scale) ** self.concentration + - (self.concentration - 1) * math_ops.log(x) + - math_ops.log(self.concentration) + - -self.concentration * math_ops.log(self.scale), - axis=event_dims) - - def _maybe_assert_valid_x(self, x): - if not self.validate_args: - return x - is_valid = check_ops.assert_non_negative( - x, - message="Forward transformation input must be at least {}.".format(0)) - return control_flow_ops.with_dependencies([is_valid], x) - - def _maybe_assert_valid_y(self, y): - if not self.validate_args: - return y - is_positive = check_ops.assert_non_negative( - y, message="Inverse transformation input must be greater than 0.") - less_than_one = check_ops.assert_less_equal( - y, constant_op.constant(1., y.dtype), - message="Inverse transformation input must be less than or equal to 1.") - return control_flow_ops.with_dependencies([is_positive, less_than_one], y) diff --git a/tensorflow/contrib/distributions/python/ops/cauchy.py b/tensorflow/contrib/distributions/python/ops/cauchy.py index a17bb091f69b651d21f70a25c5aab61b203e62de..6f5d724a2a945ed8f9c159d8314327c6f994d1db 100644 --- a/tensorflow/contrib/distributions/python/ops/cauchy.py +++ b/tensorflow/contrib/distributions/python/ops/cauchy.py @@ -30,7 +30,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution - __all__ = [ "Cauchy", ] @@ -44,16 +43,17 @@ class Cauchy(distribution.Distribution): The probability density function (pdf) is, ```none - pdf(x; loc, scale) = 1 / (pi * scale * (1 + ((x - loc) / scale)**2)) + pdf(x; loc, scale) = 1 / (pi scale (1 + z**2)) + z = (x - loc) / scale ``` where `loc` is the location, and `scale` is the scale. The Cauchy distribution is a member of the [location-scale family]( https://en.wikipedia.org/wiki/Location-scale_family), i.e. + `Y ~ Cauchy(loc, scale)` is equivalent to, ```none X ~ Cauchy(loc=0, scale=1) - Y ~ Cauchy(loc=loc, scale=scale) Y = loc + scale * X ``` @@ -62,14 +62,16 @@ class Cauchy(distribution.Distribution): Examples of initialization of one or a batch of distributions. ```python + tfd = tf.contrib.distributions + # Define a single scalar Cauchy distribution. - dist = Cauchy(loc=0., scale=3.) + dist = tfd.Cauchy(loc=0., scale=3.) # Evaluate the cdf at 1, returning a scalar. dist.cdf(1.) # Define a batch of two scalar valued Cauchy distributions. - dist = Cauchy(loc=[1, 2.], scale=[11, 22.]) + dist = tfd.Cauchy(loc=[1, 2.], scale=[11, 22.]) # Evaluate the pdf of the first distribution on 0, and the second on 1.5, # returning a length two tensor. @@ -77,18 +79,17 @@ class Cauchy(distribution.Distribution): # Get 3 samples, returning a 3 x 2 tensor. dist.sample([3]) - ``` - - Arguments are broadcast when possible. - ```python + # Arguments are broadcast when possible. # Define a batch of two scalar valued Cauchy distributions. # Both have median 1, but different scales. - dist = tf.contrib.distributions.Cauchy(loc=1., scale=[11, 22.]) + dist = tfd.Cauchy(loc=1., scale=[11, 22.]) + # Evaluate the pdf of both distributions on the same point, 3.0, # returning a length 2 tensor. - dist.prob(3.0) + dist.prob(3.) ``` + """ def __init__(self, @@ -97,7 +98,7 @@ class Cauchy(distribution.Distribution): validate_args=False, allow_nan_stats=True, name="Cauchy"): - """Construct Cauchy distributions with loc and and scale `loc` and `scale`. + """Construct Cauchy distributions. The parameters `loc` and `scale` must be shaped in a way that supports broadcasting (e.g. `loc + scale` is a valid operation). @@ -121,8 +122,8 @@ class Cauchy(distribution.Distribution): """ parameters = locals() with ops.name_scope(name, values=[loc, scale]): - with ops.control_dependencies([check_ops.assert_positive(scale)] if - validate_args else []): + with ops.control_dependencies([check_ops.assert_positive(scale)] + if validate_args else []): self._loc = array_ops.identity(loc, name="loc") self._scale = array_ops.identity(scale, name="scale") check_ops.assert_same_float_dtype([self._loc, self._scale]) @@ -138,8 +139,8 @@ class Cauchy(distribution.Distribution): @staticmethod def _param_shapes(sample_shape): return dict( - zip(("loc", "scale"), ([ops.convert_to_tensor( - sample_shape, dtype=dtypes.int32)] * 2))) + zip(("loc", "scale"), + ([ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)] * 2))) @property def loc(self): @@ -153,13 +154,10 @@ class Cauchy(distribution.Distribution): def _batch_shape_tensor(self): return array_ops.broadcast_dynamic_shape( - array_ops.shape(self.loc), - array_ops.shape(self.scale)) + array_ops.shape(self.loc), array_ops.shape(self.scale)) def _batch_shape(self): - return array_ops.broadcast_static_shape( - self.loc.shape, - self.scale.shape) + return array_ops.broadcast_static_shape(self.loc.shape, self.scale.shape) def _event_shape_tensor(self): return constant_op.constant([], dtype=dtypes.int32) diff --git a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py index 599c855cda434d9249187d5d154d50a8a8c49a6c..1d4c5660d8d73b7b6a7e758fc834ccfddeb5c8ea 100644 --- a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py @@ -121,7 +121,7 @@ class ConditionalTransformedDistribution( log_prob = self.distribution.log_prob(x, **distribution_kwargs) if self._is_maybe_event_override: log_prob = math_ops.reduce_sum(log_prob, self._reduce_event_indices) - return ildj + log_prob + return math_ops.cast(ildj, log_prob.dtype) + log_prob @distribution_util.AppendDocstring(kwargs_dict=_condition_kwargs_dict) def _prob(self, y, bijector_kwargs=None, distribution_kwargs=None): @@ -143,7 +143,7 @@ class ConditionalTransformedDistribution( prob = self.distribution.prob(x, **distribution_kwargs) if self._is_maybe_event_override: prob = math_ops.reduce_prod(prob, self._reduce_event_indices) - return math_ops.exp(ildj) * prob + return math_ops.exp(math_ops.cast(ildj, prob.dtype)) * prob @distribution_util.AppendDocstring(kwargs_dict=_condition_kwargs_dict) def _log_cdf(self, y, bijector_kwargs=None, distribution_kwargs=None): diff --git a/tensorflow/contrib/distributions/python/ops/deterministic.py b/tensorflow/contrib/distributions/python/ops/deterministic.py index 850d08d1bd69ebc7661557d648e2bffe77e6a908..8049522e9f5dc26b244b7e710a9ae8b981efd6b6 100644 --- a/tensorflow/contrib/distributions/python/ops/deterministic.py +++ b/tensorflow/contrib/distributions/python/ops/deterministic.py @@ -290,8 +290,10 @@ class VectorDeterministic(_BaseDeterministic): #### Examples ```python + tfd = tf.contrib.distributions + # Initialize a single VectorDeterministic supported at [0., 2.] in R^2. - constant = tf.contrib.distributions.Deterministic([0., 2.]) + constant = tfd.Deterministic([0., 2.]) constant.prob([0., 2.]) ==> 1. constant.prob([0., 3.]) @@ -299,7 +301,7 @@ class VectorDeterministic(_BaseDeterministic): # Initialize a [3] batch of constants on R^2. loc = [[0., 1.], [2., 3.], [4., 5.]] - constant = constant_lib.VectorDeterministic(loc) + constant = tfd.VectorDeterministic(loc) constant.prob([[0., 1.], [1.9, 3.], [3.99, 5.]]) ==> [1., 0., 0.] ``` diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py index 869b5698e57d199755ce1686a74a1eafe3b73e7d..a4d249d41ec9733721a3583d3708e0da56db1733 100644 --- a/tensorflow/contrib/distributions/python/ops/distribution_util.py +++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py @@ -19,9 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib import linalg -from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -330,54 +328,14 @@ def shapes_from_loc_and_scale(loc, scale, name="shapes_from_loc_and_scale"): else: loc_batch_shape = ops.convert_to_tensor(loc_batch_shape, name="loc_batch_shape") + # This is defined in the core util module. + # pylint: disable=undefined-variable batch_shape = prefer_static_broadcast_shape(batch_shape, loc_batch_shape) + # pylint: enable=undefined-variable return batch_shape, event_shape -def prefer_static_broadcast_shape( - shape1, shape2, name="prefer_static_broadcast_shape"): - """Convenience function which statically broadcasts shape when possible. - - Args: - shape1: `1-D` integer `Tensor`. Already converted to tensor! - shape2: `1-D` integer `Tensor`. Already converted to tensor! - name: A string name to prepend to created ops. - - Returns: - The broadcast shape, either as `TensorShape` (if broadcast can be done - statically), or as a `Tensor`. - """ - with ops.name_scope(name, values=[shape1, shape2]): - def make_shape_tensor(x): - return ops.convert_to_tensor(x, name="shape", dtype=dtypes.int32) - - def get_tensor_shape(s): - if isinstance(s, tensor_shape.TensorShape): - return s - s_ = tensor_util.constant_value(make_shape_tensor(s)) - if s_ is not None: - return tensor_shape.TensorShape(s_) - return None - - def get_shape_tensor(s): - if not isinstance(s, tensor_shape.TensorShape): - return make_shape_tensor(s) - if s.is_fully_defined(): - return make_shape_tensor(s.as_list()) - raise ValueError("Cannot broadcast from partially " - "defined `TensorShape`.") - - shape1_ = get_tensor_shape(shape1) - shape2_ = get_tensor_shape(shape2) - if shape1_ is not None and shape2_ is not None: - return array_ops.broadcast_static_shape(shape1_, shape2_) - - shape1_ = get_shape_tensor(shape1) - shape2_ = get_shape_tensor(shape2) - return array_ops.broadcast_dynamic_shape(shape1_, shape2_) - - def get_broadcast_shape(*tensors): """Get broadcast shape as a Python list of integers (preferred) or `Tensor`. diff --git a/tensorflow/contrib/distributions/python/ops/gumbel.py b/tensorflow/contrib/distributions/python/ops/gumbel.py index ba8d3c639b397422f0f6210ba9f48650f0da1e3e..d0efaefb8e78ddf4436e9e5a112d2c1cdddaf3b5 100644 --- a/tensorflow/contrib/distributions/python/ops/gumbel.py +++ b/tensorflow/contrib/distributions/python/ops/gumbel.py @@ -62,15 +62,17 @@ class _Gumbel(distribution.Distribution): Examples of initialization of one or a batch of distributions. ```python + tfd = tf.contrib.distributions + # Define a single scalar Gumbel distribution. - dist = tf.contrib.distributions.Gumbel(loc=0., scale=3.) + dist = tfd.Gumbel(loc=0., scale=3.) # Evaluate the cdf at 1, returning a scalar. dist.cdf(1.) # Define a batch of two scalar valued Gumbels. # The first has mean 1 and scale 11, the second 2 and 22. - dist = tf.contrib.distributions.Gumbel(loc=[1, 2.], scale=[11, 22.]) + dist = tfd.Gumbel(loc=[1, 2.], scale=[11, 22.]) # Evaluate the pdf of the first distribution on 0, and the second on 1.5, # returning a length two tensor. @@ -85,7 +87,7 @@ class _Gumbel(distribution.Distribution): ```python # Define a batch of two scalar valued Logistics. # Both have mean 1, but different scales. - dist = tf.contrib.distributions.Gumbel(loc=1., scale=[11, 22.]) + dist = tfd.Gumbel(loc=1., scale=[11, 22.]) # Evaluate the pdf of both distributions on the same point, 3.0, # returning a length 2 tensor. diff --git a/tensorflow/contrib/distributions/python/ops/half_normal.py b/tensorflow/contrib/distributions/python/ops/half_normal.py new file mode 100644 index 0000000000000000000000000000000000000000..fc0751a6e0b78cb3d79bd3478e740bb05cd26428 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/half_normal.py @@ -0,0 +1,171 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Half Normal distribution class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import random_ops +from tensorflow.python.ops.distributions import distribution +from tensorflow.python.ops.distributions import special_math + + +__all__ = [ + "HalfNormal", +] + + +class HalfNormal(distribution.Distribution): + """The Half Normal distribution with scale `scale`. + + #### Mathematical details + + The half normal is a transformation of a centered normal distribution. + If some random variable `X` has normal distribution, + ```none + X ~ Normal(0.0, scale) + Y = |X| + ``` + Then `Y` will have half normal distribution. The probability density + function (pdf) is: + + ```none + pdf(x; scale, x > 0) = sqrt(2) / (scale * sqrt(pi)) * + exp(- 1/2 * (x / scale) ** 2) + ) + ``` + Where `scale = sigma` is the standard deviation of the underlying normal + distribution. + + #### Examples + + Examples of initialization of one or a batch of distributions. + + ```python + # Define a single scalar HalfNormal distribution. + dist = tf.contrib.distributions.HalfNormal(scale=3.0) + + # Evaluate the cdf at 1, returning a scalar. + dist.cdf(1.) + + # Define a batch of two scalar valued HalfNormals. + # The first has scale 11.0, the second 22.0 + dist = tf.contrib.distributions.HalfNormal(scale=[11.0, 22.0]) + + # Evaluate the pdf of the first distribution on 1.0, and the second on 1.5, + # returning a length two tensor. + dist.prob([1.0, 1.5]) + + # Get 3 samples, returning a 3 x 2 tensor. + dist.sample([3]) + ``` + + """ + + def __init__(self, + scale, + validate_args=False, + allow_nan_stats=True, + name="HalfNormal"): + """Construct HalfNormals with scale `scale`. + + Args: + scale: Floating point tensor; the scales of the distribution(s). + Must contain only positive values. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `bool`, default `True`. When `True`, + statistics (e.g., mean, mode, variance) use the value "`NaN`" to + indicate the result is undefined. When `False`, an exception is raised + if one or more of the statistic's batch members are undefined. + name: Python `str` name prefixed to Ops created by this class. + """ + parameters = locals() + with ops.name_scope(name, values=[scale]): + with ops.control_dependencies([check_ops.assert_positive(scale)] if + validate_args else []): + self._scale = array_ops.identity(scale, name="scale") + super(HalfNormal, self).__init__( + dtype=self._scale.dtype, + reparameterization_type=distribution.FULLY_REPARAMETERIZED, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + parameters=parameters, + graph_parents=[self._scale], + name=name) + + @staticmethod + def _param_shapes(sample_shape): + return {"scale": ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)} + + @property + def scale(self): + """Distribution parameter for the scale.""" + return self._scale + + def _batch_shape_tensor(self): + return array_ops.shape(self.scale) + + def _batch_shape(self): + return self.scale.shape + + def _event_shape_tensor(self): + return constant_op.constant([], dtype=dtypes.int32) + + def _event_shape(self): + return tensor_shape.scalar() + + def _sample_n(self, n, seed=None): + shape = array_ops.concat([[n], self.batch_shape_tensor()], 0) + sampled = random_ops.random_normal( + shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed) + return math_ops.abs(sampled * self.scale) + + def _prob(self, x): + coeff = np.sqrt(2) / self.scale / np.sqrt(np.pi) + pdf = coeff * math_ops.exp(- 0.5 * (x / self.scale) ** 2) + return pdf * math_ops.cast(x >= 0, self.dtype) + + def _cdf(self, x): + truncated_x = nn.relu(x) + return math_ops.erf(truncated_x / self.scale / np.sqrt(2.0)) + + def _entropy(self): + return 0.5 * math_ops.log(np.pi * self.scale ** 2.0 / 2.0) + 0.5 + + def _mean(self): + return self.scale * np.sqrt(2.0) / np.sqrt(np.pi) + + def _quantile(self, p): + return np.sqrt(2.0) * self.scale * special_math.erfinv(p) + + def _mode(self): + return array_ops.zeros(self.batch_shape_tensor()) + + def _variance(self): + return self.scale ** 2.0 * (1.0 - 2.0 / np.pi) diff --git a/tensorflow/contrib/distributions/python/ops/independent.py b/tensorflow/contrib/distributions/python/ops/independent.py index 6a74ca9a0ae1ad30081d21cc15a65be052a99e2a..cbce005013281ff3c58c94d525d5ce7a865d725a 100644 --- a/tensorflow/contrib/distributions/python/ops/independent.py +++ b/tensorflow/contrib/distributions/python/ops/independent.py @@ -68,11 +68,11 @@ class Independent(distribution_lib.Distribution): #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Make independent distribution from a 2-batch Normal. - ind = ds.Independent( - distribution=ds.Normal(loc=[-1., 1], scale=[0.1, 0.5]), + ind = tfd.Independent( + distribution=tfd.Normal(loc=[-1., 1], scale=[0.1, 0.5]), reinterpreted_batch_ndims=1) # All batch dims have been "absorbed" into event dims. @@ -80,8 +80,8 @@ class Independent(distribution_lib.Distribution): ind.event_shape # ==> [2] # Make independent distribution from a 2-batch bivariate Normal. - ind = ds.Independent( - distribution=ds.MultivariateNormalDiag( + ind = tfd.Independent( + distribution=tfd.MultivariateNormalDiag( loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5]), reinterpreted_batch_ndims=1) diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py index 956dee38a378813434656a28a69c89b6ec1e8b72..ee4d86867d48b20e97757bcec57d452085814b80 100644 --- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py +++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py @@ -88,8 +88,9 @@ class InverseGamma(distribution.Distribution): #### Examples ```python - dist = InverseGamma(concentration=3.0, rate=2.0) - dist2 = InverseGamma(concentration=[3.0, 4.0], rate=[2.0, 3.0]) + tfd = tf.contrib.distributions + dist = tfd.InverseGamma(concentration=3.0, rate=2.0) + dist2 = tfd.InverseGamma(concentration=[3.0, 4.0], rate=[2.0, 3.0]) ``` """ diff --git a/tensorflow/contrib/distributions/python/ops/logistic.py b/tensorflow/contrib/distributions/python/ops/logistic.py index 48794a48828fe796e233e968d8c755136ce166ad..473677f8d91b184e029f345bb05f5c5d63df7a40 100644 --- a/tensorflow/contrib/distributions/python/ops/logistic.py +++ b/tensorflow/contrib/distributions/python/ops/logistic.py @@ -60,15 +60,17 @@ class Logistic(distribution.Distribution): Examples of initialization of one or a batch of distributions. ```python + tfd = tf.contrib.distributions + # Define a single scalar Logistic distribution. - dist = tf.contrib.distributions.Logistic(loc=0., scale=3.) + dist = tfd.Logistic(loc=0., scale=3.) # Evaluate the cdf at 1, returning a scalar. dist.cdf(1.) # Define a batch of two scalar valued Logistics. # The first has mean 1 and scale 11, the second 2 and 22. - dist = tf.contrib.distributions.Logistic(loc=[1, 2.], scale=[11, 22.]) + dist = tfd.Logistic(loc=[1, 2.], scale=[11, 22.]) # Evaluate the pdf of the first distribution on 0, and the second on 1.5, # returning a length two tensor. @@ -76,14 +78,11 @@ class Logistic(distribution.Distribution): # Get 3 samples, returning a 3 x 2 tensor. dist.sample([3]) - ``` - Arguments are broadcast when possible. - - ```python + # Arguments are broadcast when possible. # Define a batch of two scalar valued Logistics. # Both have mean 1, but different scales. - dist = tf.contrib.distributions.Logistic(loc=1., scale=[11, 22.]) + dist = tfd.Logistic(loc=1., scale=[11, 22.]) # Evaluate the pdf of both distributions on the same point, 3.0, # returning a length 2 tensor. diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py index e676931d9145e72907d990148ee2d180e0da0258..f2d492f5489a197157558ae727416b51db04793e 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture.py +++ b/tensorflow/contrib/distributions/python/ops/mixture.py @@ -49,13 +49,13 @@ class Mixture(distribution.Distribution): ```python # Create a mixture of two Gaussians: - ds = tf.contrib.distributions + tfd = tf.contrib.distributions mix = 0.3 - bimix_gauss = ds.Mixture( - cat=ds.Categorical(probs=[mix, 1.-mix]), + bimix_gauss = tfd.Mixture( + cat=tfd.Categorical(probs=[mix, 1.-mix]), components=[ - ds.Normal(loc=-1., scale=0.1), - ds.Normal(loc=+1., scale=0.5), + tfd.Normal(loc=-1., scale=0.1), + tfd.Normal(loc=+1., scale=0.5), ]) # Plot the PDF. diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py index 5558ef0f255db684b229d129666634e50c625887..0ca236c3761f9d3a0fcc79ff9db792319108db0d 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py +++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py @@ -43,15 +43,14 @@ class MixtureSameFamily(distribution.Distribution): #### Examples ```python - import matplotlib.pyplot as plt - ds = tf.contrib.distributions + tfd = tf.contrib.distributions ### Create a mixture of two scalar Gaussians: - gm = ds.MixtureSameFamily( - mixture_distribution=ds.Categorical( + gm = tfd.MixtureSameFamily( + mixture_distribution=tfd.Categorical( probs=[0.3, 0.7]), - components_distribution=ds.Normal( + components_distribution=tfd.Normal( loc=[-1., 1], # One for each component. scale=[0.1, 0.5])) # And same here. @@ -63,14 +62,15 @@ class MixtureSameFamily(distribution.Distribution): # Plot PDF. x = np.linspace(-2., 3., int(1e4), dtype=np.float32) + import matplotlib.pyplot as plt plt.plot(x, gm.prob(x).eval()); ### Create a mixture of two Bivariate Gaussians: - gm = ds.MixtureSameFamily( - mixture_distribution=ds.Categorical( + gm = tfd.MixtureSameFamily( + mixture_distribution=tfd.Categorical( probs=[0.3, 0.7]), - components_distribution=ds.MultivariateNormalDiag( + components_distribution=tfd.MultivariateNormalDiag( loc=[[-1., 1], # component 1 [1, -1]], # component 2 scale_identity_multiplier=[.3, .6])) @@ -320,13 +320,14 @@ class MixtureSameFamily(distribution.Distribution): return array_ops.shape(d.batch_shape_tensor())[0] dist_batch_ndims = _get_ndims(self) cat_batch_ndims = _get_ndims(self.mixture_distribution) - bnd = distribution_util.pick_vector( + pad_ndims = array_ops.where( self.mixture_distribution.is_scalar_batch(), - [dist_batch_ndims], [cat_batch_ndims])[0] + dist_batch_ndims, + dist_batch_ndims - cat_batch_ndims) s = array_ops.shape(x) x = array_ops.reshape(x, shape=array_ops.concat([ s[:-1], - array_ops.ones([bnd], dtype=dtypes.int32), + array_ops.ones([pad_ndims], dtype=dtypes.int32), s[-1:], array_ops.ones([self._event_ndims], dtype=dtypes.int32), ], axis=0)) diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag.py b/tensorflow/contrib/distributions/python/ops/mvn_diag.py index 163cf75d990d5fe7ec1e3aaf0040fc71f61774a7..e862552880f4073c8fa8e90134d0633e7484b0bf 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag.py @@ -84,10 +84,10 @@ class MultivariateNormalDiag( #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 2-variate Gaussian. - mvn = ds.MultivariateNormalDiag( + mvn = tfd.MultivariateNormalDiag( loc=[1., -1], scale_diag=[1, 2.]) @@ -101,7 +101,7 @@ class MultivariateNormalDiag( mvn.prob([-1., 0]).eval() # shape: [] # Initialize a 3-batch, 2-variate scaled-identity Gaussian. - mvn = ds.MultivariateNormalDiag( + mvn = tfd.MultivariateNormalDiag( loc=[1., -1], scale_identity_multiplier=[1, 2., 3]) @@ -119,7 +119,7 @@ class MultivariateNormalDiag( mvn.prob([-1., 0]).eval() # shape: [3] # Initialize a 2-batch of 3-variate Gaussians. - mvn = ds.MultivariateNormalDiag( + mvn = tfd.MultivariateNormalDiag( loc=[[1., 2, 3], [11, 22, 33]] # shape: [2, 3] scale_diag=[[1., 2, 3], diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py index 040bc230722194316b8a74627344e315a2578281..413e88f03ae0286c294f3404549a73e1a47dcff7 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py @@ -86,7 +86,7 @@ class MultivariateNormalDiagPlusLowRank( #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 3-variate Gaussian with covariance `cov = S @ S.T`, # `S = diag(d) + U @ diag(m) @ U.T`. The perturbation, `U @ diag(m) @ U.T`, is @@ -97,7 +97,7 @@ class MultivariateNormalDiagPlusLowRank( [-1, 1], [2, -0.5]] # shape: [3, 2] m = [4., 5] # shape: [2] - mvn = ds.MultivariateNormalDiagPlusLowRank( + mvn = tfd.MultivariateNormalDiagPlusLowRank( loc=mu scale_diag=d scale_perturb_factor=U, @@ -118,7 +118,7 @@ class MultivariateNormalDiagPlusLowRank( m = [[0.1, 0.2], [0.4, 0.5]] # shape: [b, r] = [2, 2] - mvn = ds.MultivariateNormalDiagPlusLowRank( + mvn = tfd.MultivariateNormalDiagPlusLowRank( loc=mu, scale_perturb_factor=U, scale_perturb_diag=m) diff --git a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py index f9952b2069d6dfd2593e6bd71ede0badf44cdf98..00a18569fce0175ee39e433dfad796e5f21fe8a4 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py @@ -18,12 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib.distributions.python.ops import mvn_tril from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops __all__ = [ @@ -73,14 +76,14 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 3-variate Gaussian. mu = [1., 2, 3] cov = [[ 0.36, 0.12, 0.06], [ 0.12, 0.29, -0.13], [ 0.06, -0.13, 0.26]] - mvn = ds.MultivariateNormalFullCovariance( + mvn = tfd.MultivariateNormalFullCovariance( loc=mu, covariance_matrix=cov) @@ -100,7 +103,7 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): mu = [[1., 2, 3], [11, 22, 33]] # shape: [2, 3] covariance_matrix = ... # shape: [2, 3, 3], symmetric, positive definite. - mvn = ds.MultivariateNormalFullCovariance( + mvn = tfd.MultivariateNormalFullCovariance( loc=mu, covariance=covariance_matrix) @@ -167,9 +170,12 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): covariance_matrix = ops.convert_to_tensor( covariance_matrix, name="covariance_matrix") if validate_args: - assert_symmetric = check_ops.assert_equal( - covariance_matrix, - array_ops.matrix_transpose(covariance_matrix), + tol = np.finfo(covariance_matrix.dtype.as_numpy_dtype).eps * 10 + diff = math_ops.abs( + covariance_matrix + - array_ops.matrix_transpose(covariance_matrix)) + assert_symmetric = check_ops.assert_less( + diff, tol + tol * math_ops.abs(covariance_matrix), message="Matrix was not symmetric.") covariance_matrix = control_flow_ops.with_dependencies( [assert_symmetric], covariance_matrix) diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py index 300bdd5f6064a1cc9c336689ac4fae04338edb30..a7399792892f4c179c05168184d76ec95c168b51 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py @@ -90,8 +90,7 @@ class MultivariateNormalLinearOperator( #### Examples ```python - ds = tf.contrib.distributions - la = tf.linalg + tfd = tf.contrib.distributions # Initialize a single 3-variate Gaussian. mu = [1., 2, 3] @@ -103,9 +102,9 @@ class MultivariateNormalLinearOperator( # [ 0.2, 0.5, 0. ], # [ 0.1, -0.3, 0.4]]) - mvn = ds.MultivariateNormalLinearOperator( + mvn = tfd.MultivariateNormalLinearOperator( loc=mu, - scale=la.LinearOperatorLowerTriangular(scale)) + scale=tf.linalg.LinearOperatorLowerTriangular(scale)) # Covariance agrees with cholesky(cov) parameterization. mvn.covariance().eval() @@ -122,9 +121,9 @@ class MultivariateNormalLinearOperator( scale_diag = [[1., 2, 3], [0.5, 1, 1.5]] # shape: [2, 3] - mvn = ds.MultivariateNormalLinearOperator( + mvn = tfd.MultivariateNormalLinearOperator( loc=mu, - scale=la.LinearOperatorDiag(scale_diag)) + scale=tf.linalg.LinearOperatorDiag(scale_diag)) # Compute the pdf of two `R^3` observations; return a length-2 vector. x = [[-0.9, 0, 0.1], diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py index 260dcc18f513d5440d3d39368539274c03faa72a..6c7dc4ca7aaf5b3a20b072e9360d15528ad10556 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_tril.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py @@ -76,12 +76,13 @@ class MultivariateNormalTriL( ``` Trainable (batch) lower-triangular matrices can be created with - `ds.matrix_diag_transform()` and/or `ds.fill_triangular()` + `tf.contrib.distributions.matrix_diag_transform()` and/or + `tf.contrib.distributions.fill_triangular()` #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 3-variate Gaussian. mu = [1., 2, 3] @@ -92,7 +93,7 @@ class MultivariateNormalTriL( # ==> [[ 0.6, 0. , 0. ], # [ 0.2, 0.5, 0. ], # [ 0.1, -0.3, 0.4]]) - mvn = ds.MultivariateNormalTriL( + mvn = tfd.MultivariateNormalTriL( loc=mu, scale_tril=scale) @@ -112,7 +113,7 @@ class MultivariateNormalTriL( mu = [[1., 2, 3], [11, 22, 33]] # shape: [2, 3] tril = ... # shape: [2, 3, 3], lower triangular, non-zero diagonal. - mvn = ds.MultivariateNormalTriL( + mvn = tfd.MultivariateNormalTriL( loc=mu, scale_tril=tril) @@ -124,9 +125,9 @@ class MultivariateNormalTriL( # Instantiate a "learnable" MVN. dims = 4 with tf.variable_scope("model"): - mvn = ds.MultivariateNormalTriL( + mvn = tfd.MultivariateNormalTriL( loc=tf.get_variable(shape=[dims], dtype=tf.float32, name="mu"), - scale_tril=ds.fill_triangular( + scale_tril=tfd.fill_triangular( tf.get_variable(shape=[dims * (dims + 1) / 2], dtype=tf.float32, name="chol_Sigma"))) ``` diff --git a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py index 8a95038a3c8eccf8a75fea79d0a62f9883b4f13a..92f2bba1828696248c9d9460566a08ba372c3358 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py +++ b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py @@ -22,21 +22,135 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops import poisson as poisson_lib +from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import categorical as categorical_lib from tensorflow.python.ops.distributions import distribution as distribution_lib +from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.ops.distributions import transformed_distribution as transformed_lib __all__ = [ "PoissonLogNormalQuadratureCompound", + "quadrature_scheme_lognormal_gauss_hermite", + "quadrature_scheme_lognormal_quantiles", ] +def quadrature_scheme_lognormal_gauss_hermite( + loc, scale, quadrature_size, + validate_args=False, name=None): # pylint: disable=unused-argument + """Use Gauss-Hermite quadrature to form quadrature on positive-reals. + + Note: for a given `quadrature_size`, this method is generally less accurate + than `quadrature_scheme_lognormal_quantiles`. + + Args: + loc: `float`-like (batch of) scalar `Tensor`; the location parameter of + the LogNormal prior. + scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of + the LogNormal prior. + quadrature_size: Python `int` scalar representing the number of quadrature + points. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + name: Python `str` name prefixed to Ops created by this class. + + Returns: + grid: (Batch of) length-`quadrature_size` vectors representing the + `log_rate` parameters of a `Poisson`. + probs: (Batch of) length-`quadrature_size` vectors representing the + weight associate with each `grid` value. + """ + with ops.name_scope(name, "vector_diffeomixture_quadrature_gauss_hermite", + [loc, scale]): + grid, probs = np.polynomial.hermite.hermgauss(deg=quadrature_size) + grid = grid.astype(loc.dtype.as_numpy_dtype) + probs = probs.astype(loc.dtype.as_numpy_dtype) + probs /= np.linalg.norm(probs, ord=1, keepdims=True) + probs = ops.convert_to_tensor(probs, name="probs", dtype=loc.dtype) + # The following maps the broadcast of `loc` and `scale` to each grid + # point, i.e., we are creating several log-rates that correspond to the + # different Gauss-Hermite quadrature points and (possible) batches of + # `loc` and `scale`. + grid = (loc[..., array_ops.newaxis] + + np.sqrt(2.) * scale[..., array_ops.newaxis] * grid) + return grid, probs + + +def quadrature_scheme_lognormal_quantiles( + loc, scale, quadrature_size, + validate_args=False, name=None): + """Use LogNormal quantiles to form quadrature on positive-reals. + + Args: + loc: `float`-like (batch of) scalar `Tensor`; the location parameter of + the LogNormal prior. + scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of + the LogNormal prior. + quadrature_size: Python `int` scalar representing the number of quadrature + points. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + name: Python `str` name prefixed to Ops created by this class. + + Returns: + grid: (Batch of) length-`quadrature_size` vectors representing the + `log_rate` parameters of a `Poisson`. + probs: (Batch of) length-`quadrature_size` vectors representing the + weight associate with each `grid` value. + """ + with ops.name_scope(name, "quadrature_scheme_lognormal_quantiles", + [loc, scale]): + # Create a LogNormal distribution. + dist = transformed_lib.TransformedDistribution( + distribution=normal_lib.Normal(loc=loc, scale=scale), + bijector=Exp(event_ndims=0), + validate_args=validate_args) + batch_ndims = dist.batch_shape.ndims + if batch_ndims is None: + batch_ndims = array_ops.shape(dist.batch_shape_tensor())[0] + + def _compute_quantiles(): + """Helper to build quantiles.""" + # Omit {0, 1} since they might lead to Inf/NaN. + zero = array_ops.zeros([], dtype=dist.dtype) + edges = math_ops.linspace(zero, 1., quadrature_size + 3)[1:-1] + # Expand edges so its broadcast across batch dims. + edges = array_ops.reshape(edges, shape=array_ops.concat([ + [-1], array_ops.ones([batch_ndims], dtype=dtypes.int32)], axis=0)) + quantiles = dist.quantile(edges) + # Cyclically permute left by one. + perm = array_ops.concat([ + math_ops.range(1, 1 + batch_ndims), [0]], axis=0) + quantiles = array_ops.transpose(quantiles, perm) + return quantiles + quantiles = _compute_quantiles() + + # Compute grid as quantile midpoints. + grid = (quantiles[..., :-1] + quantiles[..., 1:]) / 2. + # Set shape hints. + grid.set_shape(dist.batch_shape.concatenate([quadrature_size])) + + # By construction probs is constant, i.e., `1 / quadrature_size`. This is + # important, because non-constant probs leads to non-reparameterizable + # samples. + probs = array_ops.fill( + dims=[quadrature_size], + value=1. / math_ops.cast(quadrature_size, dist.dtype)) + + return grid, probs + + class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): """`PoissonLogNormalQuadratureCompound` distribution. @@ -47,30 +161,18 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): ```none p(k|loc, scale) = int_{R_+} dl LogNormal(l | loc, scale) Poisson(k | l) - = int_{R} dz ((lambda(z) sqrt(2) scale) - * exp(-z**2) / (lambda(z) sqrt(2 pi) sigma) - * Poisson(k | lambda(z))) - = int_{R} dz exp(-z**2) / sqrt(pi) Poisson(k | lambda(z)) approx= sum{ prob[d] Poisson(k | lambda(grid[d])) : d=0, ..., deg-1 } ``` - where `lambda(z) = exp(sqrt(2) scale z + loc)` and the `prob,grid` terms - are from [numerical quadrature]( - https://en.wikipedia.org/wiki/Numerical_integration) (default: - [Gauss--Hermite quadrature]( - https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)). Note that - the second line made the substitution: - `z(l) = (log(l) - loc) / (sqrt(2) scale)` which implies `lambda(z)` [above] - and `dl = sqrt(2) scale lambda(z) dz` + By default, the `grid` is chosen as quantiles of the `LogNormal` distribution + parameterized by `loc`, `scale` and the `prob` vector is + `[1. / quadrature_size]*quadrature_size`. In the non-approximation case, a draw from the LogNormal prior represents the Poisson rate parameter. Unfortunately, the non-approximate distribution lacks an analytical probability density function (pdf). Therefore the `PoissonLogNormalQuadratureCompound` class implements an approximation based - on [numerical quadrature]( - https://en.wikipedia.org/wiki/Numerical_integration) (default: - [Gauss--Hermite quadrature]( - https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)). + on [quadrature](https://en.wikipedia.org/wiki/Numerical_integration). Note: although the `PoissonLogNormalQuadratureCompound` is approximately the Poisson-LogNormal compound distribution, it is itself a valid distribution. @@ -84,10 +186,8 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): https://en.wikipedia.org/wiki/Compound_probability_distribution). Using variable-substitution and [numerical quadrature]( https://en.wikipedia.org/wiki/Numerical_integration) (default: - [Gauss--Hermite quadrature]( - https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)) we can - redefine the distribution to be a parameter-less convex combination of `deg` - different Poisson samples. + based on `LogNormal` quantiles) we can redefine the distribution to be a + parameter-less convex combination of `deg` different Poisson samples. That is, defined over positive integers, this distribution is parameterized by a (batch of) `loc` and `scale` scalars. @@ -96,46 +196,51 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): ```none pdf(k | loc, scale, deg) - = sum{ prob[d] Poisson(k | lambda=exp(sqrt(2) scale grid[d] + loc)) + = sum{ prob[d] Poisson(k | lambda=exp(grid[d])) : d=0, ..., deg-1 } ``` - where, [e.g., `grid, w = numpy.polynomial.hermite.hermgauss(deg)`]( - https://docs.scipy.org/doc/numpy-1.10.0/reference/generated/numpy.polynomial.hermite.hermgauss.html) - and `prob = w / sqrt(pi)`. - #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions + # Create two batches of PoissonLogNormalQuadratureCompounds, one with # prior `loc = 0.` and another with `loc = 1.` In both cases `scale = 1.` - pln = ds.PoissonLogNormalQuadratureCompound( + pln = tfd.PoissonLogNormalQuadratureCompound( loc=[0., -0.5], scale=1., - quadrature_grid_and_probs=( - np.polynomial.hermite.hermgauss(deg=10)), + quadrature_size=10, validate_args=True) """ def __init__(self, loc, scale, - quadrature_grid_and_probs=None, + quadrature_size=8, + quadrature_fn=quadrature_scheme_lognormal_quantiles, validate_args=False, allow_nan_stats=True, name="PoissonLogNormalQuadratureCompound"): - """Constructs the PoissonLogNormalQuadratureCompound on `R**k`. + """Constructs the PoissonLogNormalQuadratureCompound`. + + Note: `probs` returned by (optional) `quadrature_fn` are presumed to be + either a length-`quadrature_size` vector or a batch of vectors in 1-to-1 + correspondence with the returned `grid`. (I.e., broadcasting is only + partially supported.) Args: loc: `float`-like (batch of) scalar `Tensor`; the location parameter of the LogNormal prior. scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of the LogNormal prior. - quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s - representing the sample points and the corresponding (possibly - normalized) weight. When `None`, defaults to: - `np.polynomial.hermite.hermgauss(deg=8)`. + quadrature_size: Python `int` scalar representing the number of quadrature + points. + quadrature_fn: Python callable taking `loc`, `scale`, + `quadrature_size`, `validate_args` and returning `tuple(grid, probs)` + representing the LogNormal grid and corresponding normalized weight. + normalized) weight. + Default value: `quadrature_scheme_lognormal_quantiles`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect @@ -147,47 +252,41 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): name: Python `str` name prefixed to Ops created by this class. Raises: - TypeError: if `loc.dtype != scale[0].dtype`. + TypeError: if `quadrature_grid` and `quadrature_probs` have different base + `dtype`. """ parameters = locals() with ops.name_scope(name, values=[loc, scale]): - loc = ops.convert_to_tensor(loc, name="loc") - self._loc = loc + if loc is not None: + loc = ops.convert_to_tensor(loc, name="loc") + if scale is not None: + scale = ops.convert_to_tensor( + scale, dtype=None if loc is None else loc.dtype, name="scale") + self._quadrature_grid, self._quadrature_probs = tuple(quadrature_fn( + loc, scale, quadrature_size, validate_args)) + + dt = self._quadrature_grid.dtype + if dt.base_dtype != self._quadrature_probs.dtype.base_dtype: + raise TypeError("Quadrature grid dtype ({}) does not match quadrature " + "probs dtype ({}).".format( + dt.name, self._quadrature_probs.dtype.name)) - scale = ops.convert_to_tensor(scale, name="scale") - self._scale = scale - - dtype = loc.dtype.base_dtype - if dtype != scale.dtype.base_dtype: - raise TypeError( - "loc.dtype(\"{}\") does not match scale.dtype(\"{}\")".format( - loc.dtype.name, scale.dtype.name)) - - grid, probs = distribution_util.process_quadrature_grid_and_probs( - quadrature_grid_and_probs, dtype, validate_args) - self._quadrature_grid = grid - self._quadrature_probs = probs - self._quadrature_size = distribution_util.dimension_size(probs, axis=0) + self._distribution = poisson_lib.Poisson( + log_rate=self._quadrature_grid, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats) self._mixture_distribution = categorical_lib.Categorical( logits=math_ops.log(self._quadrature_probs), validate_args=validate_args, allow_nan_stats=allow_nan_stats) - # The following maps the broadcast of `loc` and `scale` to each grid - # point, i.e., we are creating several log-rates that correspond to the - # different Gauss-Hermite quadrature points and (possible) batches of - # `loc` and `scale`. - self._log_rate = (loc[..., array_ops.newaxis] - + np.sqrt(2.) * scale[..., array_ops.newaxis] * grid) - - self._distribution = poisson_lib.Poisson( - log_rate=self._log_rate, - validate_args=validate_args, - allow_nan_stats=allow_nan_stats) + self._loc = loc + self._scale = scale + self._quadrature_size = quadrature_size super(PoissonLogNormalQuadratureCompound, self).__init__( - dtype=dtype, + dtype=dt, reparameterization_type=distribution_lib.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, @@ -197,12 +296,12 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): @property def mixture_distribution(self): - """Distribution which randomly selects a Poisson with Gauss-Hermite rate.""" + """Distribution which randomly selects a Poisson with quadrature param.""" return self._mixture_distribution @property def distribution(self): - """Base Poisson parameterized by a Gauss-Hermite grid of rates.""" + """Base Poisson parameterized by a quadrature grid.""" return self._distribution @property @@ -216,24 +315,18 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): return self._scale @property - def quadrature_grid(self): - """Quadrature grid points.""" - return self._quadrature_grid - - @property - def quadrature_probs(self): - """Quadrature normalized weights.""" - return self._quadrature_probs + def quadrature_size(self): + return self._quadrature_size def _batch_shape_tensor(self): return array_ops.broadcast_dynamic_shape( - array_ops.shape(self.loc), - array_ops.shape(self.scale)) + self.distribution.batch_shape_tensor(), + array_ops.shape(self.mixture_distribution.logits))[:-1] def _batch_shape(self): return array_ops.broadcast_static_shape( - self.loc.shape, - self.scale.shape) + self.distribution.batch_shape, + self.mixture_distribution.logits.shape)[:-1] def _event_shape(self): return tensor_shape.scalar() @@ -241,18 +334,31 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): def _sample_n(self, n, seed=None): # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. - batch_size = (np.prod(self.batch_shape.as_list(), dtype=np.int32) - if self.batch_shape.is_fully_defined() - else math_ops.reduce_prod(self.batch_shape_tensor())) + batch_size = self.batch_shape.num_elements() + if batch_size is None: + batch_size = math_ops.reduce_prod(self.batch_shape_tensor()) + # We need to "sample extra" from the mixture distribution if it doesn't + # already specify a probs vector for each batch coordinate. + # We only support this kind of reduced broadcasting, i.e., there is exactly + # one probs vector for all batch dims or one for each. ids = self._mixture_distribution.sample( sample_shape=concat_vectors( [n], distribution_util.pick_vector( - self.is_scalar_batch(), - np.int32([]), - [batch_size])), + self.mixture_distribution.is_scalar_batch(), + [batch_size], + np.int32([]))), seed=distribution_util.gen_new_seed( seed, "poisson_lognormal_quadrature_compound")) + # We need to flatten batch dims in case mixture_distribution has its own + # batch dims. + ids = array_ops.reshape(ids, shape=concat_vectors( + [n], + distribution_util.pick_vector( + self.is_scalar_batch(), + np.int32([]), + np.int32([-1])))) + # Stride `quadrature_size` for `batch_size` number of times. offset = math_ops.range(start=0, limit=batch_size * self._quadrature_size, @@ -275,7 +381,7 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): def _mean(self): return math_ops.exp( math_ops.reduce_logsumexp( - self.mixture_distribution.logits + self._log_rate, + self.mixture_distribution.logits + self.distribution.log_rate, axis=-1)) def _variance(self): @@ -292,7 +398,7 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): # where, # # Z|v ~ interpolate_affine[v](distribution) - # V ~ mixture_distrubution + # V ~ mixture_distribution # # thus, # @@ -300,7 +406,7 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): # Var[E[Z | V]] = sum{ prob[d] (Mean[d] - Mean)**2 : d=0, ..., deg-1 } v = array_ops.stack([ # log(self.distribution.variance()) = log(Var[d]) = log(rate[d]) - self._log_rate, + self.distribution.log_rate, # log((Mean[d] - Mean)**2) 2. * math_ops.log( math_ops.abs(self.distribution.mean() @@ -311,14 +417,9 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): axis=[-2, -1]) -def static_value(x): - """Returns the static value of a `Tensor` or `None`.""" - return tensor_util.constant_value(ops.convert_to_tensor(x)) - - def concat_vectors(*args): """Concatenates input vectors, statically if possible.""" - args_ = [static_value(x) for x in args] + args_ = [distribution_util.static_value(x) for x in args] if any(vec is None for vec in args_): return array_ops.concat(args, axis=0) return [val for vec in args_ for val in vec] diff --git a/tensorflow/contrib/distributions/python/ops/sample_stats.py b/tensorflow/contrib/distributions/python/ops/sample_stats.py index 2a4b92c72900f79785e7e34b77179d3decbace5b..dfc813361977c159d8d48f9d5b9ff03db5b4acdc 100644 --- a/tensorflow/contrib/distributions/python/ops/sample_stats.py +++ b/tensorflow/contrib/distributions/python/ops/sample_stats.py @@ -28,12 +28,190 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import spectral_ops +from tensorflow.python.ops.distributions import util __all__ = [ + "auto_correlation", "percentile", ] +# TODO(langmore) Write separate versions of this for real/complex dtype, taking +# advantage of optimized real-fft ops. +def auto_correlation( + x, + axis=-1, + max_lags=None, + center=True, + normalize=True, + name="auto_correlation"): + """Auto correlation along one axis. + + Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation + `RXX` may be defined as (with `E` expectation and `Conj` complex conjugate) + + ``` + RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) }, + W[n] := (X[n] - MU) / S, + MU := E{ X[0] }, + S**2 := E{ (X[0] - MU) Conj(X[0] - MU) }. + ``` + + This function takes the viewpoint that `x` is (along one axis) a finite + sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an + estimate of `RXX[m]` as follows: + + After extending `x` from length `L` to `inf` by zero padding, the auto + correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as + + ``` + rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]), + w[n] := (x[n] - mu) / s, + mu := L**-1 sum_n x[n], + s**2 := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu) + ``` + + The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users + often set `max_lags` small enough so that the entire output is meaningful. + + Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by + `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation + contains a slight bias, which goes to zero as `len(x) - m --> infinity`. + + Args: + x: `float32` or `complex64` `Tensor`. + axis: Python `int`. The axis number along which to compute correlation. + Other dimensions index different batch members. + max_lags: Positive `int` tensor. The maximum value of `m` to consider + (in equation above). If `max_lags >= x.shape[axis]`, we effectively + re-set `max_lags` to `x.shape[axis] - 1`. + center: Python `bool`. If `False`, do not subtract the mean estimate `mu` + from `x[n]` when forming `w[n]`. + normalize: Python `bool`. If `False`, do not divide by the variance + estimate `s**2` when forming `w[n]`. + name: `String` name to prepend to created ops. + + Returns: + `rxx`: `Tensor` of same `dtype` as `x`. `rxx.shape[i] = x.shape[i]` for + `i != axis`, and `rxx.shape[axis] = max_lags + 1`. + + Raises: + TypeError: If `x` is not a supported type. + """ + # Implementation details: + # Extend length N / 2 1-D array x to length N by zero padding onto the end. + # Then, set + # F[x]_k := sum_n x_n exp{-i 2 pi k n / N }. + # It is not hard to see that + # F[x]_k Conj(F[x]_k) = F[R]_k, where + # R_m := sum_n x_n Conj(x_{(n - m) mod N}). + # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m]. + + # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT + # based version of estimating RXX. + # Note that this is a special case of the Wiener-Khinchin Theorem. + with ops.name_scope(name, values=[x]): + x = ops.convert_to_tensor(x, name="x") + + # Rotate dimensions of x in order to put axis at the rightmost dim. + # FFT op requires this. + rank = util.prefer_static_rank(x) + if axis < 0: + axis = rank + axis + shift = rank - 1 - axis + # Suppose x.shape[axis] = T, so there are T "time" steps. + # ==> x_rotated.shape = B + [T], + # where B is x_rotated's batch shape. + x_rotated = util.rotate_transpose(x, shift) + + if center: + x_rotated -= math_ops.reduce_mean(x_rotated, axis=-1, keepdims=True) + + # x_len = N / 2 from above explanation. The length of x along axis. + # Get a value for x_len that works in all cases. + x_len = util.prefer_static_shape(x_rotated)[-1] + + # TODO(langmore) Investigate whether this zero padding helps or hurts. At + # the moment is is necessary so that all FFT implementations work. + # Zero pad to the next power of 2 greater than 2 * x_len, which equals + # 2**(ceil(Log_2(2 * x_len))). Note: Log_2(X) = Log_e(X) / Log_e(2). + x_len_float64 = math_ops.cast(x_len, np.float64) + target_length = math_ops.pow( + np.float64(2.), + math_ops.ceil(math_ops.log(x_len_float64 * 2) / np.log(2.))) + pad_length = math_ops.cast(target_length - x_len_float64, np.int32) + + # We should have: + # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length] + # = B + [T + pad_length] + x_rotated_pad = util.pad(x_rotated, axis=-1, back=True, count=pad_length) + + dtype = x.dtype + if not dtype.is_complex: + if not dtype.is_floating: + raise TypeError("Argument x must have either float or complex dtype" + " found: {}".format(dtype)) + x_rotated_pad = math_ops.complex(x_rotated_pad, + dtype.real_dtype.as_numpy_dtype(0.)) + + # Autocorrelation is IFFT of power-spectral density (up to some scaling). + fft_x_rotated_pad = spectral_ops.fft(x_rotated_pad) + spectral_density = fft_x_rotated_pad * math_ops.conj(fft_x_rotated_pad) + # shifted_product is R[m] from above detailed explanation. + # It is the inner product sum_n X[n] * Conj(X[n - m]). + shifted_product = spectral_ops.ifft(spectral_density) + + # Cast back to real-valued if x was real to begin with. + shifted_product = math_ops.cast(shifted_product, dtype) + + # Figure out if we can deduce the final static shape, and set max_lags. + # Use x_rotated as a reference, because it has the time dimension in the far + # right, and was created before we performed all sorts of crazy shape + # manipulations. + know_static_shape = True + if not x_rotated.shape.is_fully_defined(): + know_static_shape = False + if max_lags is None: + max_lags = x_len - 1 + else: + max_lags = ops.convert_to_tensor(max_lags, name="max_lags") + max_lags_ = tensor_util.constant_value(max_lags) + if max_lags_ is None or not know_static_shape: + know_static_shape = False + max_lags = math_ops.minimum(x_len - 1, max_lags) + else: + max_lags = min(x_len - 1, max_lags_) + + # Chop off the padding. + # We allow users to provide a huge max_lags, but cut it off here. + # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags] + shifted_product_chopped = shifted_product[..., :max_lags + 1] + + # If possible, set shape. + if know_static_shape: + chopped_shape = x_rotated.shape.as_list() + chopped_shape[-1] = min(x_len, max_lags + 1) + shifted_product_chopped.set_shape(chopped_shape) + + # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]). The + # other terms were zeros arising only due to zero padding. + # `denominator = (N / 2 - m)` (defined below) is the proper term to + # divide by by to make this an unbiased estimate of the expectation + # E[X[n] Conj(X[n - m])]. + x_len = math_ops.cast(x_len, dtype.real_dtype) + max_lags = math_ops.cast(max_lags, dtype.real_dtype) + denominator = x_len - math_ops.range(0., max_lags + 1.) + denominator = math_ops.cast(denominator, dtype) + shifted_product_rotated = shifted_product_chopped / denominator + + if normalize: + shifted_product_rotated /= shifted_product_rotated[..., :1] + + # Transpose dimensions back to those of x. + return util.rotate_transpose(shifted_product_rotated, -shift) + + # TODO(langmore) To make equivalent to numpy.percentile: # Make work with a sequence of floats or single float for 'q'. # Make work with "linear", "midpoint" interpolation. (linear should be default) diff --git a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py index b05f15771a3a94779ffddea8f16ad2fa4ea2fdd1..c4b8f055b7fbc3f0835b503eddd7617610326d8c 100644 --- a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py +++ b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py @@ -115,7 +115,7 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution): tailweight: Tailweight parameter. Default is `1.0` (unchanged tailweight) distribution: `tf.Distribution`-like instance. Distribution that is transformed to produce this distribution. - Default is `ds.Normal(0., 1.)`. + Default is `tf.distributions.Normal(0., 1.)`. Must be a scalar-batch, scalar-event distribution. Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is a function of non-trainable parameters. WARNING: If you backprop through diff --git a/tensorflow/contrib/distributions/python/ops/test_util.py b/tensorflow/contrib/distributions/python/ops/test_util.py index 77f2a39273dc365a4ac202d846dd2bc364655c86..bfc727450f5e48ecbf98bf8ab0475ec67c9e7137 100644 --- a/tensorflow/contrib/distributions/python/ops/test_util.py +++ b/tensorflow/contrib/distributions/python/ops/test_util.py @@ -40,6 +40,7 @@ class DiscreteScalarDistributionTestHelpers(object): def run_test_sample_consistent_log_prob( self, sess_run_fn, dist, num_samples=int(1e5), num_threshold=int(1e3), seed=42, + batch_size=None, rtol=1e-2, atol=0.): """Tests that sample/log_prob are consistent with each other. @@ -66,6 +67,8 @@ class DiscreteScalarDistributionTestHelpers(object): seed: Python `int` indicating the seed to use when sampling from `dist`. In general it is not recommended to use `None` during a test as this increases the likelihood of spurious test failure. + batch_size: Hint for unpacking result of samples. Default: `None` means + batch_size is inferred. rtol: Python `float`-type indicating the admissible relative error between analytical and sample statistics. atol: Python `float`-type indicating the admissible absolute error between @@ -80,10 +83,11 @@ class DiscreteScalarDistributionTestHelpers(object): # Histogram only supports vectors so we call it once per batch coordinate. y = dist.sample(num_samples, seed=seed) y = array_ops.reshape(y, shape=[num_samples, -1]) - batch_size = math_ops.reduce_prod(dist.batch_shape_tensor()) + if batch_size is None: + batch_size = math_ops.reduce_prod(dist.batch_shape_tensor()) batch_dims = array_ops.shape(dist.batch_shape_tensor())[0] edges_expanded_shape = 1 + array_ops.pad([-2], paddings=[[0, batch_dims]]) - for b, x in enumerate(array_ops.unstack(y, axis=1)): + for b, x in enumerate(array_ops.unstack(y, num=batch_size, axis=1)): counts, edges = self.histogram(x) edges = array_ops.reshape(edges, edges_expanded_shape) probs = math_ops.exp(dist.log_prob(edges)) diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py index 92043d6a08833888c36009261addca0d14949ea8..7ce8a83fd91e2dfaa0ccef633f803b3ae595e646 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py +++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py @@ -22,30 +22,176 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops.bijectors.affine_linear_operator import AffineLinearOperator +from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered from tensorflow.contrib.linalg.python.ops import linear_operator_addition as linop_add_lib -from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import categorical as categorical_lib from tensorflow.python.ops.distributions import distribution as distribution_lib +from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.ops.linalg import linear_operator_diag as linop_diag_lib from tensorflow.python.ops.linalg import linear_operator_full_matrix as linop_full_lib from tensorflow.python.ops.linalg import linear_operator_identity as linop_identity_lib from tensorflow.python.ops.linalg import linear_operator_lower_triangular as linop_tril_lib -static_value = distribution_util.static_value - __all__ = [ "VectorDiffeomixture", + "quadrature_scheme_softmaxnormal_gauss_hermite", + "quadrature_scheme_softmaxnormal_quantiles", ] +def quadrature_scheme_softmaxnormal_gauss_hermite( + loc, scale, quadrature_size, + validate_args=False, name=None): + """Use Gauss-Hermite quadrature to form quadrature on `K - 1` simplex. + + Note: for a given `quadrature_size`, this method is generally less accurate + than `quadrature_scheme_softmaxnormal_quantiles`. + + Args: + loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0. + Represents the `location` parameter of the SoftmaxNormal used for + selecting one of the `K` affine transformations. + scale: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0. + Represents the `scale` parameter of the SoftmaxNormal used for + selecting one of the `K` affine transformations. + quadrature_size: Python `int` scalar representing the number of quadrature + points. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + name: Python `str` name prefixed to Ops created by this class. + + Returns: + grid: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the + convex combination of affine parameters for `K` components. + `grid[..., :, n]` is the `n`-th grid point, living in the `K - 1` simplex. + probs: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the + associated with each grid point. + """ + with ops.name_scope(name, "quadrature_scheme_softmaxnormal_gauss_hermite", + [loc, scale]): + loc = ops.convert_to_tensor(loc, name="loc") + dt = loc.dtype.base_dtype + scale = ops.convert_to_tensor(scale, dtype=dt, name="scale") + + loc = maybe_check_quadrature_param(loc, "loc", validate_args) + scale = maybe_check_quadrature_param(scale, "scale", validate_args) + + grid, probs = np.polynomial.hermite.hermgauss(deg=quadrature_size) + grid = grid.astype(loc.dtype.as_numpy_dtype) + probs = probs.astype(loc.dtype.as_numpy_dtype) + probs /= np.linalg.norm(probs, ord=1, keepdims=True) + probs = ops.convert_to_tensor(probs, name="probs", dtype=loc.dtype) + + grid = softmax( + -distribution_util.pad( + (loc[..., array_ops.newaxis] + + np.sqrt(2.) * scale[..., array_ops.newaxis] * grid), + axis=-2, + front=True), + axis=-2) # shape: [B, components, deg] + + return grid, probs + + +def quadrature_scheme_softmaxnormal_quantiles( + loc, scale, quadrature_size, + validate_args=False, name=None): + """Use SoftmaxNormal quantiles to form quadrature on `K - 1` simplex. + + Args: + loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0. + Represents the `location` parameter of the SoftmaxNormal used for + selecting one of the `K` affine transformations. + scale: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0. + Represents the `scale` parameter of the SoftmaxNormal used for + selecting one of the `K` affine transformations. + quadrature_size: Python scalar `int` representing the number of quadrature + points. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + name: Python `str` name prefixed to Ops created by this class. + + Returns: + grid: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the + convex combination of affine parameters for `K` components. + `grid[..., :, n]` is the `n`-th grid point, living in the `K - 1` simplex. + probs: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the + associated with each grid point. + """ + with ops.name_scope(name, "softmax_normal_grid_and_probs", [loc, scale]): + loc = ops.convert_to_tensor(loc, name="loc") + dt = loc.dtype.base_dtype + scale = ops.convert_to_tensor(scale, dtype=dt, name="scale") + + loc = maybe_check_quadrature_param(loc, "loc", validate_args) + scale = maybe_check_quadrature_param(scale, "scale", validate_args) + + dist = normal_lib.Normal(loc=loc, scale=scale) + + def _get_batch_ndims(): + """Helper to get dist.batch_shape.ndims, statically if possible.""" + ndims = dist.batch_shape.ndims + if ndims is None: + ndims = array_ops.shape(dist.batch_shape_tensor())[0] + return ndims + batch_ndims = _get_batch_ndims() + + def _get_final_shape(qs): + """Helper to build `TensorShape`.""" + bs = dist.batch_shape.with_rank_at_least(1) + num_components = bs[-1].value + if num_components is not None: + num_components += 1 + tail = tensor_shape.TensorShape([num_components, qs]) + return bs[:-1].concatenate(tail) + + def _compute_quantiles(): + """Helper to build quantiles.""" + # Omit {0, 1} since they might lead to Inf/NaN. + zero = array_ops.zeros([], dtype=dist.dtype) + edges = math_ops.linspace(zero, 1., quadrature_size + 3)[1:-1] + # Expand edges so its broadcast across batch dims. + edges = array_ops.reshape(edges, shape=array_ops.concat([ + [-1], array_ops.ones([batch_ndims], dtype=dtypes.int32)], axis=0)) + quantiles = dist.quantile(edges) + quantiles = SoftmaxCentered(event_ndims=1).forward(quantiles) + # Cyclically permute left by one. + perm = array_ops.concat([ + math_ops.range(1, 1 + batch_ndims), [0]], axis=0) + quantiles = array_ops.transpose(quantiles, perm) + quantiles.set_shape(_get_final_shape(quadrature_size + 1)) + return quantiles + quantiles = _compute_quantiles() + + # Compute grid as quantile midpoints. + grid = (quantiles[..., :-1] + quantiles[..., 1:]) / 2. + # Set shape hints. + grid.set_shape(_get_final_shape(quadrature_size)) + + # By construction probs is constant, i.e., `1 / quadrature_size`. This is + # important, because non-constant probs leads to non-reparameterizable + # samples. + probs = array_ops.fill( + dims=[quadrature_size], + value=1. / math_ops.cast(quadrature_size, dist.dtype)) + + return grid, probs + + class VectorDiffeomixture(distribution_lib.Distribution): """VectorDiffeomixture distribution. @@ -188,8 +334,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): #### Examples ```python - ds = tf.contrib.distributions - la = tf.linalg + tfd = tf.contrib.distributions # Create two batches of VectorDiffeomixtures, one with mix_loc=[0.] and # another with mix_loc=[1]. In both cases, `K=2` and the affine @@ -197,20 +342,20 @@ class VectorDiffeomixture(distribution_lib.Distribution): # k=0: loc=zeros(dims) scale=LinearOperatorScaledIdentity # k=1: loc=[2.]*dims scale=LinOpDiag dims = 5 - vdm = ds.VectorDiffeomixture( + vdm = tfd.VectorDiffeomixture( mix_loc=[[0.], [1]], mix_scale=[1.], - distribution=ds.Normal(loc=0., scale=1.), + distribution=tfd.Normal(loc=0., scale=1.), loc=[ None, # Equivalent to `np.zeros(dims, dtype=np.float32)`. np.float32([2.]*dims), ], scale=[ - la.LinearOperatorScaledIdentity( + tf.linalg.LinearOperatorScaledIdentity( num_rows=dims, multiplier=np.float32(1.1), is_positive_definite=True), - la.LinearOperatorDiag( + tf.linalg.LinearOperatorDiag( diag=np.linspace(2.5, 3.5, dims, dtype=np.float32), is_positive_definite=True), ], @@ -223,17 +368,20 @@ class VectorDiffeomixture(distribution_lib.Distribution): distribution, loc=None, scale=None, - quadrature_grid_and_probs=None, + quadrature_size=8, + quadrature_fn=quadrature_scheme_softmaxnormal_quantiles, validate_args=False, allow_nan_stats=True, name="VectorDiffeomixture"): - """Constructs the VectorDiffeomixture on `R**k`. + """Constructs the VectorDiffeomixture on `R**d`. Args: - mix_loc: `float`-like `Tensor`. Represents the `location` parameter of the - SoftmaxNormal used for selecting one of the `K` affine transformations. - mix_scale: `float`-like `Tensor`. Represents the `scale` parameter of the - SoftmaxNormal used for selecting one of the `K` affine transformations. + mix_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`. Represents + the `location` parameter of the SoftmaxNormal used for selecting one of + the `K` affine transformations. + mix_scale: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`. + Represents the `scale` parameter of the SoftmaxNormal used for selecting + one of the `K` affine transformations. distribution: `tf.Distribution`-like instance. Distribution from which `d` iid samples are used as input to the selected affine transformation. Must be a scalar-batch, scalar-event distribution. Typically @@ -252,10 +400,13 @@ class VectorDiffeomixture(distribution_lib.Distribution): `k`-th element represents the `scale` used for the `k`-th affine transformation. `LinearOperator`s must have shape `[B1, ..., Bb, d, d]`, `b >= 0`, i.e., characterizes `b`-batches of `d x d` matrices - quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s - representing the sample points and the corresponding (possibly - normalized) weight. When `None`, defaults to: - `np.polynomial.hermite.hermgauss(deg=8)`. + quadrature_size: Python `int` scalar representing number of + quadrature points. + quadrature_fn: Python callable taking `mix_loc`, `mix_scale`, + `quadrature_size`, `validate_args` and returning `tuple(grid, probs)` + representing the SoftmaxNormal grid and corresponding normalized weight. + normalized) weight. + Default value: `quadrature_scheme_softmaxnormal_quantiles`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect @@ -322,11 +473,8 @@ class VectorDiffeomixture(distribution_lib.Distribution): raise NotImplementedError("Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(scale))) - grid, probs = distribution_util.process_quadrature_grid_and_probs( - quadrature_grid_and_probs, dtype, validate_args) - self._quadrature_grid = grid - self._quadrature_probs = probs - self._quadrature_size = distribution_util.dimension_size(probs, axis=0) + self._grid, probs = tuple(quadrature_fn( + mix_loc, mix_scale, quadrature_size, validate_args)) # Note: by creating the logits as `log(prob)` we ensure that # `self.mixture_distribution.logits` is equivalent to @@ -336,22 +484,13 @@ class VectorDiffeomixture(distribution_lib.Distribution): validate_args=validate_args, allow_nan_stats=allow_nan_stats) - mix_loc = maybe_check_mix_param( - mix_loc, "mix_loc", dtype, validate_args) - mix_scale = maybe_check_mix_param( - mix_scale, "mix_scale", dtype, validate_args) - asserts = distribution_util.maybe_check_scalar_distribution( distribution, dtype, validate_args) if asserts: - mix_loc = control_flow_ops.with_dependencies(asserts, mix_loc) + self._grid = control_flow_ops.with_dependencies( + asserts, self._grid) self._distribution = distribution - # shape: [B, deg] - self._interpolate_weight = math_ops.sigmoid( - mix_loc - + np.sqrt(2.) * mix_scale * grid) - self._interpolated_affine = [ AffineLinearOperator(shift=loc_, scale=scale_, @@ -359,15 +498,16 @@ class VectorDiffeomixture(distribution_lib.Distribution): validate_args=validate_args, name="interpolated_affine_{}".format(k)) for k, (loc_, scale_) in enumerate(zip( - interpolate_loc(self._quadrature_size, - self._interpolate_weight, - loc), - interpolate_scale(self._quadrature_size, - self._interpolate_weight, - scale)))] + interpolate_loc(self._grid, loc), + interpolate_scale(self._grid, scale)))] - self._batch_shape_, self._event_shape_ = determine_batch_event_shapes( - mix_loc, mix_scale, self._endpoint_affine) + [ + self._batch_shape_, + self._batch_shape_tensor_, + self._event_shape_, + self._event_shape_tensor_, + ] = determine_batch_event_shapes(self._grid, + self._endpoint_affine) super(VectorDiffeomixture, self).__init__( dtype=dtype, @@ -386,8 +526,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=( - [mix_loc, mix_scale] - + distribution._graph_parents # pylint: disable=protected-access + distribution._graph_parents # pylint: disable=protected-access + [loc_ for loc_ in loc if loc_ is not None] + [p for scale_ in scale for p in scale_.graph_parents]), name=name) @@ -403,9 +542,9 @@ class VectorDiffeomixture(distribution_lib.Distribution): return self._distribution @property - def interpolate_weight(self): + def grid(self): """Grid of mixing probabilities, one for each grid point.""" - return self._interpolate_weight + return self._grid @property def endpoint_affine(self): @@ -417,27 +556,17 @@ class VectorDiffeomixture(distribution_lib.Distribution): """Affine transformation for each convex combination of `K` components.""" return self._interpolated_affine - @property - def quadrature_grid(self): - """Quadrature grid points.""" - return self._quadrature_grid - - @property - def quadrature_probs(self): - """Quadrature normalized weights.""" - return self._quadrature_probs - def _batch_shape_tensor(self): - return self._batch_shape_ + return self._batch_shape_tensor_ def _batch_shape(self): - return tensor_shape.TensorShape(static_value(self._batch_shape_)) + return self._batch_shape_ def _event_shape_tensor(self): - return self._event_shape_ + return self._event_shape_tensor_ def _event_shape(self): - return tensor_shape.TensorShape(static_value(self._event_shape_)) + return self._event_shape_ def _sample_n(self, n, seed=None): x = self.distribution.sample( @@ -450,25 +579,44 @@ class VectorDiffeomixture(distribution_lib.Distribution): # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. - batch_size = reduce_prod(self.batch_shape_tensor()) - ids = self._mixture_distribution.sample( + batch_size = self.batch_shape.num_elements() + if batch_size is None: + batch_size = array_ops.reduce_prod(self.batch_shape_tensor()) + mix_batch_size = self.mixture_distribution.batch_shape.num_elements() + if mix_batch_size is None: + mix_batch_size = math_ops.reduce_prod( + self.mixture_distribution.batch_shape_tensor()) + ids = self.mixture_distribution.sample( sample_shape=concat_vectors( [n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), - [batch_size])), + [batch_size // mix_batch_size])), seed=distribution_util.gen_new_seed( seed, "vector_diffeomixture")) - - # Stride `quadrature_size` for `batch_size` number of times. + # We need to flatten batch dims in case mixture_distribution has its own + # batch dims. + ids = array_ops.reshape(ids, shape=concat_vectors( + [n], + distribution_util.pick_vector( + self.is_scalar_batch(), + np.int32([]), + np.int32([-1])))) + + # Stride `components * quadrature_size` for `batch_size` number of times. + stride = self.grid.shape.with_rank_at_least( + 2)[-2:].num_elements() + if stride is None: + stride = array_ops.reduce_prod( + array_ops.shape(self.grid)[-2:]) offset = math_ops.range(start=0, - limit=batch_size * self._quadrature_size, - delta=self._quadrature_size, + limit=batch_size * stride, + delta=stride, dtype=ids.dtype) weight = array_ops.gather( - array_ops.reshape(self.interpolate_weight, shape=[-1]), + array_ops.reshape(self.grid, shape=[-1]), ids + offset) weight = weight[..., array_ops.newaxis] @@ -500,10 +648,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): self.mixture_distribution.logits - fldj + log_prob, axis=-1) def _mean(self): - # Since we created logits to already be scaled, we can use exp which is - # slightly cheaper than `self.mixture_distribution.probs`. - p = math_ops.exp(self.mixture_distribution.logits) - + p = self._expand_mix_distribution_probs() m = self._expand_base_distribution_mean() mean = None for k, aff in enumerate(self.interpolated_affine): @@ -537,9 +682,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): self._covariance_of_mean_given_quadrature_component(diag_only=True)) def _mean_of_covariance_given_quadrature_component(self, diag_only): - # Since we created logits to already be scaled, we can use exp which is - # slightly cheaper than `self.mixture_distribution.probs`. - p = math_ops.exp(self.mixture_distribution.logits) + p = self.mixture_distribution.probs # To compute E[Cov(Z|V)], we'll add matrices within three categories: # scaled-identity, diagonal, and full. Then we'll combine these at the end. @@ -611,10 +754,9 @@ class VectorDiffeomixture(distribution_lib.Distribution): def _covariance_of_mean_given_quadrature_component(self, diag_only): square = math_ops.square if diag_only else vec_osquare - # Since we created logits to already be scaled, we can use exp which is - # slightly cheaper than `self.mixture_distribution.probs`. - p = math_ops.exp(self.mixture_distribution.logits) - + p = self._expand_mix_distribution_probs() + if not diag_only: + p = p[..., array_ops.newaxis, :] # Assuming event.ndims=1. m = self._expand_base_distribution_mean() cov_e_z_given_v = None @@ -638,17 +780,25 @@ class VectorDiffeomixture(distribution_lib.Distribution): m.set_shape(self.batch_shape.concatenate(self.event_shape)) return m - -def maybe_check_mix_param(param, name, expected_base_dtype, validate_args): - """Helper which checks validity of `mix_loc` and `mix_scale` init args.""" + def _expand_mix_distribution_probs(self): + p = self.mixture_distribution.probs # [B, deg] + deg = p.shape.with_rank_at_least(1)[-1].value + if deg is None: + deg = array_ops.shape(p)[-1] + event_ndims = self.event_shape.ndims + if event_ndims is None: + event_ndims = array_ops.shape(self.event_shape_tensor())[0] + expand_shape = array_ops.concat([ + self.mixture_distribution.batch_shape_tensor(), + array_ops.ones([event_ndims], dtype=dtypes.int32), + [deg], + ], axis=0) + return array_ops.reshape(p, shape=expand_shape) + + +def maybe_check_quadrature_param(param, name, validate_args): + """Helper which checks validity of `loc` and `scale` init args.""" with ops.name_scope(name="check_" + name, values=[param]): - param = ops.convert_to_tensor(param, dtype=expected_base_dtype, name=name) - - if param.dtype.base_dtype != expected_base_dtype: - raise TypeError( - "dtype mismatch; {}.base_dtype=\"{}\" is not \"{}\".".format( - name, param.dtype.base_dtype.name, expected_base_dtype.name)) - assertions = [] if param.shape.ndims is not None: if param.shape.ndims == 0: @@ -679,79 +829,84 @@ def maybe_check_mix_param(param, name, expected_base_dtype, validate_args): return param -def determine_batch_event_shapes(mix_loc, mix_scale, endpoint_affine): +def determine_batch_event_shapes(grid, endpoint_affine): """Helper to infer batch_shape and event_shape.""" with ops.name_scope(name="determine_batch_event_shapes"): - mix_batch_shape = distribution_util.prefer_static_broadcast_shape( - array_ops.shape(mix_loc, name="mix_loc_shape"), - array_ops.shape(mix_scale, name="mix_scale_shape")) - if isinstance(mix_batch_shape, tensor_shape.TensorShape): - mix_batch_shape = mix_batch_shape.with_rank_at_least(1)[:-1] - else: - s = static_value(mix_batch_shape) - if s is not None: - mix_batch_shape = ops.convert_to_tensor( - s[:-1], dtype=dtypes.int32, name="mix_batch_shape") - else: - mix_batch_shape = mix_batch_shape[:-1] - - # We broadcast with a 1D constant to automatically make the result a - # TensorShape if possible. - batch_shape = distribution_util.prefer_static_broadcast_shape( - mix_batch_shape, - constant_op.constant([], dtype=dtypes.int32, name="batch_shape")) - event_shape = constant_op.constant( - [], dtype=dtypes.int32, name="event_shape") + # grid # shape: [B, k, q] + # endpoint_affine # len=k, shape: [B, d, d] + batch_shape = grid.shape[:-2] + batch_shape_tensor = array_ops.shape(grid)[:-2] + event_shape = None + event_shape_tensor = None + + def _set_event_shape(shape, shape_tensor): + if event_shape is None: + return shape, shape_tensor + return (array_ops.broadcast_static_shape(event_shape, shape), + array_ops.broadcast_dynamic_shape( + event_shape_tensor, shape_tensor)) + for aff in endpoint_affine: - b, e = distribution_util.shapes_from_loc_and_scale(aff.shift, aff.scale) - if batch_shape is None: - batch_shape = distribution_util.prefer_static_broadcast_shape( - mix_batch_shape, b) - else: - batch_shape = distribution_util.prefer_static_broadcast_shape( - batch_shape, b) - event_shape = distribution_util.prefer_static_broadcast_shape( - event_shape, e) - if isinstance(batch_shape, tensor_shape.TensorShape): - batch_shape = ops.convert_to_tensor( - batch_shape.as_list(), dtype=dtypes.int32, name="batch_shape") - if isinstance(event_shape, tensor_shape.TensorShape): - event_shape = ops.convert_to_tensor( - event_shape.as_list(), dtype=dtypes.int32, name="event_shape") - return batch_shape, event_shape - - -def interpolate_loc(deg, interpolate_weight, loc): + if aff.shift is not None: + batch_shape = array_ops.broadcast_static_shape( + batch_shape, aff.shift.shape[:-1]) + batch_shape_tensor = array_ops.broadcast_dynamic_shape( + batch_shape_tensor, array_ops.shape(aff.shift)[:-1]) + event_shape, event_shape_tensor = _set_event_shape( + aff.shift.shape[-1:], array_ops.shape(aff.shift)[-1:]) + + if aff.scale is not None: + batch_shape = array_ops.broadcast_static_shape( + batch_shape, aff.scale.batch_shape) + batch_shape_tensor = array_ops.broadcast_dynamic_shape( + batch_shape_tensor, aff.scale.batch_shape_tensor()) + event_shape, event_shape_tensor = _set_event_shape( + tensor_shape.TensorShape([aff.scale.range_dimension]), + aff.scale.range_dimension_tensor()[array_ops.newaxis]) + + return batch_shape, batch_shape_tensor, event_shape, event_shape_tensor + + +def interpolate_loc(grid, loc): """Helper which interpolates between two locs.""" if len(loc) != 2: raise NotImplementedError("Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(loc))) - with ops.name_scope("interpolate_loc", values=[interpolate_weight, loc]): + deg = grid.shape.with_rank_at_least(1)[-1].value + if deg is None: + raise ValueError("Num quadrature grid points must be known prior " + "to graph execution.") + with ops.name_scope("interpolate_loc", values=[grid, loc]): if loc is None or loc[0] is None and loc[1] is None: return [None]*deg - w = interpolate_weight[..., array_ops.newaxis, :] # shape: [B, 1, deg] + # shape: [B, 1, k, deg] + w = grid[..., array_ops.newaxis, :, :] loc = [x[..., array_ops.newaxis] # shape: [B, e, 1] if x is not None else None for x in loc] if loc[0] is None: - x = (1. - w) * loc[1] # shape: [B, e, deg] + x = w[..., 1, :] * loc[1] # shape: [B, e, deg] elif loc[1] is None: - x = w * loc[0] # shape: [B, e, deg] + x = w[..., 0, :] * loc[0] # shape: [B, e, deg] else: delta = loc[0] - loc[1] - x = w * delta + loc[1] # shape: [B, e, deg] + x = w[..., 0, :] * delta + loc[1] # shape: [B, e, deg] return [x[..., k] for k in range(deg)] # list(shape:[B, e]) -def interpolate_scale(deg, interpolate_weight, scale): +def interpolate_scale(grid, scale): """Helper which interpolates between two scales.""" if len(scale) != 2: raise NotImplementedError("Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(scale))) - with ops.name_scope("interpolate_scale", values=[interpolate_weight]): + deg = grid.shape.with_rank_at_least(1)[-1].value + if deg is None: + raise ValueError("Num quadrature grid points must be known prior " + "to graph execution.") + with ops.name_scope("interpolate_scale", values=[grid]): return [linop_add_lib.add_operators([ - linop_scale(interpolate_weight[..., k], scale[0]), - linop_scale(1. - interpolate_weight[..., k], scale[1]), - ])[0] for k in range(deg)] + linop_scale(grid[..., k, q], s) + for k, s in enumerate(scale) + ])[0] for q in range(deg)] def linop_scale(w, op): @@ -791,39 +946,12 @@ def linop_scale(w, op): def concat_vectors(*args): """Concatenates input vectors, statically if possible.""" - args_ = [static_value(x) for x in args] + args_ = [distribution_util.static_value(x) for x in args] if any(vec is None for vec in args_): return array_ops.concat(args, axis=0) return [val for vec in args_ for val in vec] -def reduce_prod(x): - """Same as `math_ops.reduce_prod` but statically if possible.""" - x_ = static_value(x) - if x_ is not None: - return np.prod(x_, dtype=x.dtype.as_numpy_dtype) - return array_ops.reduce_prod(x) - - -def ndims_from_shape(shape): - """Returns `Tensor`'s `rank` implied by a `Tensor` shape.""" - if shape.shape.ndims not in (None, 1): - raise ValueError("input is not a valid shape: not 1D") - if not shape.dtype.is_integer: - raise TypeError("input is not a valid shape: wrong dtype") - if shape.shape.is_fully_defined(): - return shape.shape.as_list()[0] - return array_ops.shape(shape)[0] - - -def ndims(x): - """Returns rank, statically if possible.""" - x = ops.convert_to_tensor(x) - if x.shape.ndims is not None: - return x.shape.ndims - return array_ops.rank(x) - - def add(x, y): """Adds inputs; interprets `None` as zero.""" if x is None: @@ -836,3 +964,18 @@ def add(x, y): def vec_osquare(x): """Computes the outer-product of a (batch of) vector, i.e., x.T x.""" return x[..., :, array_ops.newaxis] * x[..., array_ops.newaxis, :] + + +def softmax(x, axis, name=None): + """Equivalent to tf.nn.softmax but works around b/70297725.""" + with ops.name_scope(name, "softmax", [x, axis]): + x = ops.convert_to_tensor(x, name="x") + ndims = (x.shape.ndims if x.shape.ndims is not None + else array_ops.rank(x, name="ndims")) + axis = ops.convert_to_tensor(axis, dtype=dtypes.int32, name="axis") + axis_ = tensor_util.constant_value(axis) + if axis_ is not None: + axis = np.int(ndims + axis_ if axis_ < 0 else axis_) + else: + axis = array_ops.where(axis < 0, ndims + axis, axis) + return nn_ops.softmax(x, axis=axis) diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py index 356d78b67a8107750f68f7f84d73d1231f5b2b03..526fe2d39aef9aed833b889de80e849c469435e7 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py @@ -89,14 +89,13 @@ class VectorExponentialDiag( #### Examples ```python - ds = tf.contrib.distributions - la = tf.linalg + tfd = tf.contrib.distributions # Initialize a single 2-variate VectorExponential, supported on # {(x, y) in R^2 : x > 0, y > 0}. # The first component has pdf exp{-x}, the second 0.5 exp{-x / 2} - vex = ds.VectorExponentialDiag(scale_diag=[1., 2.]) + vex = tfd.VectorExponentialDiag(scale_diag=[1., 2.]) # Compute the pdf of an`R^2` observation; return a scalar. vex.prob([3., 4.]).eval() # shape: [] @@ -107,7 +106,7 @@ class VectorExponentialDiag( scale_diag = [[1., 2, 3], [0.5, 1, 1.5]] # shape: [2, 3] - vex = ds.VectorExponentialDiag(loc, scale_diag) + vex = tfd.VectorExponentialDiag(loc, scale_diag) # Compute the pdf of two `R^3` observations; return a length-2 vector. x = [[1.9, 2.2, 3.1], diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py index b313a851b381e5b3a057fd17e6c2ef4eb0fc34f1..9d5fd9ac4178a1ae29b1ce32f304b22fd3d234dc 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py @@ -107,16 +107,15 @@ class VectorExponentialLinearOperator( #### Examples ```python - ds = tf.contrib.distributions - la = tf.linalg + tfd = tf.contrib.distributions # Initialize a single 2-variate VectorExponential, supported on # {(x, y) in R^2 : x > 0, y > 0}. mat = [[1.0, 0.1], [0.1, 1.0]] - vex = ds.VectorExponentialLinearOperator( - scale=la.LinearOperatorFullMatrix(mat)) + vex = tfd.VectorExponentialLinearOperator( + scale=tf.linalg.LinearOperatorFullMatrix(mat)) # Compute the pdf of an`R^2` observation; return a scalar. vex.prob([1., 2.]).eval() # shape: [] @@ -127,9 +126,9 @@ class VectorExponentialLinearOperator( scale_diag = [[1., 2, 3], [0.5, 1, 1.5]] # shape: [2, 3] - vex = ds.VectorExponentialLinearOperator( + vex = tfd.VectorExponentialLinearOperator( loc=mu, - scale=la.LinearOperatorDiag(scale_diag)) + scale=tf.linalg.LinearOperatorDiag(scale_diag)) # Compute the pdf of two `R^3` observations; return a length-2 vector. x = [[1.9, 2.2, 3.1], diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py index 0e3867809a820f49cfa7f5282c47f786626481a6..8dd983b750d9b39775e570800006011f4968f7f3 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py @@ -101,10 +101,10 @@ class VectorLaplaceDiag( #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 2-variate VectorLaplace. - vla = ds.VectorLaplaceDiag( + vla = tfd.VectorLaplaceDiag( loc=[1., -1], scale_diag=[1, 2.]) @@ -118,7 +118,7 @@ class VectorLaplaceDiag( vla.prob([-1., 0]).eval() # shape: [] # Initialize a 3-batch, 2-variate scaled-identity VectorLaplace. - vla = ds.VectorLaplaceDiag( + vla = tfd.VectorLaplaceDiag( loc=[1., -1], scale_identity_multiplier=[1, 2., 3]) @@ -136,7 +136,7 @@ class VectorLaplaceDiag( vla.prob([-1., 0]).eval() # shape: [3] # Initialize a 2-batch of 3-variate VectorLaplace's. - vla = ds.VectorLaplaceDiag( + vla = tfd.VectorLaplaceDiag( loc=[[1., 2, 3], [11, 22, 33]] # shape: [2, 3] scale_diag=[[1., 2, 3], diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py index c7abdbb4caf9bee4cbd5991eb5d652f20dd0f8d1..ec485c95c15da2794b67d2699d2bdd9db97bb6c4 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py @@ -109,8 +109,7 @@ class VectorLaplaceLinearOperator( #### Examples ```python - ds = tf.contrib.distributions - la = tf.linalg + tfd = tf.contrib.distributions # Initialize a single 3-variate VectorLaplace with some desired covariance. mu = [1., 2, 3] @@ -124,9 +123,9 @@ class VectorLaplaceLinearOperator( # [ 0.1, -0.3, 0.4]]) # Divide scale by sqrt(2) so that the final covariance will be what we want. - vla = ds.VectorLaplaceLinearOperator( + vla = tfd.VectorLaplaceLinearOperator( loc=mu, - scale=la.LinearOperatorLowerTriangular(scale / tf.sqrt(2))) + scale=tf.linalg.LinearOperatorLowerTriangular(scale / tf.sqrt(2.))) # Covariance agrees with cholesky(cov) parameterization. vla.covariance().eval() @@ -143,9 +142,9 @@ class VectorLaplaceLinearOperator( scale_diag = [[1., 2, 3], [0.5, 1, 1.5]] # shape: [2, 3] - vla = ds.VectorLaplaceLinearOperator( + vla = tfd.VectorLaplaceLinearOperator( loc=mu, - scale=la.LinearOperatorDiag(scale_diag)) + scale=tf.linalg.LinearOperatorDiag(scale_diag)) # Compute the pdf of two `R^3` observations; return a length-2 vector. x = [[-0.9, 0, 0.1], diff --git a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py index 544a8710709a0afb56c6ae6f36d35de892e8e420..e1ccf116457a97261b9ce3965552764771d3bdd2 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py @@ -143,7 +143,7 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution): broadcastable with `event_shape`. distribution: `tf.Distribution`-like instance. Distribution from which `k` iid samples are used as input to transformation `F`. Default is - `ds.Normal(0., 1.)`. + `tf.distributions.Normal(loc=0., scale=1.)`. Must be a scalar-batch, scalar-event distribution. Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is a function of non-trainable parameters. WARNING: If you backprop through diff --git a/tensorflow/contrib/distributions/python/ops/vector_student_t.py b/tensorflow/contrib/distributions/python/ops/vector_student_t.py index 29d41ab81c62d621c3c3533e1449341e9a085645..8c67647a618d22a58428d78865c4ebf7d98bdf9e 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_student_t.py +++ b/tensorflow/contrib/distributions/python/ops/vector_student_t.py @@ -91,14 +91,14 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): Extra leading dimensions, if provided, allow for batches. ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 3-variate vector Student's t-distribution. mu = [1., 2, 3] chol = [[1., 0, 0.], [1, 3, 0], [1, 2, 3]] - vt = ds.VectorStudentT(df=2, loc=mu, scale_tril=chol) + vt = tfd.VectorStudentT(df=2, loc=mu, scale_tril=chol) # Evaluate this on an observation in R^3, returning a scalar. vt.prob([-1., 0, 1]) @@ -107,7 +107,7 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): mu = [[1., 2, 3], [11, 22, 33]] chol = ... # shape 2 x 3 x 3, lower triangular, positive diagonal. - vt = ds.VectorStudentT(loc=mu, scale_tril=chol) + vt = tfd.VectorStudentT(loc=mu, scale_tril=chol) # Evaluate this on a two observations, each in R^3, returning a length two # tensor. diff --git a/tensorflow/contrib/eager/README.md b/tensorflow/contrib/eager/README.md index dcc370cd00d5f93cd5b145a31fd58ef5041a86a8..09242ee47ddd044dfc99e22d5b7751a989c86485 100644 --- a/tensorflow/contrib/eager/README.md +++ b/tensorflow/contrib/eager/README.md @@ -76,3 +76,6 @@ For an introduction to eager execution in TensorFlow, see: ## Changelog - 2017/10/31: Initial preview release. +- 2017/12/01: Example of dynamic neural network: + [SPINN: Stack-augmented Parser-Interpreter Neural Network](https://arxiv.org/abs/1603.06021). + See [README.md](python/examples/spinn/README.md) for details. diff --git a/tensorflow/contrib/eager/proto/BUILD b/tensorflow/contrib/eager/proto/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..aedfec8924e7314addd22349c0576a84a58d9aa3 --- /dev/null +++ b/tensorflow/contrib/eager/proto/BUILD @@ -0,0 +1,24 @@ +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +tf_proto_library( + name = "checkpointable_object_graph_proto", + srcs = [ + "checkpointable_object_graph.proto", + ], + visibility = ["//tensorflow/contrib/eager/python:__subpackages__"], +) diff --git a/tensorflow/contrib/eager/proto/checkpointable_object_graph.proto b/tensorflow/contrib/eager/proto/checkpointable_object_graph.proto new file mode 100644 index 0000000000000000000000000000000000000000..c962638aa11c06dcd5be6a794314e029ae84e572 --- /dev/null +++ b/tensorflow/contrib/eager/proto/checkpointable_object_graph.proto @@ -0,0 +1,56 @@ +syntax = "proto3"; + +option cc_enable_arenas = true; + +package tensorflow.contrib.eager; + +// Prototype for an addition to BundleHeaderProto which saves extra information +// about the objects which own variables, allowing for more robust checkpoint +// loading into modified programs. + +message CheckpointableObjectGraph { + message Object { + message ObjectReference { + // An index into `CheckpointableObjectGraph.nodes`, indicating the object + // being referenced. + int32 node_id = 1; + // A numeric identifier for this object within its parent. + int32 local_uid = 2; + // A user-provided name for the edge. May be blank/omitted, in which case + // there is no explicitly provided local name; fall back on local_uid. + string local_name = 3; + } + + message VariableReference { + // A name for the variable which is unique within the object which owns + // it. Does not include a name_scope or variable_scope prefix. + string local_name = 1; + // The full name of the variable. Used to allow name-based loading of + // checkpoints which were saved using an object-based API. + string full_name = 2; + } + + message SlotVariableReference { + // An index into `CheckpointableObjectGraph.nodes`, indicating the object + // which created the variable that this variable is slotting for. + int32 original_variable_node_id = 1; + // The local name of the variable being slotted for within the object that + // owns it. + string original_variable_local_name = 2; + // The name of the slot (e.g. "m"/"v"). + string slot_name = 3; + // The full name of the slot variable. Used to allow name-based loading of + // checkpoints which were saved using an object-based API. + string full_name = 4; + } + + // Objects which this object depends on. + repeated ObjectReference children = 1; + // Non-slot variables owned by this object. + repeated VariableReference variables = 2; + // Slot variables owned by this object. + repeated SlotVariableReference slot_variables = 3; + } + + repeated Object nodes = 1; +} diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index bf2e883bc53c3281ef89d1200f5a089305ef3e72..086315464c99811371d836aed290b5068729adb0 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -19,6 +19,7 @@ py_library( "//tensorflow/python:framework_test_lib", "//tensorflow/python:numerics", "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:script_ops", "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python/eager:backprop", @@ -103,37 +104,6 @@ cuda_py_test( ], ) -py_library( - name = "summary_writer", - srcs = ["summary_writer.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/summary:gen_summary_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary_op_util", - "//tensorflow/python:variable_scope", - "//tensorflow/python/eager:context", - ], -) - -cuda_py_test( - name = "summary_writer_test", - srcs = ["summary_writer_test.py"], - additional_deps = [ - ":summary_writer", - "//third_party/py/numpy", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:constant_op", - "//tensorflow/python/eager:context", - "//tensorflow/python/eager:test", - ], -) - py_library( name = "metrics", srcs = [ @@ -232,6 +202,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":network", + "//tensorflow/contrib/layers:layers_py", "//tensorflow/python:constant_op", "//tensorflow/python:errors", "//tensorflow/python:framework_test_lib", @@ -246,6 +217,39 @@ py_test( ], ) +py_library( + name = "checkpointable", + srcs = ["checkpointable.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/contrib/eager/proto:checkpointable_object_graph_proto_py", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + ], +) + +py_test( + name = "checkpointable_test", + srcs = ["checkpointable_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":checkpointable", + ":network", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:layers", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + "@six_archive//:six", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/eager/python/checkpointable.py b/tensorflow/contrib/eager/python/checkpointable.py new file mode 100644 index 0000000000000000000000000000000000000000..b141ffb2bc03b8e38f8481bc044c3aae7e156c15 --- /dev/null +++ b/tensorflow/contrib/eager/python/checkpointable.py @@ -0,0 +1,392 @@ +"""An object-local variable management scheme.""" +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import re + +from tensorflow.contrib.eager.proto import checkpointable_object_graph_pb2 +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import optimizer as optimizer_lib +from tensorflow.python.training import saver as saver_lib + +_CheckpointableReference = collections.namedtuple( + "_CheckpointableReference", + [ + "name", # The local name if explicitly specified, else None. + "local_uid", # 0 for the first dependency, 1 for the next, ... Used for + # routing checkpointed variables to their correct + # Checkpointables when "name" is not set (see docstring of + # `track_checkpointable`). + "ref" # The Checkpointable object being referenced. + ]) + +_OwnedVariable = collections.namedtuple( + "_OwnedVariable", + [ + "name", # The variable's (local) name. + "variable" # The owned variable object. + ]) + +# Validation regular expression for the local names of Checkpointable +# objects. In particular, disallows "/" in names, and reserves +# underscore-prefixed names. +_VALID_LOCAL_NAME = re.compile(r"^[A-Za-z0-9.][A-Za-z0-9_.-]*$") + +# Keyword for identifying that the next bit of a checkpoint variable name is a +# slot name. May not be the local name of a checkpointable. Checkpoint names for +# slot variables look like: +# +# /<_OPTIMIZER_SLOTS_NAME>// +# +# Where is a full path from the checkpoint root to the +# variable being slotted for. +_OPTIMIZER_SLOTS_NAME = "_OPTIMIZER_SLOT" + + +class Checkpointable(object): + """Manages variables and dependencies on other objects. + + To make reliable checkpoints, all `Checkpointable`s on which this object + depends must be registered in the constructor using `track_checkpointable` in + a deterministic order, and if possible they should be named. Variables may be + created using `add_variable` outside of the constructor and in any order, but + only these variables will be saved. + """ + + def __init__(self): + # Basically less useful OrderedDicts but without the reference cycles. + # TODO(allenl): Switch these to OrderedDict once TensorFlow supports only + # Python 3.6+. + self._checkpoint_dependencies = [] # A list of _CheckpointableReference + # objects. + self._dependency_names = set() + self._owned_variables = [] # A list of _OwnedVariable objects. + self._owned_variable_names = set() + + def add_variable(self, name, shape, dtype=None, initializer=None, **kwargs): + """Create a new variable object to be saved with this `Checkpointable`. + + If the user has requested that this object or another `Checkpointable` which + depends on this object be restored from a checkpoint (deferred loading + before variable object creation), `initializer` may be ignored and the value + from the checkpoint used instead. + + Args: + name: A name for the variable. Must be unique within this object. + shape: The shape of the variable. + dtype: The data type of the variable. + initializer: The initializer to use. Ignored if deferred loading has been + requested. + **kwargs: Passed to get_variable. + + Returns: + The new variable object. + + Raises: + ValueError: If the variable name is not unique. + """ + if name in self._owned_variable_names: + raise ValueError( + ("A variable named '%s' already exists in this Checkpointable, but " + "Checkpointable.add_variable called to create another with " + "that name. Variable names must be unique within a Checkpointable " + "object.") % (name,)) + if "getter" in kwargs: + # Allow the getter to be overridden, typically because there is a need for + # compatibility with some other variable creation mechanism. This should + # be relatively uncommon in user code. + getter = kwargs.pop("getter") + else: + getter = variable_scope.get_variable + # TODO(allenl): handle deferred loading + new_variable = getter( + name=name, shape=shape, dtype=dtype, initializer=initializer, **kwargs) + self._owned_variables.append( + _OwnedVariable(name=name, variable=new_variable)) + self._owned_variable_names.add(name) + return new_variable + + def track_checkpointable(self, checkpointable, name=None): + """Declare a dependency on another `Checkpointable` object. + + Indicates that checkpoints for this object should include variables from + `checkpointable`. + + Variables in a checkpoint are mapped to `Checkpointable`s based on names if + provided when the checkpoint was written, but otherwise use the order those + `Checkpointable`s were declared as dependencies. Both `name` arguments and + the dependency declaration order should be deterministic. + + There are two sufficient conditions to avoid breaking existing checkpoints + when modifying a class: (1) New dependencies must be declared after existing + dependencies, and (2) dependencies which were previously declared may never + be removed (a trivial placeholder with the same name may be used instead). + + Args: + checkpointable: A `Checkpointable` which this object depends on. + name: A local name for `checkpointable`, used for loading checkpoints into + the correct objects. If provided, it must be unique within this + `Checkpointable`. If None, dependency declaration order is used instead. + + Returns: + `checkpointable`, for convenience when declaring a dependency and + assigning to a member variable in one statement. + + Raises: + RuntimeError: If __init__ was not called. + TypeError: If `checkpointable` does not inherit from `Checkpointable`. + ValueError: For invalid names. + """ + if not hasattr(self, "_checkpoint_dependencies"): + raise RuntimeError("Need to call Checkpointable.__init__ before calling " + "Checkpointable.track_checkpointable().") + if not isinstance(checkpointable, Checkpointable): + raise TypeError( + ("Checkpointable.track_checkpointable() passed type %s, not a " + "Checkpointable.") % (type(checkpointable),)) + if name is not None: + if not _VALID_LOCAL_NAME.match(name): + raise ValueError( + ("Checkpointable names must match the regular expression '%s', but " + "got an invalid name '%s' instead.") % (_VALID_LOCAL_NAME.pattern, + name)) + if name in self._dependency_names: + raise ValueError( + ("Called Checkpointable.track_checkpointable() with name='%s', but " + "a Checkpointable with this name is already declared as a " + "dependency. If provided, names must be unique.") % (name,)) + self._dependency_names.add(name) + self._checkpoint_dependencies.append( + _CheckpointableReference( + name=name, + ref=checkpointable, + # TODO(allenl): Should this be exposed to allow users to stop + # depending on things and still load checkpoints when not using + # names? + local_uid=len(self._checkpoint_dependencies))) + return checkpointable + + @property + def checkpoint_dependencies(self): + """Other `Checkpointable` objects on which this object depends.""" + return self._checkpoint_dependencies + + +def _breadth_first_checkpointable_traversal(root_checkpointable): + """Find shortest paths to all variables owned by dependencies of root.""" + bfs_sorted = [] + root_checkpointable_reference = _CheckpointableReference( + name=None, local_uid=0, ref=root_checkpointable) + to_visit = collections.deque([root_checkpointable_reference]) + path_to_root = {root_checkpointable_reference: ()} + while to_visit: + current_checkpointable = to_visit.popleft() + bfs_sorted.append(current_checkpointable) + for child_checkpointable in ( + current_checkpointable.ref.checkpoint_dependencies): + if child_checkpointable not in path_to_root: + path_to_root[child_checkpointable] = ( + path_to_root[current_checkpointable] + (child_checkpointable,)) + to_visit.append(child_checkpointable) + return bfs_sorted, path_to_root + + +def _object_prefix_from_path(path_to_root): + return "/".join((checkpointable.name if checkpointable.name else "_%d" % ( + checkpointable.local_uid,)) for checkpointable in path_to_root) + + +def _escape_variable_name(variable_name): + # We need to support slashes in variable names for compatibility, since this + # naming scheme is being patched in to things like Layer.add_variable where + # slashes were previously accepted. We also want to use slashes to indicate + # edges traversed to reach the variable, so we escape forward slashes in + # variable names. + return variable_name.replace("_S_", "_S_.").replace(r"/", r"_S__") + + +def _variable_naming_for_object(path_to_root): + """Make a function for naming variables in an object.""" + # Name non-slot variables: + # + # / + # + # is not necessarily unique, but this is fine since we also + # save the graph of `Checkpointable`s with the checkpoint. Even if this path + # no longer exists because of a change in the Python program, we can look up + # the `Checkpointable` which owns the variable in the checkpoint's graph and + # use another path if one still exists. + + object_prefix = _object_prefix_from_path(path_to_root) + if object_prefix: + object_prefix += "/" + + def _name_single_variable(owned_variable): + """Names a variable within an object.""" + return object_prefix + _escape_variable_name(owned_variable.name) + + return _name_single_variable + + +def _slot_variable_naming_for_optimizer(optimizer, path_to_root): + """Make a function for naming slot variables in an optimizer.""" + # Name slot variables: + # + # /<_OPTIMIZER_SLOTS_NAME>// + # + # where is exactly the checkpoint name used for the original + # variable, including the path from the checkpoint root and the local name in + # the object which owns it. Note that we only save slot variables if the + # variable it's slotting for is also being saved. + + optimizer_identifier = "/%s/%s/" % (_OPTIMIZER_SLOTS_NAME, + _object_prefix_from_path(path_to_root)) + + def _name_slot_variable(variable_path, slot_name): + """With an optimizer specified, name a slot variable.""" + + if not _VALID_LOCAL_NAME.match(slot_name): + # Slot variable names include the name of the slot. We need to + # validate that part of the name to be sure that the checkpoint name + # is a valid name scope name. + raise ValueError( + ("Could not save slot variables for optimizer %s, because its " + "slot name has invalid characters (got '%s', was expecting it " + "to match the regular expression '%s').") % + (optimizer, slot_name, _VALID_LOCAL_NAME.pattern)) + + return variable_path + optimizer_identifier + slot_name + + return _name_slot_variable + + +def _serialize_non_slot_variables(checkpointable_objects, path_to_root, + object_graph_proto): + """Name non-slot variables and add them to `object_graph_proto`.""" + named_variables = {} + non_slot_variables = [] + checkpoint_node_ids = {} + + for checkpoint_id, checkpointable in enumerate(checkpointable_objects): + checkpoint_node_ids[checkpointable] = checkpoint_id + + for checkpoint_id, checkpointable in enumerate(checkpointable_objects): + naming_scheme = _variable_naming_for_object(path_to_root[checkpointable]) + object_proto = object_graph_proto.nodes.add() + for owned_variable in checkpointable.ref._owned_variables: # pylint: disable=protected-access + variable_name = naming_scheme(owned_variable) + named_variables[variable_name] = owned_variable.variable + non_slot_variables.append(( + variable_name, # The variable's full checkpoint name + owned_variable, # The variable's _OwnedVariable object + checkpoint_id)) # The checkpoint ID of the node which owns this + # variable. + variable_proto = object_proto.variables.add() + variable_proto.local_name = owned_variable.name + # Figure out the name-based Saver's name for this variable. + saver_dict = saver_lib.BaseSaverBuilder.OpListToDict( + [owned_variable.variable], convert_variable_to_tensor=False) + variable_full_name, = saver_dict.keys() + variable_proto.full_name = variable_full_name + + for child in checkpointable.ref.checkpoint_dependencies: + child_proto = object_proto.children.add() + child_proto.node_id = checkpoint_node_ids[child] + child_proto.local_uid = child.local_uid + if child.name is not None: + child_proto.local_name = child.name + return named_variables, non_slot_variables + + +def _serialize_slot_variables(checkpointable_objects, path_to_root, + non_slot_variables, object_graph_proto): + """Name slot variables and add them to `object_graph_proto`.""" + named_slot_variables = {} + for optimizer_checkpoint_id, checkpointable_ref in enumerate( + checkpointable_objects): + if isinstance(checkpointable_ref.ref, optimizer_lib.Optimizer): + optimizer_object_proto = object_graph_proto.nodes[optimizer_checkpoint_id] + naming_scheme = _slot_variable_naming_for_optimizer( + optimizer=checkpointable_ref.ref, + path_to_root=path_to_root[checkpointable_ref]) + slot_names = checkpointable_ref.ref.get_slot_names() + for (variable_path, owned_variable, + original_node_checkpoint_id) in non_slot_variables: + for slot_name in slot_names: + slot_variable = checkpointable_ref.ref.get_slot( + owned_variable.variable, slot_name) + if slot_variable is not None: + checkpoint_name = naming_scheme( + variable_path=variable_path, slot_name=slot_name) + named_slot_variables[checkpoint_name] = slot_variable + slot_variable_proto = optimizer_object_proto.slot_variables.add() + slot_variable_proto.slot_name = slot_name + # Figure out the name-based Saver's name for this variable. + saver_dict = saver_lib.BaseSaverBuilder.OpListToDict( + [slot_variable], convert_variable_to_tensor=False) + slot_variable_full_name, = saver_dict.keys() + slot_variable_proto.full_name = slot_variable_full_name + slot_variable_proto.original_variable_local_name = ( + owned_variable.name) + slot_variable_proto.original_variable_node_id = ( + original_node_checkpoint_id) + return named_slot_variables + + +# TODO(allenl): Convenience utility for saving multiple objects (i.e. construct +# a root Checkpointable if passed a list of Checkpointables). +def _serialize_object_graph(root_checkpointable): + """Determine checkpoint keys for variables and build a serialized graph. + + Non-slot variables are keyed based on a shortest path from the root saveable + to the object which owns the variable (i.e. the one which called + `Checkpointable.add_variable` to create it). + + Slot variables are keyed based on a shortest path to the variable being + slotted for, a shortest path to their optimizer, and the slot name. + + Args: + root_checkpointable: A `Checkpointable` object whose variables (including + the variables of dependencies, recursively) should be saved. + + Returns: + A tuple of (named_variables, object_graph_proto): + named_variables: A dictionary mapping names to variable objects. + object_graph_proto: A CheckpointableObjectGraph protocol buffer containing + the serialized object graph and variable references. + + Raises: + ValueError: If there are invalid characters in an optimizer's slot names. + """ + checkpointable_objects, path_to_root = ( + _breadth_first_checkpointable_traversal(root_checkpointable)) + object_graph_proto = ( + checkpointable_object_graph_pb2.CheckpointableObjectGraph()) + + # Gather non-slot variables. + named_variables, non_slot_variables = _serialize_non_slot_variables( + checkpointable_objects, path_to_root, object_graph_proto) + + # Gather slot variables which are associated with variables gathered above. + named_slot_variables = _serialize_slot_variables( + checkpointable_objects, path_to_root, non_slot_variables, + object_graph_proto) + + named_variables.update(named_slot_variables) + return named_variables, object_graph_proto diff --git a/tensorflow/contrib/eager/python/checkpointable_test.py b/tensorflow/contrib/eager/python/checkpointable_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f820990bbe5fe6c9b4cdf890680aaad0847010c0 --- /dev/null +++ b/tensorflow/contrib/eager/python/checkpointable_test.py @@ -0,0 +1,277 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import six + +from tensorflow.contrib.eager.python import checkpointable +from tensorflow.contrib.eager.python import network as network_lib +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.layers import core +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.training import adam +from tensorflow.python.training import training_util + + +class CheckpointableDenseLayer(core.Dense, checkpointable.Checkpointable): + + def __init__(self, *args, **kwargs): + checkpointable.Checkpointable.__init__(self) + core.Dense.__init__(self, *args, **kwargs) + + def add_variable(self, name, shape, **kwargs): + # Calls both Checkpointable.add_variable and Layer.add_variable. Eventually + # Layer.add_variable should inherit from Checkpointable and simply call + # super and then do post-processing. + return checkpointable.Checkpointable.add_variable( + self, + name=name, + shape=shape, + getter=functools.partial(core.Dense.add_variable, self), + **kwargs) + + +# pylint: disable=not-callable +class CheckpointableNetwork(network_lib.Network, checkpointable.Checkpointable): + + def __init__(self): + network_lib.Network.__init__(self) + checkpointable.Checkpointable.__init__(self) + + def track_layer(self, layer, name=None): + self.track_checkpointable(layer, name=name) + return super(CheckpointableNetwork, self).track_layer(layer) + + +class CheckpointableAdam(adam.AdamOptimizer, checkpointable.Checkpointable): + + def __init__(self, *args, **kwargs): + checkpointable.Checkpointable.__init__(self) + adam.AdamOptimizer.__init__(self, *args, **kwargs) + + # NOTE: Copied from AdamOptimizer with modifications to use add_variable + # for non-slot variables. These contortions are necessary to maintain + # checkpoint compatibility with variable.name based saving. + def _create_slots(self, var_list): + # Create the beta1 and beta2 accumulators on the same device as the first + # variable. Sort the var_list to make sure this device is consistent across + # workers (these need to go on the same PS, otherwise some updates are + # silently ignored). + first_var = min(var_list, key=lambda x: x.name) + + create_new = self._beta1_power is None + if not create_new and context.in_graph_mode(): + create_new = (self._beta1_power.graph is not first_var.graph) + + if create_new: + with ops.colocate_with(first_var): + + def _variable_getter(name, shape, dtype, initializer): + del shape, dtype # not used, but there for compatibility + return variable_scope.variable( + name=name, initial_value=initializer, trainable=False) + + self._beta1_power = self.add_variable( + name="beta1_power", + shape=[], + initializer=self._beta1, + getter=_variable_getter) + self._beta2_power = self.add_variable( + name="beta2_power", + shape=[], + initializer=self._beta2, + getter=_variable_getter) + # Create slots for the first and second moments. + for v in var_list: + self._zeros_slot(v, "m", self._name) + self._zeros_slot(v, "v", self._name) + + # TODO(allenl): Override slot variable creation (_get_or_make_slot, + # _get_or_make_slot_with_initializer, _zeros_slot) to allow deferred + # loading. Likely no need to run this through add_variable, since gathering + # slot variables is special cased anyway. + + +class MyNetwork(CheckpointableNetwork): + """A concrete Network for testing.""" + + def __init__(self): + super(MyNetwork, self).__init__() + self._named = self.track_layer( + CheckpointableDenseLayer(1, use_bias=True), name="named_dense") + self._unnamed = self.track_layer( + CheckpointableDenseLayer(1, use_bias=False)) + + def call(self, values): + return self._unnamed(self._named(values)) + + +class Root(checkpointable.Checkpointable): + """A stand-in for a Trainer class.""" + + def __init__(self, optimizer, network): + super(Root, self).__init__() + self.track_checkpointable(optimizer, name="optimizer") + self.track_checkpointable(network, name="network") + self._global_step = None + + @property + def global_step(self): + if self._global_step is None: + # Get the default create_global_step utility to actually call + # self.add_variable, by setting a custom getter. + def _owned_variable_as_custom_getter(getter, *args, **kwargs): + return self.add_variable(*args, getter=getter, **kwargs) + + with variable_scope.variable_scope( + "", custom_getter=_owned_variable_as_custom_getter): + self._global_step = training_util.create_global_step() + return self._global_step + + +class CheckpointNamingTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testNamingWithOptimizer(self): + input_value = constant_op.constant([[3.]]) + network = MyNetwork() + # A nuisance Network using the same optimizer. Its slot variables should not + # go in the checkpoint, since it is never depended on. + other_network = MyNetwork() + optimizer = CheckpointableAdam(0.001) + root_checkpointable = Root(optimizer=optimizer, network=network) + if context.in_eager_mode(): + optimizer.minimize( + lambda: network(input_value), + global_step=root_checkpointable.global_step) + optimizer.minimize( + lambda: other_network(input_value), + global_step=root_checkpointable.global_step) + else: + train_op = optimizer.minimize( + network(input_value), global_step=root_checkpointable.global_step) + optimizer.minimize( + other_network(input_value), + global_step=root_checkpointable.global_step) + self.evaluate(variables.global_variables_initializer()) + self.evaluate(train_op) + named_variables, serialized_graph = checkpointable._serialize_object_graph( + root_checkpointable) + expected_checkpoint_names = ( + # Created in the root node, so no prefix. + "global_step", + # No name provided to track_checkpointable(), so the position (1, after + # the named track_checkpointable() which is 0) is used instead. + "network/_1/kernel", + # track_checkpointable() with a name provided, so that's used + "network/named_dense/kernel", + "network/named_dense/bias", + # The optimizer creates two non-slot variables + "optimizer/beta1_power", + "optimizer/beta2_power", + # Slot variables + "network/_1/kernel/_OPTIMIZER_SLOT/optimizer/m", + "network/_1/kernel/_OPTIMIZER_SLOT/optimizer/v", + "network/named_dense/kernel/_OPTIMIZER_SLOT/optimizer/m", + "network/named_dense/kernel/_OPTIMIZER_SLOT/optimizer/v", + "network/named_dense/bias/_OPTIMIZER_SLOT/optimizer/m", + "network/named_dense/bias/_OPTIMIZER_SLOT/optimizer/v", + ) + six.assertCountEqual(self, expected_checkpoint_names, + named_variables.keys()) + # Check that we've mapped to the right variable objects (not exhaustive) + self.assertEqual("global_step:0", named_variables["global_step"].name) + self.assertEqual("my_network/checkpointable_dense_layer_1/kernel:0", + named_variables["network/_1/kernel"].name) + self.assertEqual("my_network/checkpointable_dense_layer/kernel:0", + named_variables["network/named_dense/kernel"].name) + self.assertEqual("beta1_power:0", + named_variables["optimizer/beta1_power"].name) + self.assertEqual("beta2_power:0", + named_variables["optimizer/beta2_power"].name) + # Spot check the generated protocol buffers. + self.assertEqual(0, serialized_graph.nodes[0].children[0].local_uid) + self.assertEqual("optimizer", + serialized_graph.nodes[0].children[0].local_name) + optimizer_node = serialized_graph.nodes[serialized_graph.nodes[0].children[ + 0].node_id] + self.assertEqual("beta1_power", optimizer_node.variables[0].local_name) + self.assertEqual("beta1_power", optimizer_node.variables[0].full_name) + self.assertEqual( + "kernel", optimizer_node.slot_variables[0].original_variable_local_name) + original_variable_owner = serialized_graph.nodes[ + optimizer_node.slot_variables[0].original_variable_node_id] + self.assertEqual("kernel", original_variable_owner.variables[0].local_name) + self.assertEqual("m", optimizer_node.slot_variables[0].slot_name) + # We strip off the :0 suffix, as variable.name-based saving does. + self.assertEqual("my_network/checkpointable_dense_layer/kernel/Adam", + optimizer_node.slot_variables[0].full_name) + self.assertEqual("my_network/checkpointable_dense_layer/kernel/Adam:0", + optimizer.get_slot( + var=named_variables["network/named_dense/kernel"], + name="m").name) + + def _get_checkpoint_name(self, name): + root = checkpointable.Checkpointable() + with variable_scope.variable_scope("get_checkpoint_name"): + # Create the variable in a variable scope so that we get more relaxed + # naming rules (variables outside a scope may not start with "_", "/" or + # "-"). Since we don't use the scope part of the name, these cases are + # somewhat annoying. + root.add_variable(name=name, shape=[1, 2], dtype=dtypes.float64) + named_variables, _ = checkpointable._serialize_object_graph(root) + checkpoint_name, = named_variables.keys() + with ops.name_scope("root/" + checkpoint_name): + pass # Make sure we can use this as an op name if we prefix it. + return checkpoint_name + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testVariableNameEscaping(self): + self.assertEqual(r"a_S__b_S__c", self._get_checkpoint_name(r"a/b/c")) + self.assertEqual(r"", self._get_checkpoint_name(r"")) + self.assertEqual(r"_S__", self._get_checkpoint_name(r"/")) + self.assertEqual(r"_S___S_._", self._get_checkpoint_name(r"/_S__")) + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testNumberedPath(self): + root = checkpointable.Checkpointable() + leaf = checkpointable.Checkpointable() + root.track_checkpointable(leaf) + leaf.add_variable(name="v", shape=[]) + named_variables, _ = checkpointable._serialize_object_graph(root) + variable_name, = named_variables.keys() + self.assertEqual(r"_0/v", variable_name) + + @test_util.run_in_graph_and_eager_modes() + def testLocalNameValidation(self): + root = checkpointable.Checkpointable() + leaf = checkpointable.Checkpointable() + with self.assertRaisesRegexp(ValueError, "invalid name"): + # Leading underscores are reserved, which avoids conflicts with + # un-named edges in paths and the optimizer slots identifier. + root.track_checkpointable(leaf, name="_12") + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/eager/python/evaluator.py b/tensorflow/contrib/eager/python/evaluator.py index bd0ab02ecf7ae6025e08dde1c3ddc634db9255c1..3faaeef5903615ea122800a6690117dde682e830 100644 --- a/tensorflow/contrib/eager/python/evaluator.py +++ b/tensorflow/contrib/eager/python/evaluator.py @@ -110,7 +110,7 @@ class Evaluator(object): return self._all_metric_results() else: def f(): - with summary_ops.create_summary_file_writer( + with summary_ops.create_file_writer( summary_logdir).as_default(), summary_ops.always_record_summaries(): return self._all_metric_results() if context.in_eager_mode(): diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD index aa21a6ab994acf929890ecebc07a86cf7ebf97db..6aef010a2139c4cd2ae19c008aa21d4e3592ca98 100644 --- a/tensorflow/contrib/eager/python/examples/BUILD +++ b/tensorflow/contrib/eager/python/examples/BUILD @@ -11,5 +11,6 @@ py_library( "//tensorflow/contrib/eager/python/examples/resnet50", "//tensorflow/contrib/eager/python/examples/rnn_colorbot", "//tensorflow/contrib/eager/python/examples/rnn_ptb", + "//tensorflow/contrib/eager/python/examples/spinn:data", ], ) diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py index d0130ebd118dbaff4f0161c8b2528764c6103e02..7bc5007c5655bed81b5600ee283c35bd332a1ebe 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py @@ -85,7 +85,7 @@ def fit(model, dataset, optimizer, verbose=False, logdir=None): if logdir: # Support for TensorBoard summaries. Once training has started, use: # tensorboard --logdir= - summary_writer = tf.contrib.summary.create_summary_file_writer(logdir) + summary_writer = tf.contrib.summary.create_file_writer(logdir) # Training loop. for i, (xs, ys) in enumerate(tfe.Iterator(dataset)): diff --git a/tensorflow/contrib/eager/python/examples/mnist/mnist.py b/tensorflow/contrib/eager/python/examples/mnist/mnist.py index bfb7d5a9002787f6544d383de58150661ac2bde3..bb121c7704b4772dde520ddc928a13c50ec8bb18 100644 --- a/tensorflow/contrib/eager/python/examples/mnist/mnist.py +++ b/tensorflow/contrib/eager/python/examples/mnist/mnist.py @@ -190,9 +190,9 @@ def main(_): else: train_dir = None test_dir = None - summary_writer = tf.contrib.summary.create_summary_file_writer( + summary_writer = tf.contrib.summary.create_file_writer( train_dir, flush_millis=10000) - test_summary_writer = tf.contrib.summary.create_summary_file_writer( + test_summary_writer = tf.contrib.summary.create_file_writer( test_dir, flush_millis=10000, name='test') checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt') diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py index 14c82c87a72457d414c4a1d3c53d4d1a68a400e6..23317886e712323f4b520000e0fd372734fc53a1 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py @@ -73,7 +73,7 @@ class ResNet50GraphTest(tf.test.TestCase): tf.train.get_or_create_global_step() logdir = tempfile.mkdtemp() with tf.contrib.summary.always_record_summaries(): - with tf.contrib.summary.create_summary_file_writer( + with tf.contrib.summary.create_file_writer( logdir, max_queue=0, name='t0').as_default(): model = resnet50.ResNet50(data_format()) diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py index 582f4837c6f3197081cb558063e963866d173f29..d8d8644dde10498e5fd480f92b69656fca1558dd 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py @@ -95,7 +95,7 @@ class ResNet50Test(tf.test.TestCase): model = resnet50.ResNet50(data_format) tf.train.get_or_create_global_step() logdir = tempfile.mkdtemp() - with tf.contrib.summary.create_summary_file_writer( + with tf.contrib.summary.create_file_writer( logdir, max_queue=0, name='t0').as_default(), tf.contrib.summary.always_record_summaries(): with tf.device(device): diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py index 609cbd28772c3ae8da70648ca5b1b264a8a255e2..40919f2d4cf511eb35fac954719286366aef6c7c 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py @@ -247,9 +247,9 @@ def main(_): log_dir = os.path.join(FLAGS.dir, "summaries") tf.gfile.MakeDirs(log_dir) - train_summary_writer = tf.contrib.summary.create_summary_file_writer( + train_summary_writer = tf.contrib.summary.create_file_writer( os.path.join(log_dir, "train"), flush_millis=10000) - test_summary_writer = tf.contrib.summary.create_summary_file_writer( + test_summary_writer = tf.contrib.summary.create_file_writer( os.path.join(log_dir, "eval"), flush_millis=10000, name="eval") with tf.device(device): diff --git a/tensorflow/contrib/eager/python/examples/spinn/BUILD b/tensorflow/contrib/eager/python/examples/spinn/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..a1f8a759e2a556bc219f0aa13942f293c4f34cfa --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/spinn/BUILD @@ -0,0 +1,42 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") +load("//tensorflow:tensorflow.bzl", "py_test") + +py_library( + name = "data", + srcs = ["data.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = ["//third_party/py/numpy"], +) + +py_test( + name = "data_test", + size = "small", + srcs = ["data_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":data", + "//tensorflow:tensorflow_py", + ], +) + +cuda_py_test( + name = "spinn_test", + size = "medium", + srcs = ["spinn_test.py"], + additional_deps = [ + ":data", + "//third_party/examples/eager/spinn", + "//third_party/py/numpy", + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/summary:summary_test_util", + "//tensorflow/python/eager:test", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + ], + tags = ["no_pip"], # because spinn.py is under third_party/. +) diff --git a/tensorflow/contrib/eager/python/examples/spinn/README.md b/tensorflow/contrib/eager/python/examples/spinn/README.md new file mode 100644 index 0000000000000000000000000000000000000000..eb0637df473e22e5d39ca1b0816464cb2b7c6435 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/spinn/README.md @@ -0,0 +1,13 @@ +# SPINN: Dynamic neural network with TensorFlow eager execution + +This directory contains files supporting the +[spinn.py model in third_party/examples/eager/spinn/](../../../../../../third_party/examples/eager/spinn/spinn.py), +including + +- `data.py`: Utility library for loading and preprocessing the SNLI and GloVe + data. +- `data_test.py` and `spinn_test.py`: Unit tests for the data and model modules. + +See the [README.md in third_party/examples/eager/spinn/](../../../../../../third_party/examples/eager/spinn/README.md) +for detailed background, license and usage information regarding the SPINN code. + diff --git a/tensorflow/contrib/eager/python/examples/spinn/data.py b/tensorflow/contrib/eager/python/examples/spinn/data.py new file mode 100644 index 0000000000000000000000000000000000000000..a6e046320f78541bef4e091e97f08fd51857af83 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/spinn/data.py @@ -0,0 +1,350 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities of SNLI data and GloVe word vectors for SPINN model. + +See more details about the SNLI data set at: + https://nlp.stanford.edu/projects/snli/ + +See more details about the GloVe pretrained word embeddings at: + https://nlp.stanford.edu/projects/glove/ +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import glob +import math +import os +import random + +import numpy as np + +POSSIBLE_LABELS = ("entailment", "contradiction", "neutral") + +UNK_CODE = 0 # Code for unknown word tokens. +PAD_CODE = 1 # Code for padding tokens. + +SHIFT_CODE = 3 +REDUCE_CODE = 2 + +WORD_VECTOR_LEN = 300 # Embedding dimensions. + +LEFT_PAREN = "(" +RIGHT_PAREN = ")" +PARENTHESES = (LEFT_PAREN, RIGHT_PAREN) + + +def get_non_parenthesis_words(items): + """Get the non-parenthesis items from a SNLI parsed sentence. + + Args: + items: Data items from a parsed SNLI setence, with parentheses. E.g., + ["(", "Man", "(", "(", "(", "(", "(", "wearing", "pass", ")", ... + + Returns: + A list of non-parenthis word items, all converted to lower case. E.g., + ["man", "wearing", "pass", ... + """ + return [x.lower() for x in items if x not in PARENTHESES and x] + + +def get_shift_reduce(items): + """Obtain shift-reduce vector from a list of items from the SNLI data. + + Args: + items: Data items as a list of str, e.g., + ["(", "Man", "(", "(", "(", "(", "(", "wearing", "pass", ")", ... + + Returns: + A list of shift-reduce transitions, encoded as `SHIFT_CODE` for shift and + `REDUCE_CODE` for reduce. See code above for the values of `SHIFT_CODE` + and `REDUCE_CODE`. + """ + trans = [] + for item in items: + if item == LEFT_PAREN: + continue + elif item == RIGHT_PAREN: + trans.append(REDUCE_CODE) + else: + trans.append(SHIFT_CODE) + return trans + + +def pad_and_reverse_word_ids(sentences): + """Pad a list of sentences to the common maximum length + 1. + + Args: + sentences: A list of sentences as a list of list of integers. Each integer + is a word ID. Each list of integer corresponds to one sentence. + + Returns: + A numpy.ndarray of shape (num_sentences, max_length + 1), wherein max_length + is the maximum sentence length (in # of words). Each sentence is reversed + and then padded with an extra one at head, as required by the model. + """ + max_len = max(len(sent) for sent in sentences) + for sent in sentences: + if len(sent) < max_len: + sent.extend([PAD_CODE] * (max_len - len(sent))) + # Reverse in time order and pad an extra one. + sentences = np.fliplr(np.array(sentences, dtype=np.int64)) + sentences = np.concatenate( + [np.ones([sentences.shape[0], 1], dtype=np.int64), sentences], axis=1) + return sentences + + +def pad_transitions(sentences_transitions): + """Pad a list of shift-reduce transitions to the maximum length.""" + max_len = max(len(transitions) for transitions in sentences_transitions) + for transitions in sentences_transitions: + if len(transitions) < max_len: + transitions.extend([PAD_CODE] * (max_len - len(transitions))) + return np.array(sentences_transitions, dtype=np.int64) + + +def load_vocabulary(data_root): + """Load vocabulary from SNLI data files. + + Args: + data_root: Root directory of the data. It is assumed that the SNLI data + files have been downloaded and extracted to the "snli/snli_1.0" + subdirectory of it. + + Returns: + Vocabulary as a set of strings. + + Raises: + ValueError: If SNLI data files cannot be found. + """ + snli_path = os.path.join(data_root, "snli") + snli_glob_pattern = os.path.join(snli_path, "snli_1.0/snli_1.0_*.txt") + file_names = glob.glob(snli_glob_pattern) + if not file_names: + raise ValueError( + "Cannot find SNLI data files at %s. " + "Please download and extract SNLI data first." % snli_glob_pattern) + + print("Loading vocabulary...") + vocab = set() + for file_name in file_names: + with open(os.path.join(snli_path, file_name), "rt") as f: + for i, line in enumerate(f): + if i == 0: + continue + items = line.split("\t") + premise_words = get_non_parenthesis_words(items[1].split(" ")) + hypothesis_words = get_non_parenthesis_words(items[2].split(" ")) + vocab.update(premise_words) + vocab.update(hypothesis_words) + return vocab + + +def load_word_vectors(data_root, vocab): + """Load GloVe word vectors for words present in the vocabulary. + + Args: + data_root: Data root directory. It is assumed that the GloVe file + has been downloaded and extracted at the "glove/" subdirectory of it. + vocab: A `set` of words, representing the vocabulary. + + Returns: + 1. word2index: A dict from lower-case word to row index in the embedding + matrix, i.e, `embed` below. + 2. embed: The embedding matrix as a float32 numpy array. Its shape is + [vocabulary_size, WORD_VECTOR_LEN]. vocabulary_size is len(vocab). + WORD_VECTOR_LEN is the embedding dimension (300). + + Raises: + ValueError: If GloVe embedding file cannot be found. + """ + glove_path = os.path.join(data_root, "glove/glove.42B.300d.txt") + if not os.path.isfile(glove_path): + raise ValueError( + "Cannot find GloVe embedding file at %s. " + "Please download and extract GloVe embeddings first." % glove_path) + + print("Loading word vectors...") + + word2index = dict() + embed = [] + + embed.append([0] * WORD_VECTOR_LEN) # + embed.append([0] * WORD_VECTOR_LEN) # + word2index[""] = UNK_CODE + word2index[""] = PAD_CODE + + with open(glove_path, "rt") as f: + for line in f: + items = line.split(" ") + word = items[0] + if word in vocab and word not in word2index: + word2index[word] = len(embed) + vector = np.array([float(item) for item in items[1:]]) + assert (WORD_VECTOR_LEN,) == vector.shape + embed.append(vector) + embed = np.array(embed, dtype=np.float32) + return word2index, embed + + +def calculate_bins(length2count, min_bin_size): + """Cacluate bin boundaries given a histogram of lengths and mininum bin size. + + Args: + length2count: A `dict` mapping length to sentence count. + min_bin_size: Minimum bin size in terms of total number of sentence pairs + in the bin. + + Returns: + A `list` representing the right bin boundaries, starting from the inclusive + right boundary of the first bin. For example, if the output is + [10, 20, 35], + it means there are three bins: [1, 10], [11, 20] and [21, 35]. + """ + bounds = [] + lengths = sorted(length2count.keys()) + cum_count = 0 + for length in lengths: + cum_count += length2count[length] + if cum_count >= min_bin_size: + bounds.append(length) + cum_count = 0 + if bounds[-1] != lengths[-1]: + bounds.append(lengths[-1]) + return bounds + + +class SnliData(object): + """A split of SNLI data.""" + + def __init__(self, data_file, word2index, sentence_len_limit=-1): + """SnliData constructor. + + Args: + data_file: Full path to the data file, e.g., + "/tmp/spinn-data/snli/snli_1.0/snli_1.0.train.txt" + word2index: A dict from lower-case word to row index in the embedding + matrix (see `load_word_vectors()` for details). + sentence_len_limit: Maximum allowed sentence length (# of words). + A value of <= 0 means unlimited. Sentences longer than this limit + are currently discarded, not truncated. + """ + + self._labels = [] + self._premises = [] + self._premise_transitions = [] + self._hypotheses = [] + self._hypothesis_transitions = [] + + with open(data_file, "rt") as f: + for i, line in enumerate(f): + if i == 0: + # Skip header line. + continue + items = line.split("\t") + if items[0] not in POSSIBLE_LABELS: + continue + + premise_items = items[1].split(" ") + hypothesis_items = items[2].split(" ") + premise_words = get_non_parenthesis_words(premise_items) + hypothesis_words = get_non_parenthesis_words(hypothesis_items) + + if (sentence_len_limit > 0 and + (len(premise_words) > sentence_len_limit or + len(hypothesis_words) > sentence_len_limit)): + # TODO(cais): Maybe truncate; do not discard. + continue + + premise_ids = [ + word2index.get(word, UNK_CODE) for word in premise_words] + hypothesis_ids = [ + word2index.get(word, UNK_CODE) for word in hypothesis_words] + + self._premises.append(premise_ids) + self._hypotheses.append(hypothesis_ids) + self._premise_transitions.append(get_shift_reduce(premise_items)) + self._hypothesis_transitions.append(get_shift_reduce(hypothesis_items)) + assert (len(self._premise_transitions[-1]) == + 2 * len(premise_words) - 1) + assert (len(self._hypothesis_transitions[-1]) == + 2 * len(hypothesis_words) - 1) + + self._labels.append(POSSIBLE_LABELS.index(items[0]) + 1) + + assert len(self._labels) == len(self._premises) + assert len(self._labels) == len(self._hypotheses) + assert len(self._labels) == len(self._premise_transitions) + assert len(self._labels) == len(self._hypothesis_transitions) + + def num_batches(self, batch_size): + """Calculate number of batches given batch size.""" + return int(math.ceil(len(self._labels) / batch_size)) + + def get_generator(self, batch_size): + """Obtain a generator for batched data. + + All examples of this SnliData object are randomly shuffled, sorted + according to the maximum sentence length of the premise and hypothesis + sentences in the pair, and batched. + + Args: + batch_size: Desired batch size. + + Returns: + A generator for data batches. The generator yields a 5-tuple: + label: An array of the shape (batch_size,). + premise: An array of the shape (max_premise_len, batch_size), wherein + max_premise_len is the maximum length of the (padded) premise + sentence in the batch. + premise_transitions: An array of the shape (2 * max_premise_len -3, + batch_size). + hypothesis: Same as `premise`, but for hypothesis sentences. + hypothesis_transitions: Same as `premise_transitions`, but for + hypothesis sentences. + All the elements of the 5-tuple have dtype `int64`. + """ + # Randomly shuffle examples. + zipped = list(zip( + self._labels, self._premises, self._premise_transitions, + self._hypotheses, self._hypothesis_transitions)) + random.shuffle(zipped) + # Then sort the examples by maximum of the premise and hypothesis sentence + # lengths in the pair. During training, the batches are expected to be + # shuffled. So it is okay to leave them sorted by max length here. + (labels, premises, premise_transitions, hypotheses, + hypothesis_transitions) = zip( + *sorted(zipped, key=lambda x: max(len(x[1]), len(x[3])))) + + def _generator(): + begin = 0 + while begin < len(labels): + # The sorting above and the batching here makes sure that sentences of + # similar max lengths are batched together, minimizing the inefficiency + # due to uneven max lengths. The sentences are batched differently in + # each call to get_generator() due to the shuffling before sotring + # above. The pad_and_reverse_word_ids() and pad_transitions() functions + # take care of any remaning unevenness of the max sentence lengths. + end = min(begin + batch_size, len(labels)) + # Transpose, because the SPINN model requires time-major, instead of + # batch-major. + yield (labels[begin:end], + pad_and_reverse_word_ids(premises[begin:end]).T, + pad_transitions(premise_transitions[begin:end]).T, + pad_and_reverse_word_ids(hypotheses[begin:end]).T, + pad_transitions(hypothesis_transitions[begin:end]).T) + begin = end + return _generator diff --git a/tensorflow/contrib/eager/python/examples/spinn/data_test.py b/tensorflow/contrib/eager/python/examples/spinn/data_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e4f0b37c5099e45b7e3b258b258c0a203c36b3b7 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/spinn/data_test.py @@ -0,0 +1,243 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for SPINN data module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import shutil +import tempfile + +import tensorflow as tf + +from tensorflow.contrib.eager.python.examples.spinn import data + + +class DataTest(tf.test.TestCase): + + def setUp(self): + super(DataTest, self).setUp() + self._temp_data_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self._temp_data_dir) + super(DataTest, self).tearDown() + + def testGenNonParenthesisWords(self): + seq_with_parse = ( + "( Man ( ( ( ( ( wearing pass ) ( on ( a lanyard ) ) ) and " + ") ( standing ( in ( ( a crowd ) ( of people ) ) ) ) ) . ) )") + self.assertEqual( + ["man", "wearing", "pass", "on", "a", "lanyard", "and", "standing", + "in", "a", "crowd", "of", "people", "."], + data.get_non_parenthesis_words(seq_with_parse.split(" "))) + + def testGetShiftReduce(self): + seq_with_parse = ( + "( Man ( ( ( ( ( wearing pass ) ( on ( a lanyard ) ) ) and " + ") ( standing ( in ( ( a crowd ) ( of people ) ) ) ) ) . ) )") + self.assertEqual( + [3, 3, 3, 2, 3, 3, 3, 2, 2, 2, 3, 2, 3, 3, 3, 3, 2, 3, 3, 2, 2, 2, 2, 2, + 3, 2, 2], data.get_shift_reduce(seq_with_parse.split(" "))) + + def testPadAndReverseWordIds(self): + id_sequences = [[0, 2, 3, 4, 5], + [6, 7, 8], + [9, 10, 11, 12, 13, 14, 15, 16]] + self.assertAllClose( + [[1, 1, 1, 1, 5, 4, 3, 2, 0], + [1, 1, 1, 1, 1, 1, 8, 7, 6], + [1, 16, 15, 14, 13, 12, 11, 10, 9]], + data.pad_and_reverse_word_ids(id_sequences)) + + def testPadTransitions(self): + unpadded = [[3, 3, 3, 2, 2, 2, 2], + [3, 3, 2, 2, 2]] + self.assertAllClose( + [[3, 3, 3, 2, 2, 2, 2], + [3, 3, 2, 2, 2, 1, 1]], + data.pad_transitions(unpadded)) + + def testCalculateBins(self): + length2count = { + 1: 10, + 2: 15, + 3: 25, + 4: 40, + 5: 35, + 6: 10} + self.assertEqual([2, 3, 4, 5, 6], + data.calculate_bins(length2count, 20)) + self.assertEqual([3, 4, 6], data.calculate_bins(length2count, 40)) + self.assertEqual([4, 6], data.calculate_bins(length2count, 60)) + + def testLoadVoacbulary(self): + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") + fake_dev_file = os.path.join(snli_1_0_dir, "snli_1.0_dev.txt") + os.makedirs(snli_1_0_dir) + + with open(fake_train_file, "wt") as f: + f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t" + "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t" + "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n") + f.write("neutral\t( ( Foo bar ) . )\t( ( foo baz ) . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + with open(fake_dev_file, "wt") as f: + f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t" + "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t" + "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n") + f.write("neutral\t( ( Quux quuz ) ? )\t( ( Corge grault ) ! )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Quux quuz?\t.Corge grault!\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + + vocab = data.load_vocabulary(self._temp_data_dir) + self.assertSetEqual( + {".", "?", "!", "foo", "bar", "baz", "quux", "quuz", "corge", "grault"}, + vocab) + + def testLoadVoacbularyWithoutFileRaisesError(self): + with self.assertRaisesRegexp(ValueError, "Cannot find SNLI data files at"): + data.load_vocabulary(self._temp_data_dir) + + os.makedirs(os.path.join(self._temp_data_dir, "snli")) + with self.assertRaisesRegexp(ValueError, "Cannot find SNLI data files at"): + data.load_vocabulary(self._temp_data_dir) + + os.makedirs(os.path.join(self._temp_data_dir, "snli/snli_1.0")) + with self.assertRaisesRegexp(ValueError, "Cannot find SNLI data files at"): + data.load_vocabulary(self._temp_data_dir) + + def testLoadWordVectors(self): + glove_dir = os.path.join(self._temp_data_dir, "glove") + os.makedirs(glove_dir) + glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") + + words = [".", ",", "foo", "bar", "baz"] + with open(glove_file, "wt") as f: + for i, word in enumerate(words): + f.write("%s " % word) + for j in range(data.WORD_VECTOR_LEN): + f.write("%.5f" % (i * 0.1)) + if j < data.WORD_VECTOR_LEN - 1: + f.write(" ") + else: + f.write("\n") + + vocab = {"foo", "bar", "baz", "qux", "."} + # Notice that "qux" is not present in `words`. + word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab) + + self.assertEqual(6, len(word2index)) + self.assertEqual(0, word2index[""]) + self.assertEqual(1, word2index[""]) + self.assertEqual(2, word2index["."]) + self.assertEqual(3, word2index["foo"]) + self.assertEqual(4, word2index["bar"]) + self.assertEqual(5, word2index["baz"]) + self.assertEqual((6, data.WORD_VECTOR_LEN), embed.shape) + self.assertAllClose([0.0] * data.WORD_VECTOR_LEN, embed[0, :]) + self.assertAllClose([0.0] * data.WORD_VECTOR_LEN, embed[1, :]) + self.assertAllClose([0.0] * data.WORD_VECTOR_LEN, embed[2, :]) + self.assertAllClose([0.2] * data.WORD_VECTOR_LEN, embed[3, :]) + self.assertAllClose([0.3] * data.WORD_VECTOR_LEN, embed[4, :]) + self.assertAllClose([0.4] * data.WORD_VECTOR_LEN, embed[5, :]) + + def testLoadWordVectorsWithoutFileRaisesError(self): + vocab = {"foo", "bar", "baz", "qux", "."} + with self.assertRaisesRegexp( + ValueError, "Cannot find GloVe embedding file at"): + data.load_word_vectors(self._temp_data_dir, vocab) + + os.makedirs(os.path.join(self._temp_data_dir, "glove")) + with self.assertRaisesRegexp( + ValueError, "Cannot find GloVe embedding file at"): + data.load_word_vectors(self._temp_data_dir, vocab) + + def testSnliData(self): + """Unit test for SnliData objects.""" + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") + os.makedirs(snli_1_0_dir) + + # Four sentences in total. + with open(fake_train_file, "wt") as f: + f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t" + "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t" + "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n") + f.write("neutral\t( ( Foo bar ) . )\t( ( foo . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("contradiction\t( ( Bar foo ) . )\t( ( baz . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("entailment\t( ( Quux quuz ) . )\t( ( grault . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("entailment\t( ( Quuz quux ) . )\t( ( garply . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + + glove_dir = os.path.join(self._temp_data_dir, "glove") + os.makedirs(glove_dir) + glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") + + words = [".", "foo", "bar", "baz", "quux", "quuz", "grault", "garply"] + with open(glove_file, "wt") as f: + for i, word in enumerate(words): + f.write("%s " % word) + for j in range(data.WORD_VECTOR_LEN): + f.write("%.5f" % (i * 0.1)) + if j < data.WORD_VECTOR_LEN - 1: + f.write(" ") + else: + f.write("\n") + + vocab = data.load_vocabulary(self._temp_data_dir) + word2index, _ = data.load_word_vectors(self._temp_data_dir, vocab) + + train_data = data.SnliData(fake_train_file, word2index) + self.assertEqual(4, train_data.num_batches(1)) + self.assertEqual(2, train_data.num_batches(2)) + self.assertEqual(2, train_data.num_batches(3)) + self.assertEqual(1, train_data.num_batches(4)) + + generator = train_data.get_generator(2)() + for i in range(2): + label, prem, prem_trans, hypo, hypo_trans = next(generator) + self.assertEqual(2, len(label)) + self.assertEqual((4, 2), prem.shape) + self.assertEqual((5, 2), prem_trans.shape) + self.assertEqual((3, 2), hypo.shape) + self.assertEqual((3, 2), hypo_trans.shape) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py new file mode 100644 index 0000000000000000000000000000000000000000..84e25cf81a2223800c47994b26d000caddee6b01 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py @@ -0,0 +1,409 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import gc +import glob +import os +import shutil +import tempfile +import time + +import numpy as np +import tensorflow as tf + +# pylint: disable=g-bad-import-order +import tensorflow.contrib.eager as tfe +from tensorflow.contrib.eager.python.examples.spinn import data +from third_party.examples.eager.spinn import spinn +from tensorflow.contrib.summary import summary_test_util +from tensorflow.python.eager import test +from tensorflow.python.framework import test_util +# pylint: enable=g-bad-import-order + + +def _generate_synthetic_snli_data_batch(sequence_length, + batch_size, + vocab_size): + """Generate a fake batch of SNLI data for testing.""" + with tf.device("cpu:0"): + labels = tf.random_uniform([batch_size], minval=1, maxval=4, dtype=tf.int64) + prem = tf.random_uniform( + (sequence_length, batch_size), maxval=vocab_size, dtype=tf.int64) + prem_trans = tf.constant(np.array( + [[3, 3, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3, + 2, 3, 3, 2, 2, 3, 3, 3, 2, 2, 2, 2, + 3, 2, 2]] * batch_size, dtype=np.int64).T) + hypo = tf.random_uniform( + (sequence_length, batch_size), maxval=vocab_size, dtype=tf.int64) + hypo_trans = tf.constant(np.array( + [[3, 3, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3, + 2, 3, 3, 2, 2, 3, 3, 3, 2, 2, 2, 2, + 3, 2, 2]] * batch_size, dtype=np.int64).T) + if tfe.num_gpus(): + labels = labels.gpu() + prem = prem.gpu() + prem_trans = prem_trans.gpu() + hypo = hypo.gpu() + hypo_trans = hypo_trans.gpu() + return labels, prem, prem_trans, hypo, hypo_trans + + +def _test_spinn_config(d_embed, d_out, logdir=None): + config_tuple = collections.namedtuple( + "Config", ["d_hidden", "d_proj", "d_tracker", "predict", + "embed_dropout", "mlp_dropout", "n_mlp_layers", "d_mlp", + "d_out", "projection", "lr", "batch_size", "epochs", + "force_cpu", "logdir", "log_every", "dev_every", "save_every", + "lr_decay_every", "lr_decay_by"]) + return config_tuple( + d_hidden=d_embed, + d_proj=d_embed * 2, + d_tracker=8, + predict=False, + embed_dropout=0.1, + mlp_dropout=0.1, + n_mlp_layers=2, + d_mlp=32, + d_out=d_out, + projection=True, + lr=2e-2, + batch_size=2, + epochs=10, + force_cpu=False, + logdir=logdir, + log_every=1, + dev_every=2, + save_every=2, + lr_decay_every=1, + lr_decay_by=0.75) + + +class SpinnTest(test_util.TensorFlowTestCase): + + def setUp(self): + super(SpinnTest, self).setUp() + self._test_device = "gpu:0" if tfe.num_gpus() else "cpu:0" + self._temp_data_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self._temp_data_dir) + super(SpinnTest, self).tearDown() + + def testBundle(self): + with tf.device(self._test_device): + lstm_iter = [np.array([[0, 1], [2, 3]], dtype=np.float32), + np.array([[0, -1], [-2, -3]], dtype=np.float32), + np.array([[0, 2], [4, 6]], dtype=np.float32), + np.array([[0, -2], [-4, -6]], dtype=np.float32)] + out = spinn._bundle(lstm_iter) + + self.assertEqual(2, len(out)) + self.assertEqual(tf.float32, out[0].dtype) + self.assertEqual(tf.float32, out[1].dtype) + self.assertAllEqual(np.array([[0, 2, 0, -2, 0, 4, 0, -4]]).T, + out[0].numpy()) + self.assertAllEqual(np.array([[1, 3, -1, -3, 2, 6, -2, -6]]).T, + out[1].numpy()) + + def testUnbunbdle(self): + with tf.device(self._test_device): + state = [np.array([[0, 1, 2], [3, 4, 5]], dtype=np.float32), + np.array([[0, -1, -2], [-3, -4, -5]], dtype=np.float32)] + out = spinn._unbundle(state) + + self.assertEqual(2, len(out)) + self.assertEqual(tf.float32, out[0].dtype) + self.assertEqual(tf.float32, out[1].dtype) + self.assertAllEqual(np.array([[0, 1, 2, 0, -1, -2]]), + out[0].numpy()) + self.assertAllEqual(np.array([[3, 4, 5, -3, -4, -5]]), + out[1].numpy()) + + def testReducer(self): + with tf.device(self._test_device): + batch_size = 3 + size = 10 + tracker_size = 8 + reducer = spinn.Reducer(size, tracker_size=tracker_size) + + left_in = [] + right_in = [] + tracking = [] + for _ in range(batch_size): + left_in.append(tf.random_normal((1, size * 2))) + right_in.append(tf.random_normal((1, size * 2))) + tracking.append(tf.random_normal((1, tracker_size * 2))) + + out = reducer(left_in, right_in, tracking=tracking) + self.assertEqual(batch_size, len(out)) + self.assertEqual(tf.float32, out[0].dtype) + self.assertEqual((1, size * 2), out[0].shape) + + def testReduceTreeLSTM(self): + with tf.device(self._test_device): + size = 10 + tracker_size = 8 + reducer = spinn.Reducer(size, tracker_size=tracker_size) + + lstm_in = np.array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [0, -1, -2, -3, -4, -5, -6, -7, -8, -9]], + dtype=np.float32) + c1 = np.array([[0, 1], [2, 3]], dtype=np.float32) + c2 = np.array([[0, -1], [-2, -3]], dtype=np.float32) + + h, c = reducer._tree_lstm(c1, c2, lstm_in) + self.assertEqual(tf.float32, h.dtype) + self.assertEqual(tf.float32, c.dtype) + self.assertEqual((2, 2), h.shape) + self.assertEqual((2, 2), c.shape) + + def testTracker(self): + with tf.device(self._test_device): + batch_size = 2 + size = 10 + tracker_size = 8 + buffer_length = 18 + stack_size = 3 + + tracker = spinn.Tracker(tracker_size, False) + tracker.reset_state() + + # Create dummy inputs for testing. + bufs = [] + buf = [] + for _ in range(buffer_length): + buf.append(tf.random_normal((batch_size, size * 2))) + bufs.append(buf) + self.assertEqual(1, len(bufs)) + self.assertEqual(buffer_length, len(bufs[0])) + self.assertEqual((batch_size, size * 2), bufs[0][0].shape) + + stacks = [] + stack = [] + for _ in range(stack_size): + stack.append(tf.random_normal((batch_size, size * 2))) + stacks.append(stack) + self.assertEqual(1, len(stacks)) + self.assertEqual(3, len(stacks[0])) + self.assertEqual((batch_size, size * 2), stacks[0][0].shape) + + for _ in range(2): + out1, out2 = tracker(bufs, stacks) + self.assertIsNone(out2) + self.assertEqual(batch_size, len(out1)) + self.assertEqual(tf.float32, out1[0].dtype) + self.assertEqual((1, tracker_size * 2), out1[0].shape) + + self.assertEqual(tf.float32, tracker.state.c.dtype) + self.assertEqual((batch_size, tracker_size), tracker.state.c.shape) + self.assertEqual(tf.float32, tracker.state.h.dtype) + self.assertEqual((batch_size, tracker_size), tracker.state.h.shape) + + def testSPINN(self): + with tf.device(self._test_device): + embedding_dims = 10 + d_tracker = 8 + sequence_length = 15 + num_transitions = 27 + + config_tuple = collections.namedtuple( + "Config", ["d_hidden", "d_proj", "d_tracker", "predict"]) + config = config_tuple( + embedding_dims, embedding_dims * 2, d_tracker, False) + s = spinn.SPINN(config) + + # Create some fake data. + buffers = tf.random_normal((sequence_length, 1, config.d_proj)) + transitions = tf.constant( + [[3], [3], [2], [3], [3], [3], [2], [2], [2], [3], [3], [3], + [2], [3], [3], [2], [2], [3], [3], [3], [2], [2], [2], [2], + [3], [2], [2]], dtype=tf.int64) + self.assertEqual(tf.int64, transitions.dtype) + self.assertEqual((num_transitions, 1), transitions.shape) + + out = s(buffers, transitions, training=True) + self.assertEqual(tf.float32, out.dtype) + self.assertEqual((1, embedding_dims), out.shape) + + def testSNLIClassifierAndTrainer(self): + with tf.device(self._test_device): + vocab_size = 40 + batch_size = 2 + d_embed = 10 + sequence_length = 15 + d_out = 4 + + config = _test_spinn_config(d_embed, d_out) + + # Create fake embedding matrix. + embed = tf.random_normal((vocab_size, d_embed)) + + model = spinn.SNLIClassifier(config, embed) + trainer = spinn.SNLIClassifierTrainer(model, config.lr) + + (labels, prem, prem_trans, hypo, + hypo_trans) = _generate_synthetic_snli_data_batch(sequence_length, + batch_size, + vocab_size) + + # Invoke model under non-training mode. + logits = model(prem, prem_trans, hypo, hypo_trans, training=False) + self.assertEqual(tf.float32, logits.dtype) + self.assertEqual((batch_size, d_out), logits.shape) + + # Invoke model under training model. + logits = model(prem, prem_trans, hypo, hypo_trans, training=True) + self.assertEqual(tf.float32, logits.dtype) + self.assertEqual((batch_size, d_out), logits.shape) + + # Calculate loss. + loss1 = trainer.loss(labels, logits) + self.assertEqual(tf.float32, loss1.dtype) + self.assertEqual((), loss1.shape) + + loss2, logits = trainer.train_batch( + labels, prem, prem_trans, hypo, hypo_trans) + self.assertEqual(tf.float32, loss2.dtype) + self.assertEqual((), loss2.shape) + self.assertEqual(tf.float32, logits.dtype) + self.assertEqual((batch_size, d_out), logits.shape) + # Training on the batch should have led to a change in the loss value. + self.assertNotEqual(loss1.numpy(), loss2.numpy()) + + def testTrainSpinn(self): + """Test with fake toy SNLI data and GloVe vectors.""" + + # 1. Create and load a fake SNLI data file and a fake GloVe embedding file. + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") + os.makedirs(snli_1_0_dir) + + # Four sentences in total. + with open(fake_train_file, "wt") as f: + f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t" + "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t" + "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n") + f.write("neutral\t( ( Foo bar ) . )\t( ( foo . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("contradiction\t( ( Bar foo ) . )\t( ( baz . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("entailment\t( ( Quux quuz ) . )\t( ( grault . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("entailment\t( ( Quuz quux ) . )\t( ( garply . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + + glove_dir = os.path.join(self._temp_data_dir, "glove") + os.makedirs(glove_dir) + glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") + + words = [".", "foo", "bar", "baz", "quux", "quuz", "grault", "garply"] + with open(glove_file, "wt") as f: + for i, word in enumerate(words): + f.write("%s " % word) + for j in range(data.WORD_VECTOR_LEN): + f.write("%.5f" % (i * 0.1)) + if j < data.WORD_VECTOR_LEN - 1: + f.write(" ") + else: + f.write("\n") + + vocab = data.load_vocabulary(self._temp_data_dir) + word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab) + + train_data = data.SnliData(fake_train_file, word2index) + dev_data = data.SnliData(fake_train_file, word2index) + test_data = data.SnliData(fake_train_file, word2index) + print(embed) + + # 2. Create a fake config. + config = _test_spinn_config( + data.WORD_VECTOR_LEN, 4, + logdir=os.path.join(self._temp_data_dir, "logdir")) + + # 3. Test training of a SPINN model. + spinn.train_spinn(embed, train_data, dev_data, test_data, config) + + # 4. Load train loss values from the summary files and verify that they + # decrease with training. + summary_file = glob.glob(os.path.join(config.logdir, "events.out.*"))[0] + events = summary_test_util.events_from_file(summary_file) + train_losses = [event.summary.value[0].simple_value for event in events + if event.summary.value + and event.summary.value[0].tag == "train/loss"] + self.assertEqual(config.epochs, len(train_losses)) + self.assertLess(train_losses[-1], train_losses[0]) + + +class EagerSpinnSNLIClassifierBenchmark(test.Benchmark): + + def benchmarkEagerSpinnSNLIClassifier(self): + test_device = "gpu:0" if tfe.num_gpus() else "cpu:0" + with tf.device(test_device): + burn_in_iterations = 2 + benchmark_iterations = 10 + + vocab_size = 1000 + batch_size = 128 + sequence_length = 15 + d_embed = 200 + d_out = 4 + + embed = tf.random_normal((vocab_size, d_embed)) + + config = _test_spinn_config(d_embed, d_out) + model = spinn.SNLIClassifier(config, embed) + trainer = spinn.SNLIClassifierTrainer(model, config.lr) + + (labels, prem, prem_trans, hypo, + hypo_trans) = _generate_synthetic_snli_data_batch(sequence_length, + batch_size, + vocab_size) + + for _ in range(burn_in_iterations): + trainer.train_batch(labels, prem, prem_trans, hypo, hypo_trans) + + gc.collect() + start_time = time.time() + for _ in xrange(benchmark_iterations): + trainer.train_batch(labels, prem, prem_trans, hypo, hypo_trans) + wall_time = time.time() - start_time + # Named "examples"_per_sec to conform with other benchmarks. + extras = {"examples_per_sec": benchmark_iterations / wall_time} + self.report_benchmark( + name="Eager_SPINN_SNLIClassifier_Benchmark", + iters=benchmark_iterations, + wall_time=wall_time, + extras=extras) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/eager/python/g3doc/guide.md b/tensorflow/contrib/eager/python/g3doc/guide.md index 147b7047f42b7ccba5829b61370e82e217ce5838..0095ffa0db99d46d25654d73504d0d7d41c18b6f 100644 --- a/tensorflow/contrib/eager/python/g3doc/guide.md +++ b/tensorflow/contrib/eager/python/g3doc/guide.md @@ -757,7 +757,7 @@ For example, to record summaries once every 100 global steps, use: ```python tf.train.get_or_create_global_step() # Ensuring the global step variable exists -writer = tf.contrib.summary.create_summary_file_writer(logdir) +writer = tf.contrib.summary.create_file_writer(logdir) for _ in range(iterations): with writer.as_default(): diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index aa359b7a0d7d89e8788c323d1621798d1a22b658..2f8016ede3caee6dbb6fd8f5226f1464b5c3976b 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -73,7 +73,7 @@ class Metric(object): * `result()`: Computes and returns a final value for the metric from the variables in `self`. - Decendants may override `aggregate()`, but usually won't need to. It + Descendants may override `aggregate()`, but usually won't need to. It adds in the state from a list of metrics of the same type as `self`. (Default is to sum all the variables.) Note that users should not call `aggregate()`, it is for use by TensorFlow infrastructure. diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index 96eb1b4f2a0e4c4af1f3310a2801b1b6aee285d6..1055f4563cd4608189281450aed512fbf5f31de1 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -67,7 +67,7 @@ class MetricsTest(test.TestCase): m([1, 10, 100]) training_util.get_or_create_global_step() logdir = tempfile.mkdtemp() - with summary_ops.create_summary_file_writer( + with summary_ops.create_file_writer( logdir, max_queue=0, name="t0").as_default(), summary_ops.always_record_summaries(): m.result() # As a side-effect will write summaries. diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py index 97eded7dca2c0594321a006fecb360e26675a005..e3c13cbd2e8ccd2ab79da74e0e97905c6ed5c02d 100644 --- a/tensorflow/contrib/eager/python/network.py +++ b/tensorflow/contrib/eager/python/network.py @@ -54,16 +54,81 @@ def _network_name_scope_naming(current_variable_scope): class Network(base.Layer): """Represents the composition of a set of Layers. - TODO(josh11b,ashankar): - - Should "trainable" be changeable on the Network object? - - Do we allow add_variable in Network? - - Detect layers used in __call__ that weren't registered with track_layer. - - Convert inputs to __call__ to tensors. - - Prevent variables from being created after the first __call__? - (Think about restoring from a checkpoint). + `Network` implements the `Layer` interface and adds convenience methods for + managing sub-`Layer`s, such as listing variables. + + `Layer`s (including other `Network`s) should be added via `track_layer`. They + can then be used when overriding the `Network.call` method: + + ```python + class TwoLayerNetwork(tfe.Network): + + def __init__(self, name): + super(TwoLayerNetwork, self).__init__(name=name) + self.layer_one = self.track_layer(tf.layers.Dense(16, input_shape=(8,))) + self.layer_two = self.track_layer(tf.layers.Dense(1, input_shape=(16,))) + + def call(self, inputs): + return self.layer_two(self.layer_one(inputs)) + ``` + + After constructing an object and calling the `Network`, a list of variables + created by tracked `Layer`s is available via `Network.variables`: + + ```python + net = TwoLayerNetwork(name="net") + output = net(tf.ones([1, 8])) + print([v.name for v in net.variables]) + ``` + + This example prints variable names, one kernel and one bias per + `tf.layers.Dense` layer: + + ``` + ['net/dense/kernel:0', + 'net/dense/bias:0', + 'net/dense_1/kernel:0', + 'net/dense_1/bias:0'] + ``` + + These variables can be passed to a `Saver` (`tf.train.Saver`, or + `tf.contrib.eager.Saver` when executing eagerly) to save or restore the + `Network`, typically alongside a global step and `tf.train.Optimizer` + variables when checkpointing during training. + + Note that the semantics of calling a `Network` with graph execution (i.e. not + executing eagerly) may change slightly in the future. Currently stateful ops + are pruned from the graph unless they or something that depends on them is + executed in a session, but this behavior is not consistent with eager + execution (where stateful ops are executed eagerly). `Layer`s from `tf.layers` + do not depend on this pruning and so will not be affected, but `Network`s + which rely on stateful ops being added to the graph but not executed (e.g. via + custom `Layer`s which manage stateful ops) may break with this change. """ + # TODO(josh11b,ashankar,allenl): + # - Should 'trainable' be changeable on the Network object? + # - Do we allow add_variable in Network? + # - Detect layers used in __call__ that weren't registered with track_layer. + # - Convert inputs to __call__ to tensors. def __init__(self, name=None): + """Configure the `Network`. + + Args: + name: The name to use for this `Network`. If specified, it must be unique + in the context where this `Network` is first + (1) added to another `Network` (in which case it must not share a name + with other `Layers` added to that `Network`), or + (2) built/called (in which case no other 'top-level' `Network`s may + share this name). + If unspecified or None, the `Network` will be named using its class + name, with a number appended if necessary for uniqueness (e.g. MyNetwork + -> 'my_network_1'). + + Raises: + ValueError: If `name` is not valid. Note that some naming errors will + instead be raised when the `Network` is called. + """ if isinstance(name, variable_scope.VariableScope): raise ValueError("VariableScopes are not valid Network names.") if name is not None and "/" in name: @@ -386,8 +451,30 @@ class Network(base.Layer): "at https://github.com/tensorflow/tensorflow/issues/new if this is " "important to you") - # TODO(josh11b): Support other Layer methods needed for graph mode, such as for - # losses and updates + def add_loss(self, losses, inputs=None): + raise RuntimeError( + "add_loss is not supported in Network class yet. Please file an issue " + "at https://github.com/tensorflow/tensorflow/issues/new if this is " + "important to you") + + @property + def losses(self): + """Gather losses from `Layer`s in the `Network`. + + Note that when executing eagerly, `Layer.losses` evaluates + regularizers. When using graph execution, variable regularization ops have + already been created and are simply returned here. + + Returns: + A list of tensors. + """ + layer_losses = [] + for layer in self.layers: + layer_losses.extend(layer.losses) + return layer_losses + + # TODO(allenl): Support other Layer methods needed for graph mode, such as for + # updates class Sequential(Network): diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py index e7835a63e6db926aa2d4b6c76c681c8a301757bd..3eb4f5f8b3954a7ed04d2ef1d4f119ad137e1e65 100644 --- a/tensorflow/contrib/eager/python/network_test.py +++ b/tensorflow/contrib/eager/python/network_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import gc from tensorflow.contrib.eager.python import network +from tensorflow.contrib.layers.python.layers import regularizers from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.eager import test @@ -45,6 +46,22 @@ class MyNetwork(network.Network): return self.l1(x) +class RegularizedNetwork(network.Network): + + def __init__(self): + super(RegularizedNetwork, self).__init__() + self.l1 = self.track_layer(core.Dense( + 1, + bias_regularizer=regularizers.l1_regularizer(2.0), + kernel_regularizer=regularizers.l1_regularizer(2.0))) + self.l2 = self.track_layer(core.Dense( + 1, + bias_regularizer=regularizers.l1_regularizer(2.0))) + + def call(self, values): + return self.l2(self.l1(values)) + + class NetworkTest(test.TestCase): def _save_modify_load_network_built(self, net, global_step=None): @@ -484,6 +501,18 @@ class NetworkTest(test.TestCase): _check_op_prefixes(expected_prefix="my_network_1/dense/", checked_ops=checked_ops) + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testVariableRegularizers(self): + net = RegularizedNetwork() + net(constant_op.constant([[1.]])) + self.evaluate(net.variables[0].assign([[2.]])) + self.evaluate(net.variables[1].assign([3.])) + self.evaluate(net.variables[2].assign([[-2.]])) + self.evaluate(net.variables[3].assign([4.])) + self.assertAllEqual([4., 6., 8.], self.evaluate(net.losses)) + self.evaluate(net.variables[3].assign([5.])) + self.assertAllEqual([4., 6., 10.], self.evaluate(net.losses)) + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testDuplicateNameError(self): one = constant_op.constant([[1.]]) diff --git a/tensorflow/contrib/eager/python/summary_writer.py b/tensorflow/contrib/eager/python/summary_writer.py deleted file mode 100644 index 5d8c41b545b3c9fd03af85f302ba05a394f085a4..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/summary_writer.py +++ /dev/null @@ -1,242 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""TensorBoard Summary Writer for TensorFlow Eager Execution.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import uuid - -from tensorflow.contrib.summary import gen_summary_ops -from tensorflow.python.eager import context -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import summary_op_util -from tensorflow.python.ops import variable_scope - - -def _maybe_cpu(v): - if isinstance(v, (ops.EagerTensor, ops.Tensor)): - return v.cpu() - else: - return v - - -def _summary_writer_function(name, tensor, function, family=None): - def record(): - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - function(tag, scope) - return True - return record - - -class SummaryWriter(object): - """Writes summaries for TensorBoard, compatible with eager execution. - - This class is the supported way of writing TensorBoard summaries under - eager execution. - """ - - _CPU_DEVICE = "cpu:0" - - def __init__(self, - logdir, - max_queue=10, - flush_secs=120, - filename_suffix=""): - """Summary writer for TensorBoard, compatible with eager execution. - - If necessary, multiple instances of `SummaryWriter` can be created, with - distinct `logdir`s and `name`s. Each `SummaryWriter` instance will retain - its independent `global_step` counter and data writing destination. - - Example: - ```python - writer = tfe.SummaryWriter("my_model") - - # ... Code that sets up the model and data batches ... - - for _ in xrange(train_iters): - loss = model.train_batch(batch) - writer.scalar("loss", loss) - writer.step() - ``` - - Args: - logdir: Directory in which summary files will be written. - max_queue: Number of summary items to buffer before flushing to - filesystem. If 0, summaries will be flushed immediately. - flush_secs: Number of secondsbetween forced commits to disk. - filename_suffix: Suffix of the event protobuf files in which the summary - data are stored. - - Raises: - ValueError: If this constructor is called not under eager execution. - """ - # TODO(apassos, ashankar): Make this class and the underlying - # contrib.summary_ops compatible with graph model and remove this check. - if not context.in_eager_mode(): - raise ValueError( - "Use of SummaryWriter is currently supported only with eager " - "execution enabled. File an issue at " - "https://github.com/tensorflow/tensorflow/issues/new to express " - "interest in fixing this.") - - # TODO(cais): Consider adding name keyword argument, which if None or empty, - # will register the global global_step that training_util.get_global_step() - # can find. - with context.device(self._CPU_DEVICE): - self._name = uuid.uuid4().hex - self._global_step = 0 - self._global_step_tensor = variable_scope.get_variable( - "global_step/summary_writer/" + self._name, - shape=[], dtype=dtypes.int64, - initializer=init_ops.zeros_initializer()) - self._global_step_dirty = False - self._resource = gen_summary_ops.summary_writer(shared_name=self._name) - gen_summary_ops.create_summary_file_writer( - self._resource, logdir, max_queue, flush_secs, filename_suffix) - # Delete the resource when this object is deleted - self._resource_deleter = resource_variable_ops.EagerResourceDeleter( - handle=self._resource, handle_device=self._CPU_DEVICE) - - def step(self): - """Increment the global step counter of this SummaryWriter instance.""" - self._global_step += 1 - self._global_step_dirty = True - - @property - def global_step(self): - """Obtain the current global_step value of this SummaryWriter instance. - - Returns: - An `int` representing the current value of the global_step of this - `SummaryWriter` instance. - """ - return self._global_step - - def _update_global_step_tensor(self): - with context.device(self._CPU_DEVICE): - if self._global_step_dirty: - self._global_step_dirty = False - return state_ops.assign(self._global_step_tensor, self._global_step) - else: - return self._global_step_tensor - - def generic(self, name, tensor, metadata, family=None): - """Write a generic-type summary. - - Args: - name: A name for the generated node. Will also serve as the series name in - TensorBoard. - tensor: A `Tensor` or compatible value type containing the value of the - summary. - metadata: Metadata about the summary. - family: Optional; if provided, used as the prefix of the summary tag name, - which controls the tab name used for display on Tensorboard. - """ - with context.device(self._CPU_DEVICE): - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_summary( - self._resource, - self._update_global_step_tensor(), - _maybe_cpu(tensor), - tag, - _maybe_cpu(metadata), - name=scope) - - def scalar(self, name, tensor, family=None): - """Write a scalar summary. - - Args: - name: A name for the generated node. Will also serve as the series name in - TensorBoard. - tensor: A real numeric `Tensor` or compatible value type containing a - single value. - family: Optional; if provided, used as the prefix of the summary tag name, - which controls the tab name used for display on Tensorboard. - - Returns: - A summary writer function for scalars. - """ - with context.device(self._CPU_DEVICE): - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_scalar_summary( - self._resource, self._update_global_step_tensor(), - tag, _maybe_cpu(tensor), name=scope) - - def histogram(self, name, tensor, family=None): - """Write a histogram summary. - - Args: - name: A name for the generated node. Will also serve as a series name in - TensorBoard. - tensor: A real numeric `Tensor` or compatible value type. Any shape. - Values to use to build the histogram. - family: Optional; if provided, used as the prefix of the summary tag name, - which controls the tab name used for display on Tensorboard. - """ - with context.device(self._CPU_DEVICE): - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_histogram_summary( - self._resource, self._update_global_step_tensor(), - tag, _maybe_cpu(tensor), name=scope) - - def image(self, name, tensor, bad_color=None, max_images=3, family=None): - """Write an image summary.""" - with context.device(self._CPU_DEVICE): - if bad_color is None: - bad_color_ = constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8) - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_image_summary( - self._resource, self._update_global_step_tensor(), - tag, _maybe_cpu(tensor), bad_color_, max_images, - name=scope) - - def audio(self, name, tensor, sample_rate, max_outputs, family=None): - """Write an audio summary. - - Args: - name: A name for the generated node. Will also serve as a series name in - TensorBoard. - tensor: A 3-D `float32` `Tensor` of shape `[batch_size, frames, channels]` - or a 2-D `float32` `Tensor` of shape `[batch_size, frames]`, or - compatible value type. - sample_rate: A Scalar `float32` `Tensor` indicating the sample rate of the - signal in hertz. - max_outputs: Max number of batch elements to generate audio for. - family: Optional; if provided, used as the prefix of the summary tag name, - which controls the tab name used for display on Tensorboard. - """ - with context.device(self._CPU_DEVICE): - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_audio_summary( - self._resource, self._update_global_step_tensor(), - tag, - _maybe_cpu(tensor), - sample_rate=_maybe_cpu(sample_rate), - max_outputs=max_outputs, - name=scope) diff --git a/tensorflow/contrib/eager/python/summary_writer_test.py b/tensorflow/contrib/eager/python/summary_writer_test.py deleted file mode 100644 index 5ebb36d04fcba8f4558fa1c09716314af42f559f..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/summary_writer_test.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Unit tests for eager execution SummaryWriter.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import shutil -import tempfile - -import numpy as np - -from tensorflow.contrib.eager.python import summary_writer -from tensorflow.core.util import event_pb2 -from tensorflow.python.eager import context -from tensorflow.python.eager import test -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.lib.io import tf_record -from tensorflow.python.platform import gfile - - -class SummaryWriterTest(test.TestCase): - - def setUp(self): - super(SummaryWriterTest, self).setUp() - self._test_device = "gpu:0" if context.num_gpus() else "cpu:0" - self._tmp_logdir = tempfile.mkdtemp() - with context.device(self._test_device): - # Use max_queue=0 so that summaries are immediately flushed to filesystem, - # making testing easier. - self._writer = summary_writer.SummaryWriter(self._tmp_logdir, max_queue=0) - - def tearDown(self): - if os.path.isdir(self._tmp_logdir): - shutil.rmtree(self._tmp_logdir) - super(SummaryWriterTest, self).tearDown() - - def _readLastEvent(self, logdir=None): - if not logdir: - logdir = self._tmp_logdir - files = [f for f in gfile.ListDirectory(logdir) - if not gfile.IsDirectory(os.path.join(logdir, f))] - file_path = os.path.join(logdir, files[0]) - records = list(tf_record.tf_record_iterator(file_path)) - event = event_pb2.Event() - event.ParseFromString(records[-1]) - return event - - def testGlobalStep(self): - with context.device(self._test_device): - orig_step = self._writer.global_step - self._writer.step() - self.assertEqual(orig_step + 1, self._writer.global_step) - self.assertEqual(orig_step + 1, self._writer.global_step) - self._writer.step() - self._writer.step() - self.assertEqual(orig_step + 3, self._writer.global_step) - - def testGenericSummary(self): - with context.device(self._test_device): - x = constant_op.constant(1337.0) - with context.device("cpu:0"): - metadata = constant_op.constant("foo") - self._writer.generic("x", x, metadata) - event = self._readLastEvent() - self.assertEqual("x", event.summary.value[0].tag) - - def testScalarSummary(self): - with context.device(self._test_device): - x = constant_op.constant(1337.0) - self._writer.scalar("x", x) - event = self._readLastEvent() - self.assertTrue("x", event.summary.value[0].tag) - self.assertEqual(1337.0, event.summary.value[0].simple_value) - - def testHistogramSummary(self): - with context.device(self._test_device): - y = constant_op.constant([1.0, 3.0, 3.0, 7.0]) - self._writer.histogram("y", y) - event = self._readLastEvent() - self.assertEqual("y", event.summary.value[0].tag) - self.assertTrue(event.summary.value[0].histo) - - def testImageSummary(self): - with context.device(self._test_device): - a = constant_op.constant([[10.0, 20.0], [-20.0, -10.0]]) - self._writer.histogram("image1", a) - event = self._readLastEvent() - self.assertEqual("image1", event.summary.value[0].tag) - self.assertTrue(event.summary.value[0].image) - - def testAudioSummary(self): - with context.device(self._test_device): - w = constant_op.constant(np.random.rand(3, 10, 2), dtype=dtypes.float32) - fs = constant_op.constant(44100.0, dtype=dtypes.float32) - max_outputs = 1 - self._writer.audio("audio1", w, fs, max_outputs) - event = self._readLastEvent() - self.assertTrue(event.summary.value[0].audio) - - def testTwoSummaryWritersGlobalStepsWorkWithoutCrosstalk(self): - tmp_logdir2 = os.path.join(self._tmp_logdir, "_writer2_") - writer2 = summary_writer.SummaryWriter(tmp_logdir2, max_queue=0) - - self.assertEqual(0, writer2.global_step) - self._writer.step() - self.assertEqual(0, writer2.global_step) - writer2.step() - writer2.step() - writer2.step() - self.assertEqual(3, writer2.global_step) - - x = constant_op.constant(1337.0) - writer_orig_step = self._writer.global_step - self._writer.step() - self._writer.scalar("x", x) - - event = self._readLastEvent() - self.assertEqual(writer_orig_step + 1, event.step) - - writer2.scalar("x", x) - event = self._readLastEvent(tmp_logdir2) - self.assertEqual(3, event.step) - - self._writer.step() - self._writer.scalar("x", x) - - event = self._readLastEvent() - self.assertEqual(writer_orig_step + 2, event.step) - - -# TODO(cais): Add performance benchmark for SummaryWriter. - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index 1697c879def8af5c05f3c9b11d318d570785d6de..770a7e3e7a01f3351c229b7fb53383240dd1f1c8 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -23,6 +23,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@list_devices @@num_gpus +@@py_func @@defun @@implicit_gradients @@implicit_value_and_gradients @@ -101,8 +102,10 @@ from tensorflow.python.framework.test_util import IsolateTest from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable from tensorflow.python.ops.variable_scope import EagerVariableStore +from tensorflow.python.ops import script_ops from tensorflow.python.util.all_util import remove_undocumented +py_func = script_ops.eager_py_func defun = function.defun implicit_gradients = backprop.implicit_grad implicit_value_and_gradients = backprop.implicit_val_and_grad diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 008ca7a5d17437213ad64a54dddd40ad37e81df0..bd65ece85d2bfc6b38ba3507d3e702241eaf6067 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -27,8 +27,10 @@ py_library( srcs_version = "PY2AND3", deps = [ ":dnn", + ":dnn_linear_combined", ":extenders", ":head", + ":linear", ":logit_fns", ":multi_head", ":replicate_model_fn", @@ -73,6 +75,46 @@ py_test( ], ) +py_library( + name = "dnn_linear_combined", + srcs = ["python/estimator/dnn_linear_combined.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:nn", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:dnn_linear_combined", + ], +) + +py_test( + name = "dnn_linear_combined_test", + size = "medium", + srcs = ["python/estimator/dnn_linear_combined_test.py"], + shard_count = 3, + srcs_version = "PY2AND3", + tags = [ + "no_pip", + "notsan", + ], + deps = [ + ":dnn_linear_combined", + ":head", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:nn", + "//tensorflow/python:platform", + "//tensorflow/python:summary", + "//tensorflow/python/estimator:dnn_testing_utils", + "//tensorflow/python/estimator:export_export", + "//tensorflow/python/estimator:linear_testing_utils", + "//tensorflow/python/estimator:numpy_io", + "//tensorflow/python/estimator:prediction_keys", + "//tensorflow/python/feature_column", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + py_library( name = "extenders", srcs = [ @@ -169,6 +211,42 @@ py_test( ], ) +py_library( + name = "linear", + srcs = ["python/estimator/linear.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:linear", + ], +) + +py_test( + name = "linear_test", + size = "small", + srcs = ["python/estimator/linear_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + "notsan", + ], + deps = [ + ":head", + ":linear", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:summary", + "//tensorflow/python/estimator:export_export", + "//tensorflow/python/estimator:linear_testing_utils", + "//tensorflow/python/estimator:numpy_io", + "//tensorflow/python/estimator:prediction_keys", + "//tensorflow/python/feature_column", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + py_library( name = "logit_fns", srcs = [ @@ -253,23 +331,24 @@ py_library( "//tensorflow/python:device", "//tensorflow/python:device_lib", "//tensorflow/python:framework_ops", - "//tensorflow/python:gradients", "//tensorflow/python:math_ops", "//tensorflow/python:platform", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:state_ops", "//tensorflow/python:training", "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", "//tensorflow/python/estimator:export_output", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/estimator:util", + "//tensorflow/python/ops/losses", "@six_archive//:six", ], ) cuda_py_test( name = "replicate_model_fn_test", - size = "small", + size = "medium", srcs = ["python/estimator/replicate_model_fn_test.py"], additional_deps = [ "//tensorflow/python/estimator", @@ -297,5 +376,5 @@ cuda_py_test( "//tensorflow/python:variables", ":replicate_model_fn", ], - tags = ["requires-gpu-sm35"], + tags = ["multi_gpu"], ) diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index cf727264cd5116915f6bd7f285e470cbc2e2742a..28c1f8b1809d27db697365b7bb50441f7820d2b4 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -20,10 +20,13 @@ from __future__ import print_function # pylint: disable=unused-import,line-too-long,wildcard-import from tensorflow.contrib.estimator.python.estimator.dnn import * +from tensorflow.contrib.estimator.python.estimator.dnn_linear_combined import * from tensorflow.contrib.estimator.python.estimator.extenders import * from tensorflow.contrib.estimator.python.estimator.head import * +from tensorflow.contrib.estimator.python.estimator.linear import * from tensorflow.contrib.estimator.python.estimator.logit_fns import * from tensorflow.contrib.estimator.python.estimator.multi_head import * +from tensorflow.contrib.estimator.python.estimator.replicate_model_fn import * from tensorflow.python.util.all_util import remove_undocumented # pylint: enable=unused-import,line-too-long,wildcard-import @@ -38,9 +41,12 @@ _allowed_symbols = [ 'multi_label_head', 'regression_head', 'DNNEstimator', + 'DNNLinearCombinedEstimator', + 'LinearEstimator', 'call_logit_fn', 'dnn_logit_fn_builder', 'linear_logit_fn_builder', + 'replicate_model_fn', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py new file mode 100644 index 0000000000000000000000000000000000000000..ccaf1128bf23af734f7a5722a4dd8c1f0304fab7 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py @@ -0,0 +1,164 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TensorFlow estimator for Linear and DNN joined training models.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator.canned import dnn_linear_combined as dnn_linear_combined_lib +from tensorflow.python.ops import nn + + +class DNNLinearCombinedEstimator(estimator.Estimator): + """An estimator for TensorFlow Linear and DNN joined models with custom head. + + Note: This estimator is also known as wide-n-deep. + + Example: + + ```python + numeric_feature = numeric_column(...) + categorical_column_a = categorical_column_with_hash_bucket(...) + categorical_column_b = categorical_column_with_hash_bucket(...) + + categorical_feature_a_x_categorical_feature_b = crossed_column(...) + categorical_feature_a_emb = embedding_column( + categorical_column=categorical_feature_a, ...) + categorical_feature_b_emb = embedding_column( + categorical_column=categorical_feature_b, ...) + + estimator = DNNLinearCombinedEstimator( + head=tf.contrib.estimator.multi_label_head(n_classes=3), + # wide settings + linear_feature_columns=[categorical_feature_a_x_categorical_feature_b], + linear_optimizer=tf.train.FtrlOptimizer(...), + # deep settings + dnn_feature_columns=[ + categorical_feature_a_emb, categorical_feature_b_emb, + numeric_feature], + dnn_hidden_units=[1000, 500, 100], + dnn_optimizer=tf.train.ProximalAdagradOptimizer(...)) + + # To apply L1 and L2 regularization, you can set optimizers as follows: + tf.train.ProximalAdagradOptimizer( + learning_rate=0.1, + l1_regularization_strength=0.001, + l2_regularization_strength=0.001) + # It is same for FtrlOptimizer. + + # Input builders + def input_fn_train: # returns x, y + pass + estimator.train(input_fn=input_fn_train, steps=100) + + def input_fn_eval: # returns x, y + pass + metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10) + def input_fn_predict: # returns x, None + pass + predictions = estimator.predict(input_fn=input_fn_predict) + ``` + + Input of `train` and `evaluate` should have following features, + otherwise there will be a `KeyError`: + + * for each `column` in `dnn_feature_columns` + `linear_feature_columns`: + - if `column` is a `_CategoricalColumn`, a feature with `key=column.name` + whose `value` is a `SparseTensor`. + - if `column` is a `_WeightedCategoricalColumn`, two features: the first + with `key` the id column name, the second with `key` the weight column + name. Both features' `value` must be a `SparseTensor`. + - if `column` is a `_DenseColumn`, a feature with `key=column.name` + whose `value` is a `Tensor`. + + Loss is calculated by using mean squared error. + + @compatibility(eager) + Estimators are not compatible with eager execution. + @end_compatibility + """ + + def __init__(self, + head, + model_dir=None, + linear_feature_columns=None, + linear_optimizer='Ftrl', + dnn_feature_columns=None, + dnn_optimizer='Adagrad', + dnn_hidden_units=None, + dnn_activation_fn=nn.relu, + dnn_dropout=None, + input_layer_partitioner=None, + config=None): + """Initializes a DNNLinearCombinedEstimator instance. + + Args: + head: A `_Head` instance constructed with a method such as + `tf.contrib.estimator.multi_label_head`. + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator + to continue training a previously saved model. + linear_feature_columns: An iterable containing all the feature columns + used by linear part of the model. All items in the set must be + instances of classes derived from `FeatureColumn`. + linear_optimizer: An instance of `tf.Optimizer` used to apply gradients to + the linear part of the model. Defaults to FTRL optimizer. + dnn_feature_columns: An iterable containing all the feature columns used + by deep part of the model. All items in the set must be instances of + classes derived from `FeatureColumn`. + dnn_optimizer: An instance of `tf.Optimizer` used to apply gradients to + the deep part of the model. Defaults to Adagrad optimizer. + dnn_hidden_units: List of hidden units per layer. All layers are fully + connected. + dnn_activation_fn: Activation function applied to each layer. If None, + will use `tf.nn.relu`. + dnn_dropout: When not None, the probability we will drop out + a given coordinate. + input_layer_partitioner: Partitioner for input layer. Defaults to + `min_max_variable_partitioner` with `min_slice_size` 64 << 20. + config: RunConfig object to configure the runtime settings. + + Raises: + ValueError: If both linear_feature_columns and dnn_features_columns are + empty at the same time. + """ + linear_feature_columns = linear_feature_columns or [] + dnn_feature_columns = dnn_feature_columns or [] + self._feature_columns = ( + list(linear_feature_columns) + list(dnn_feature_columns)) + if not self._feature_columns: + raise ValueError('Either linear_feature_columns or dnn_feature_columns ' + 'must be defined.') + + def _model_fn(features, labels, mode, config): + return dnn_linear_combined_lib._dnn_linear_combined_model_fn( # pylint: disable=protected-access + features=features, + labels=labels, + mode=mode, + head=head, + linear_feature_columns=linear_feature_columns, + linear_optimizer=linear_optimizer, + dnn_feature_columns=dnn_feature_columns, + dnn_optimizer=dnn_optimizer, + dnn_hidden_units=dnn_hidden_units, + dnn_activation_fn=dnn_activation_fn, + dnn_dropout=dnn_dropout, + input_layer_partitioner=input_layer_partitioner, + config=config) + + super(DNNLinearCombinedEstimator, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b5e4d34dc70ccaa4806ae8b8ed5001bd971ee7b4 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py @@ -0,0 +1,220 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for dnn_linear_combined.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import shutil +import tempfile + +import numpy as np +import six + +from tensorflow.contrib.estimator.python.estimator import dnn_linear_combined +from tensorflow.contrib.estimator.python.estimator import head as head_lib +from tensorflow.python.estimator.canned import dnn_testing_utils +from tensorflow.python.estimator.canned import linear_testing_utils +from tensorflow.python.estimator.canned import prediction_keys +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column +from tensorflow.python.framework import ops +from tensorflow.python.ops import nn +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.summary.writer import writer_cache + + +def _dnn_only_estimator_fn( + hidden_units, + feature_columns, + model_dir=None, + label_dimension=1, + weight_column=None, + optimizer='Adagrad', + activation_fn=nn.relu, + dropout=None, + input_layer_partitioner=None, + config=None): + return dnn_linear_combined.DNNLinearCombinedEstimator( + head=head_lib.regression_head( + weight_column=weight_column, label_dimension=label_dimension), + model_dir=model_dir, + dnn_feature_columns=feature_columns, + dnn_optimizer=optimizer, + dnn_hidden_units=hidden_units, + dnn_activation_fn=activation_fn, + dnn_dropout=dropout, + input_layer_partitioner=input_layer_partitioner, + config=config) + + +class DNNOnlyEstimatorEvaluateTest( + dnn_testing_utils.BaseDNNRegressorEvaluateTest, test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__( + self, _dnn_only_estimator_fn) + + +class DNNOnlyEstimatorPredictTest( + dnn_testing_utils.BaseDNNRegressorPredictTest, test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + dnn_testing_utils.BaseDNNRegressorPredictTest.__init__( + self, _dnn_only_estimator_fn) + + +class DNNOnlyEstimatorTrainTest( + dnn_testing_utils.BaseDNNRegressorTrainTest, test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + dnn_testing_utils.BaseDNNRegressorTrainTest.__init__( + self, _dnn_only_estimator_fn) + + +def _linear_only_estimator_fn( + feature_columns, + model_dir=None, + label_dimension=1, + weight_column=None, + optimizer='Ftrl', + config=None, + partitioner=None): + return dnn_linear_combined.DNNLinearCombinedEstimator( + head=head_lib.regression_head( + weight_column=weight_column, label_dimension=label_dimension), + model_dir=model_dir, + linear_feature_columns=feature_columns, + linear_optimizer=optimizer, + input_layer_partitioner=partitioner, + config=config) + + +class LinearOnlyEstimatorEvaluateTest( + linear_testing_utils.BaseLinearRegressorEvaluationTest, test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + linear_testing_utils.BaseLinearRegressorEvaluationTest.__init__( + self, _linear_only_estimator_fn) + + +class LinearOnlyEstimatorPredictTest( + linear_testing_utils.BaseLinearRegressorPredictTest, test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + linear_testing_utils.BaseLinearRegressorPredictTest.__init__( + self, _linear_only_estimator_fn) + + +class LinearOnlyEstimatorTrainTest( + linear_testing_utils.BaseLinearRegressorTrainingTest, test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + linear_testing_utils.BaseLinearRegressorTrainingTest.__init__( + self, _linear_only_estimator_fn) + + +class DNNLinearCombinedEstimatorIntegrationTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _test_complete_flow( + self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension, + label_dimension, batch_size): + linear_feature_columns = [ + feature_column.numeric_column('x', shape=(input_dimension,))] + dnn_feature_columns = [ + feature_column.numeric_column('x', shape=(input_dimension,))] + feature_columns = linear_feature_columns + dnn_feature_columns + est = dnn_linear_combined.DNNLinearCombinedEstimator( + head=head_lib.regression_head(label_dimension=label_dimension), + linear_feature_columns=linear_feature_columns, + dnn_feature_columns=dnn_feature_columns, + dnn_hidden_units=(2, 2), + model_dir=self._model_dir) + + # TRAIN + num_steps = 10 + est.train(train_input_fn, steps=num_steps) + + # EVALUTE + scores = est.evaluate(eval_input_fn) + self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) + self.assertIn('loss', six.iterkeys(scores)) + + # PREDICT + predictions = np.array([ + x[prediction_keys.PredictionKeys.PREDICTIONS] + for x in est.predict(predict_input_fn) + ]) + self.assertAllEqual((batch_size, label_dimension), predictions.shape) + + # EXPORT + feature_spec = feature_column.make_parse_example_spec(feature_columns) + serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( + feature_spec) + export_dir = est.export_savedmodel(tempfile.mkdtemp(), + serving_input_receiver_fn) + self.assertTrue(gfile.Exists(export_dir)) + + def test_numpy_input_fn(self): + """Tests complete flow with numpy_input_fn.""" + label_dimension = 2 + batch_size = 10 + data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32) + data = data.reshape(batch_size, label_dimension) + # learn y = x + train_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + eval_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + shuffle=False) + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + batch_size=batch_size, + shuffle=False) + + self._test_complete_flow( + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + predict_input_fn=predict_input_fn, + input_dimension=label_dimension, + label_dimension=label_dimension, + batch_size=batch_size) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/linear.py b/tensorflow/contrib/estimator/python/estimator/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..3bf4abe83d54504d55de73b63f369cceaf149dd2 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/linear.py @@ -0,0 +1,118 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Linear estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator.canned import linear as linear_lib + + +class LinearEstimator(estimator.Estimator): + """An estimator for TensorFlow linear models with user-specified head. + + Example: + + ```python + categorical_column_a = categorical_column_with_hash_bucket(...) + categorical_column_b = categorical_column_with_hash_bucket(...) + + categorical_feature_a_x_categorical_feature_b = crossed_column(...) + + # Estimator using the default optimizer. + estimator = LinearEstimator( + head=tf.contrib.estimator.multi_label_head(n_classes=3), + feature_columns=[categorical_column_a, + categorical_feature_a_x_categorical_feature_b]) + + # Or estimator using the FTRL optimizer with regularization. + estimator = LinearEstimator( + head=tf.contrib.estimator.multi_label_head(n_classes=3), + feature_columns=[categorical_column_a, + categorical_feature_a_x_categorical_feature_b]) + optimizer=tf.train.FtrlOptimizer( + learning_rate=0.1, + l1_regularization_strength=0.001 + )) + + def input_fn_train: # returns x, y (where y represents label's class index). + ... + estimator.train(input_fn=input_fn_train, steps=100) + def input_fn_eval: # returns x, y (where y represents label's class index). + ... + metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10) + def input_fn_predict: # returns x, None + ... + predictions = estimator.predict(input_fn=input_fn_predict) + ``` + + Input of `train` and `evaluate` should have following features, + otherwise there will be a `KeyError`: + + * if `weight_column` is not `None`, a feature with + `key=weight_column` whose value is a `Tensor`. + * for each `column` in `feature_columns`: + - if `column` is a `_CategoricalColumn`, a feature with `key=column.name` + whose `value` is a `SparseTensor`. + - if `column` is a `_WeightedCategoricalColumn`, two features: the first + with `key` the id column name, the second with `key` the weight column + name. Both features' `value` must be a `SparseTensor`. + - if `column` is a `_DenseColumn`, a feature with `key=column.name` + whose `value` is a `Tensor`. + + Loss and predicted output are determined by the specified head. + + @compatibility(eager) + Estimators are not compatible with eager execution. + @end_compatibility + """ + + def __init__(self, + head, + feature_columns, + model_dir=None, + optimizer='Ftrl', + config=None, + partitioner=None): + """Initializes a `LinearEstimator` instance. + + Args: + head: A `_Head` instance constructed with a method such as + `tf.contrib.estimator.multi_label_head`. + feature_columns: An iterable containing all the feature columns used by + the model. All items in the set should be instances of classes derived + from `FeatureColumn`. + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator + to continue training a previously saved model. + optimizer: An instance of `tf.Optimizer` used to train the model. Defaults + to FTRL optimizer. + config: `RunConfig` object to configure the runtime settings. + partitioner: Optional. Partitioner for input layer. + """ + def _model_fn(features, labels, mode, config): + return linear_lib._linear_model_fn( # pylint: disable=protected-access + features=features, + labels=labels, + mode=mode, + head=head, + feature_columns=tuple(feature_columns or []), + optimizer=optimizer, + partitioner=partitioner, + config=config) + super(LinearEstimator, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/contrib/estimator/python/estimator/linear_test.py b/tensorflow/contrib/estimator/python/estimator/linear_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c63514eb688af48577f0a3b7ce9e7478309f2c30 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/linear_test.py @@ -0,0 +1,153 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for linear.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import shutil +import tempfile + +import numpy as np +import six + +from tensorflow.contrib.estimator.python.estimator import head as head_lib +from tensorflow.contrib.estimator.python.estimator import linear +from tensorflow.python.estimator.canned import linear_testing_utils +from tensorflow.python.estimator.canned import prediction_keys +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column +from tensorflow.python.framework import ops +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.summary.writer import writer_cache + + +def _linear_estimator_fn( + weight_column=None, label_dimension=1, *args, **kwargs): + """Returns a LinearEstimator that uses regression_head.""" + return linear.LinearEstimator( + head=head_lib.regression_head( + weight_column=weight_column, label_dimension=label_dimension), + *args, **kwargs) + + +class LinearEstimatorEvaluateTest( + linear_testing_utils.BaseLinearRegressorEvaluationTest, test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + linear_testing_utils.BaseLinearRegressorEvaluationTest.__init__( + self, _linear_estimator_fn) + + +class LinearEstimatorPredictTest( + linear_testing_utils.BaseLinearRegressorPredictTest, test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + linear_testing_utils.BaseLinearRegressorPredictTest.__init__( + self, _linear_estimator_fn) + + +class LinearEstimatorTrainTest( + linear_testing_utils.BaseLinearRegressorTrainingTest, test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + linear_testing_utils.BaseLinearRegressorTrainingTest.__init__( + self, _linear_estimator_fn) + + +class LinearEstimatorIntegrationTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _test_complete_flow( + self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension, + label_dimension, batch_size): + feature_columns = [ + feature_column.numeric_column('x', shape=(input_dimension,))] + est = linear.LinearEstimator( + head=head_lib.regression_head(label_dimension=label_dimension), + feature_columns=feature_columns, + model_dir=self._model_dir) + + # TRAIN + num_steps = 10 + est.train(train_input_fn, steps=num_steps) + + # EVALUTE + scores = est.evaluate(eval_input_fn) + self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) + self.assertIn('loss', six.iterkeys(scores)) + + # PREDICT + predictions = np.array([ + x[prediction_keys.PredictionKeys.PREDICTIONS] + for x in est.predict(predict_input_fn) + ]) + self.assertAllEqual((batch_size, label_dimension), predictions.shape) + + # EXPORT + feature_spec = feature_column.make_parse_example_spec(feature_columns) + serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( + feature_spec) + export_dir = est.export_savedmodel(tempfile.mkdtemp(), + serving_input_receiver_fn) + self.assertTrue(gfile.Exists(export_dir)) + + def test_numpy_input_fn(self): + """Tests complete flow with numpy_input_fn.""" + label_dimension = 2 + batch_size = 10 + data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32) + data = data.reshape(batch_size, label_dimension) + # learn y = x + train_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + eval_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + shuffle=False) + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + batch_size=batch_size, + shuffle=False) + + self._test_complete_flow( + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + predict_input_fn=predict_input_fn, + input_dimension=label_dimension, + label_dimension=label_dimension, + batch_size=batch_size) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py index d9c83aa86577aa129458c56887ff4668c103d0db..598bd549c5cef7edde6bf94605aa8839b611e185 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py @@ -41,20 +41,25 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import tf_logging +from tensorflow.python.training import device_setter as device_setter_lib from tensorflow.python.training import training_util -def replicate_model_fn(model_fn, optimizer_fn, devices=None): +def replicate_model_fn(model_fn, + optimizer_fn, + loss_reduction=losses.Reduction.SUM, + devices=None): """Replicate `Estimator.model_fn` over GPUs within a single host. The given `model_fn` specifies a single forward pass of a model. To replicate such a model over GPUs, each GPU gets its own instance of the forward pass (a.k.a. a tower). The input features and labels get sharded into the chunks - that correspond to the number of GPUs. Each tower computes its own loss based + that correspond to the number of GPUs. Each tower computes a loss based on its input. For each such loss, gradients are computed. After that, the - available losses are summed to form aggregated loss. The available - gradients are summed too. Then, they update weights using the specified + available losses are aggregated to form aggregated loss. Available + gradients are summed. Then, they update weights using the specified optimizer. If `devices` are `None`, then all available GPUs are going to be used for @@ -101,7 +106,7 @@ def replicate_model_fn(model_fn, optimizer_fn, devices=None): On reduction algorithms: Certain algorithms were chosen for aggregating results of computations on multiple towers: - - Losses from all towers are reduced using sum. + - Losses from all towers are reduced according to `loss_reduction`. - Gradients are reduced using sum for each trainable variable. - `eval_metrics_ops` are reduced per metric using `reduce_mean`. - `EstimatorSpec.predictions` and `EstimatorSpec.export_outputs` are @@ -109,7 +114,7 @@ def replicate_model_fn(model_fn, optimizer_fn, devices=None): - For all other fields of `EstimatorSpec` the values of the first tower are taken. - On replication of variables: + On distribution of variables: Variables are not duplicated between towers. Instead, they are placed on a single device as defined above and shared across towers. @@ -123,6 +128,7 @@ def replicate_model_fn(model_fn, optimizer_fn, devices=None): optimizer_fn: a function that returns an optimizer instance. The function may accept one `params` argument. This is the `params` argument as defined by `Estimator`. See the `Estimator` documentation for details. + loss_reduction: controls whether losses are summed or averaged. devices: Optional list of devices to replicate the model across. This argument can be used to replice only on the subset of available GPUs. If `None`, then all available GPUs are going to be used for replication. @@ -133,39 +139,91 @@ def replicate_model_fn(model_fn, optimizer_fn, devices=None): conforms to the requirements of `Estimator`'s `model_fn` and can be used instead of the supplied `model_fn`. """ + return _replicate_model_fn_with_mode( + model_fn, + optimizer_fn, + loss_reduction, + devices, + # TODO(isaprykin): Query the system configuration to choose modes other + # than `SHARED_LOCAL_PARAMETER_SERVER`, even though it is often + # appropriate. + mode=_VariableDistributionMode.SHARED_LOCAL_PARAMETER_SERVER) + + +class _VariableDistributionMode(object): + """Modes for variable distribution used for forcing a particular one. + + Forcing a mode is meant for performance experimentation purposes rather than + for general use cases. + """ + + SHARED_LOCAL_PARAMETER_SERVER = 1 + """Variables are placed on a single device and shared across all devices. + + Two ways to achieve this distribution over available GPUs are supported: + 1) If exactly 1 GPU is detected, then variables and operations are placed + onto GPU. + 2) If more than 1 GPU is detected, then variables are going to be placed on + the CPU. Replicas of operations are placed on each individual GPU. + """ + + SHARED_ROUND_ROBIN = 2 + """Variables are placed on all devices in a round-robin fashion. + + Every subsequent variable is placed on the next device. There is only one + copy of each variable that is shared across all devices. + """ + + +def _replicate_model_fn_with_mode( + model_fn, + optimizer_fn, + loss_reduction=losses.Reduction.SUM, + devices=None, + mode=_VariableDistributionMode.SHARED_LOCAL_PARAMETER_SERVER): + """A version of `replicate_model_fn` that allows to specify a `mode`.""" + if loss_reduction == losses.Reduction.NONE: + raise ValueError('Tower losses need to be reduced in some way, yet {} ' + 'reduction is specified.'.format(loss_reduction)) if not devices: devices = _get_local_devices('GPU') or _get_local_devices('CPU') is_a_single_gpu_case = len(devices) == 1 and 'GPU' in devices[0] - local_ps_device = '/{}:0'.format('GPU' if is_a_single_gpu_case else 'CPU') + consolidation_device = '/{}:0'.format('GPU' + if is_a_single_gpu_case else 'CPU') + + ps_devices = [consolidation_device] + if mode == _VariableDistributionMode.SHARED_ROUND_ROBIN: + ps_devices = devices - tf_logging.info('Replicating the `model_fn` across {}. Local parameter ' - 'server device is going to be {}.'.format( - devices, local_ps_device)) + tf_logging.info('Replicating the `model_fn` across {}. Variables are going ' + 'to be placed on {}. Consolidation device is going to be {}.' + .format(devices, ps_devices, consolidation_device)) def replicated_model_fn(features, labels, mode, params=None, config=None): """Replicated version of `model_fn` to be used instead.""" feature_shards, label_shards = _split_batch( - features, labels, len(devices), device=local_ps_device) + features, labels, len(devices), device=consolidation_device) tower_specs = _get_loss_towers( model_fn=model_fn, mode=mode, features=feature_shards, labels=label_shards, params=params, + loss_reduction=loss_reduction, config=config, devices=devices, - local_ps_device=local_ps_device) + local_ps_devices=ps_devices) if mode == model_fn_lib.ModeKeys.TRAIN: train_op = _minimize_towers(tower_specs, _call_optimizer_fn(optimizer_fn, params)) return _train_spec( - tower_specs, train_op, aggregation_device=local_ps_device) + tower_specs, train_op, aggregation_device=consolidation_device) elif mode == model_fn_lib.ModeKeys.EVAL: - return _eval_spec(tower_specs, aggregation_device=local_ps_device) + return _eval_spec(tower_specs, aggregation_device=consolidation_device) elif mode == model_fn_lib.ModeKeys.PREDICT: - return _predict_spec(tower_specs, aggregation_device=local_ps_device) + return _predict_spec(tower_specs, aggregation_device=consolidation_device) return replicated_model_fn @@ -222,7 +280,8 @@ def _get_loss_towers(model_fn, params, config, devices, - local_ps_device, + local_ps_devices, + loss_reduction=losses.Reduction.SUM, name_scope_pattern=_DEFAULT_NAME_SCOPE_PATTERN): """Replicate the loss computation across devices.""" tower_specs = [] @@ -234,15 +293,22 @@ def _get_loss_towers(model_fn, if 'config' in model_fn_args: optional_params['config'] = copy.deepcopy(config) + # pylint: disable=protected-access + round_robin_strategy = device_setter_lib._RoundRobinStrategy( + num_tasks=len(local_ps_devices)) + # pylint: enable=protected-access + for i, device in enumerate(devices): is_the_first_tower = (i == 0) device_setter = _local_device_setter( - worker_device=device, ps_device=local_ps_device) + worker_device=device, + ps_devices=local_ps_devices, + ps_strategy=round_robin_strategy) - # We would like to preserve the names of the variables and ops that a user - # might be relying on. Names with prefix are going to resolve to variables - # and ops of the first tower. + # We would like to preserve the names of the variables and ops that the user + # might be relying on. Names without a prefix are going to resolve to + # variables and ops of the first tower. name_scope = name_scope_pattern if is_the_first_tower: name_scope = '' @@ -254,16 +320,19 @@ def _get_loss_towers(model_fn, if labels: labels_shard = labels[i] - tower_specs.append( - model_fn( - mode=mode, - features=features[i], - labels=labels_shard, - **optional_params)) + tower_spec = model_fn( + mode=mode, + features=features[i], + labels=labels_shard, + **optional_params) + if loss_reduction != losses.Reduction.SUM: + tower_spec = _scale_tower_loss( + tower_spec, number_of_towers=len(devices)) + tower_specs.append(tower_spec) return tower_specs -def _local_device_setter(ps_device, worker_device): +def _local_device_setter(worker_device, ps_devices, ps_strategy): """A device setter that puts distributes Var/Ops to PS/workers.""" ps_ops = ['Variable', 'VariableV2', 'VarHandleOp'] @@ -273,7 +342,7 @@ def _local_device_setter(ps_device, worker_device): node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def if node_def.op in ps_ops: ps_device_spec = framework_device.DeviceSpec.from_string( - '{}'.format(ps_device)) + '{}'.format(ps_devices[ps_strategy(op)])) ps_device_spec.merge_from(current_device) return ps_device_spec.to_string() @@ -286,6 +355,17 @@ def _local_device_setter(ps_device, worker_device): return local_device_chooser +def _scale_tower_loss(tower_spec, number_of_towers): + """Scale down the loss for arriving at the average loss by summing.""" + if tower_spec.loss is None: + return tower_spec + + estimator_spec = _asdict(tower_spec) + estimator_spec['loss'] = math_ops.div( + tower_spec.loss, 1.0 * number_of_towers, name='averaged_loss') + return model_fn_lib.EstimatorSpec(**estimator_spec) + + def _minimize_towers(tower_specs, optimizer): """Aggregate and apply gradients for computed losses.""" grad_lists = {} @@ -335,7 +415,7 @@ def _train_spec(tower_specs, aggregation_device, aggregated_loss_name='loss'): """Populate replicated EstimatorSpec for `GraphKeys.TRAIN`.""" - estimator_spec = tower_specs[0]._asdict() + estimator_spec = _asdict(tower_specs[0]) estimator_spec['mode'] = model_fn_lib.ModeKeys.TRAIN estimator_spec['train_op'] = train_op estimator_spec['loss'] = _compute_sum_on_device( @@ -346,7 +426,7 @@ def _train_spec(tower_specs, def _eval_spec(tower_specs, aggregation_device, aggregated_loss_name='loss'): """Populate replicated EstimatorSpec for `GraphKeys.EVAL`.""" - estimator_spec = tower_specs[0]._asdict() + estimator_spec = _asdict(tower_specs[0]) estimator_spec['mode'] = model_fn_lib.ModeKeys.EVAL estimator_spec['loss'] = _compute_sum_on_device( [spec.loss for spec in tower_specs], aggregation_device, @@ -414,7 +494,7 @@ def _reduce_metric_variables(number_of_towers): def _predict_spec(tower_specs, aggregation_device): """Populate replicated EstimatorSpec for `GraphKeys.PREDICT`.""" - estimator_spec = tower_specs[0]._asdict() + estimator_spec = _asdict(tower_specs[0]) estimator_spec['mode'] = model_fn_lib.ModeKeys.PREDICT with ops_lib.device(aggregation_device): @@ -474,3 +554,19 @@ def _dict_concat(*dicts): for k, v in six.iteritems(d): list_dict.setdefault(k, []).append(v) return list_dict + + +def _asdict(namedtuple): + """Returns a namedtuple as a dictionary. + + This is required because `_asdict()` in Python 3.x.x is broken in classes + that inherit from `collections.namedtuple`. See + https://bugs.python.org/issue24931 for more details. + + Args: + namedtuple: An object that inherits from `collections.namedtuple`. + + Returns: + A dictionary version of the tuple. + """ + return {k: getattr(namedtuple, k) for k in namedtuple._fields} diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py index 5a1982f5eb52f685a6998ae64a30b29a8aa2ce11..b452e5c7359a973bea670f5760b229cf72d032f5 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py @@ -40,6 +40,7 @@ from tensorflow.python.framework import ops as ops_lib from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import losses from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import variable_scope @@ -49,15 +50,30 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.saved_model import signature_constants from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import device_setter from tensorflow.python.training import gradient_descent +# TODO(isaprykin): Parametrize all the tests on +# replicate_model_fn._VariableDistributionMode when it's supported. class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase): def setUp(self): self._model_dir = tempfile.mkdtemp() - def test_complete_flow(self): + def test_complete_flow_with_public_version(self): + return self._complete_flow_with_mode(mode=None) + + def test_complete_flow_with_mode_local_ps_server(self): + return self._complete_flow_with_mode( + replicate_model_fn._VariableDistributionMode. + SHARED_LOCAL_PARAMETER_SERVER) + + def test_complete_flow_with_mode_round_robin(self): + return self._complete_flow_with_mode( + replicate_model_fn._VariableDistributionMode.SHARED_ROUND_ROBIN) + + def _complete_flow_with_mode(self, mode): n_classes = 3 input_dimension = 2 batch_size = 12 @@ -105,11 +121,20 @@ class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase): def optimizer_fn(): return optimizers.get_optimizer_instance('Adagrad', learning_rate=0.05) + if not mode: # Use the public `replicate_model_fn`. + model_fn = replicate_model_fn.replicate_model_fn( + estimator.model_fn, + optimizer_fn, + devices=['/gpu:0', '/gpu:1', '/gpu:2']) + else: + model_fn = replicate_model_fn._replicate_model_fn_with_mode( + estimator.model_fn, + optimizer_fn, + devices=['/gpu:0', '/gpu:1', '/gpu:2'], + mode=mode) + estimator = estimator_lib.Estimator( - model_fn=replicate_model_fn.replicate_model_fn( - estimator.model_fn, - optimizer_fn, - devices=['/gpu:0', '/gpu:1', '/gpu:2']), + model_fn=model_fn, model_dir=estimator.model_dir, config=estimator.config, params=estimator.params) @@ -197,13 +222,40 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0) self.assertEqual(total_loss, session.run(estimator_spec.loss)) - # loss' of c is 3. + # derivative of loss = (1*c - 1) + (2*c - 2) is 3. # new value of c = 10 - learning rate * 3 = 7.0. session.run(estimator_spec.train_op) with variable_scope.variable_scope('', reuse=True): c = variable_scope.get_variable('c', dtype=dtypes.float64) self.assertEqual(7.0, session.run(c)) + def test_train_with_mean_reduction(self): + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, + self.optimizer_fn, + losses.Reduction.MEAN, + devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) + session.run(variables.global_variables_initializer()) + + # loss = feature * c - label + total_loss = ((1.0 * 10 - 1.0) + (2.0 * 10 - 2.0)) / 2.0 + self.assertEqual(total_loss, session.run(estimator_spec.loss)) + + # derivative of loss = (1*c - 1)/2 + (2*c - 2)/2 is 1.5. + # It's the same computation as without mean reduction, but the + # loss from every tower is scaled by 1/. + # new value of c = 10 - learning rate * 1.5 = 8.5 + session.run(estimator_spec.train_op) + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual(8.5, session.run(c)) + def test_train_spec_with_optimizer_without_params(self): def optimizer_fn_without_params(): @@ -252,6 +304,38 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): self.assertEqual(0, auc) self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01) + def test_eval_with_mean_reduction(self): + features = np.array([[0.01], [0.002]]) + labels = np.array([[0.01], [0.02]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, + self.optimizer_fn, + losses.Reduction.MEAN, + devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.EVAL, self.params) + session.run(variables.local_variables_initializer()) + session.run(variables.global_variables_initializer()) + + accuracy, a = estimator_spec.eval_metric_ops['accuracy'] + auc, b = estimator_spec.eval_metric_ops['auc'] + + session.run([a, b]) + accuracy = session.run(accuracy) + auc = session.run(auc) + + # loss[i] = features[i] * 10 - labels[i]. + # Accuracy is 0.0 (no match) in the first tower. + # Accuracy is 1.0 (match) in the second tower, since the feature + # times weight "c" happened to be equal to the label. + total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02)) / 2.0 + + self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01) + self.assertEqual(0, auc) + self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01) + def test_predict(self): features = np.array([[0.01], [0.002]]) labels = np.array([[0.01], [0.02]]) @@ -273,7 +357,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( - self.model_fn, self.optimizer_fn) + self.model_fn, self.optimizer_fn, devices=['/gpu:0']) estimator_spec = replicated_model_fn( features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) session.run(variables.global_variables_initializer()) @@ -332,6 +416,11 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): 'probabilities': np.array([[0.1], [0.02]]) }, session.run(estimator_spec.predictions)) + def test_unsupported_loss_reduction(self): + with self.assertRaisesRegexp(ValueError, ''): + _ = replicate_model_fn.replicate_model_fn( + self.model_fn, self.optimizer_fn, losses.Reduction.NONE) + class GetLossTowersTest(test_util.TensorFlowTestCase): @@ -359,7 +448,7 @@ class GetLossTowersTest(test_util.TensorFlowTestCase): params=None, config=None, devices=['/gpu:0', '/gpu:1'], - local_ps_device='/gpu:0', + local_ps_devices=['/gpu:0'], name_scope_pattern='test_tower_{}') session.run(variables.global_variables_initializer()) @@ -382,6 +471,88 @@ class GetLossTowersTest(test_util.TensorFlowTestCase): c = variable_scope.get_variable('c', dtype=dtypes.float64) self.assertEqual(0.25, session.run(c)) + def test_gradients_are_computed_with_mean_reduction(self): + with self.test_session() as session: + tower_specs = replicate_model_fn._get_loss_towers( + self.model_fn, + mode=model_fn_lib.ModeKeys.EVAL, + features=[[0.6], [1.6]], + labels=[[0.6], [0.6]], + params=None, + loss_reduction=losses.Reduction.MEAN, + config=None, + devices=['/gpu:0', '/gpu:1'], + local_ps_devices=['/gpu:0'], + name_scope_pattern='test_tower_{}') + session.run(variables.global_variables_initializer()) + + self.assertEqual(len(tower_specs), 2) + + self.assertEqual('/device:GPU:0', tower_specs[0].loss.device) + self.assertEqual('averaged_loss:0', tower_specs[0].loss.name) + self.assertEqual(0.5, session.run(tower_specs[0].loss)) + + self.assertEqual('/device:GPU:1', tower_specs[1].loss.device) + self.assertEqual('test_tower_1/averaged_loss:0', tower_specs[1].loss.name) + # The input batch for the second tower had a loss that is 1.0 + # bigger: 0.6 vs 1.6. + self.assertEqual(1.0, session.run(tower_specs[1].loss)) + + self.assertEqual(1, len(variables.global_variables())) + self.assertEqual(1, len(variables.trainable_variables())) + + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual(0.25, session.run(c)) + + def test_variables_are_round_robined_correctly(self): + """Test that creates multiple variables and tests round-robin placement.""" + + def model_fn(mode, features, labels, params): + del params + for variable_name in ['a', 'b', 'c', 'd']: + c = variable_scope.get_variable( + variable_name, + initializer=constant_op.constant(0.25, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = math_ops.add(np.array([0.1, 0.2, 0.3, features[0]]), c) + labels = np.array([0.1, 0.2, 0.3, labels[0]]) + loss = losses.absolute_difference( + labels=labels, + predictions=predictions, + reduction=losses.Reduction.SUM) + return model_fn_lib.EstimatorSpec( + mode=mode, loss=math_ops.reduce_sum(loss)) + + with self.test_session() as session: + tower_specs = replicate_model_fn._get_loss_towers( + model_fn, + mode=None, + features=[[0.6], [1.6], [2.6]], + labels=[[0.6], [0.6], [2.6]], + params=None, + config=None, + devices=['/gpu:0', '/gpu:1', '/gpu:3'], + local_ps_devices=['/gpu:0', '/gpu:1', '/gpu:3'], + name_scope_pattern='test_tower_{}') + session.run(variables.global_variables_initializer()) + + self.assertEqual(len(tower_specs), 3) + self.assertEqual('/device:GPU:0', tower_specs[0].loss.device) + self.assertEqual('/device:GPU:1', tower_specs[1].loss.device) + self.assertEqual('/device:GPU:3', tower_specs[2].loss.device) + + with variable_scope.variable_scope('', reuse=True): + a = variable_scope.get_variable('a', dtype=dtypes.float64) + self.assertEqual('/device:GPU:0', a.device) + b = variable_scope.get_variable('b', dtype=dtypes.float64) + self.assertEqual('/device:GPU:1', b.device) + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual('/device:GPU:3', c.device) + d = variable_scope.get_variable('d', dtype=dtypes.float64) + self.assertEqual('/device:GPU:0', d.device) + class SplitBatchTest(test_util.TensorFlowTestCase): @@ -604,7 +775,7 @@ class PredictSpecTest(test_util.TensorFlowTestCase): params=None, config=None, devices=['/gpu:0', '/gpu:1'], - local_ps_device='/gpu:0', + local_ps_devices=['/gpu:0'], ) session.run(variables.global_variables_initializer()) @@ -843,33 +1014,73 @@ class GetLocalDevicesTest(test_util.TensorFlowTestCase): replicate_model_fn._get_local_devices('XPU')) # XPU doesn't exist. def test_whether_there_is_a_gpu(self): - self.assertEqual( - len(replicate_model_fn._get_local_devices('GPU')), - test.is_gpu_available()) + if test.is_gpu_available(): + self.assertTrue(len(replicate_model_fn._get_local_devices('GPU'))) class LocalDeviceSetterTest(test_util.TensorFlowTestCase): def test_vars_are_on_ps_but_ops_are_on_workers(self): + ps_devices = ['/device:GPU:3'] + round_robin = device_setter._RoundRobinStrategy(num_tasks=len(ps_devices)) + local_device_setter = replicate_model_fn._local_device_setter( - ps_device='/device:GPU:3', worker_device='/device:GPU:2') + ps_devices=ps_devices, + ps_strategy=round_robin, + worker_device='/device:GPU:2') with ops_lib.device(local_device_setter): - c = variables.Variable(0.01) + a = variables.Variable(0.01) + self.assertEqual('/device:GPU:3', a.device) + + b = variables.Variable(0.02) + self.assertEqual('/device:GPU:3', b.device) + + c = variables.Variable(0.03) self.assertEqual('/device:GPU:3', c.device) - cc = variables.Variable(0.02) - self.assertEqual('/device:GPU:3', cc.device) + a_op = array_ops.concat(a, axis=0) + self.assertEqual('/device:GPU:2', a_op.device) - ccc = variables.Variable(0.03) - self.assertEqual('/device:GPU:3', ccc.device) + b_op = array_ops.concat(b, axis=0) + self.assertEqual('/device:GPU:2', b_op.device) + + def test_round_robin_placement(self): + ps_devices = [ + '/device:GPU:0', '/device:GPU:1', '/device:GPU:3', '/device:GPU:4' + ] + round_robin = device_setter._RoundRobinStrategy(num_tasks=len(ps_devices)) + + local_device_setter = replicate_model_fn._local_device_setter( + ps_devices=ps_devices, + ps_strategy=round_robin, + worker_device='/device:GPU:2') + + with ops_lib.device(local_device_setter): + a = variables.Variable(0.01) + self.assertEqual('/device:GPU:0', a.device) + + b = variables.Variable(0.02) + self.assertEqual('/device:GPU:1', b.device) + + c = variables.Variable(0.03) + self.assertEqual('/device:GPU:3', c.device) + + a_op = array_ops.concat(a, axis=0) + self.assertEqual('/device:GPU:2', a_op.device) + + b_op = array_ops.concat(b, axis=0) + self.assertEqual('/device:GPU:2', b_op.device) + + c = variables.Variable(0.03) + self.assertEqual('/device:GPU:4', c.device) + + d = variables.Variable(0.03) + self.assertEqual('/device:GPU:0', d.device) c_op = array_ops.concat(c, axis=0) self.assertEqual('/device:GPU:2', c_op.device) - cc_op = array_ops.concat(cc, axis=0) - self.assertEqual('/device:GPU:2', cc_op.device) - class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/contrib/factorization/python/ops/gmm.py b/tensorflow/contrib/factorization/python/ops/gmm.py index 0d67e09f8151b48c97094b6b48f26e63443707ef..f72280c4ecf19e33278ffe74061f44bbb7b21709 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm.py +++ b/tensorflow/contrib/factorization/python/ops/gmm.py @@ -24,7 +24,7 @@ import numpy as np from tensorflow.contrib import framework from tensorflow.contrib.factorization.python.ops import gmm_ops from tensorflow.contrib.framework.python.framework import checkpoint_utils -from tensorflow.contrib.framework.python.ops import variables +from tensorflow.python.training import training_util from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib from tensorflow.python.framework import constant_op @@ -167,7 +167,7 @@ class GMM(estimator.Estimator): self._num_clusters, self._random_seed, self._covariance_type, self._params) - incr_step = state_ops.assign_add(variables.get_global_step(), 1) + incr_step = state_ops.assign_add(training_util.get_global_step(), 1) loss = math_ops.reduce_sum(losses) training_op = with_dependencies([training_op, incr_step], loss) training_hooks = [_InitializeClustersHook( diff --git a/tensorflow/contrib/factorization/python/ops/wals.py b/tensorflow/contrib/factorization/python/ops/wals.py index b2f22eb2fce89415b6cc60ecbbc5c86da97ba40b..4fe22ea26ec5f5a43f1c99d1fee518b1d326c5c9 100644 --- a/tensorflow/contrib/factorization/python/ops/wals.py +++ b/tensorflow/contrib/factorization/python/ops/wals.py @@ -77,6 +77,7 @@ class _SweepHook(session_run_hook.SessionRunHook): logging.info("SweepHook running init op.") sess.run(self._init_op) if is_sweep_done: + logging.info("SweepHook starting the next sweep.") sess.run(self._switch_op) is_row_sweep = sess.run(self._is_row_sweep_var) if is_sweep_done or not self._is_initialized: @@ -91,6 +92,22 @@ class _SweepHook(session_run_hook.SessionRunHook): fetches=[self._row_train_op if is_row_sweep else self._col_train_op]) +class _IncrementGlobalStepHook(session_run_hook.SessionRunHook): + """Hook that increments the global step.""" + + def __init__(self): + global_step = training_util.get_global_step() + if global_step: + self._global_step_incr_op = state_ops.assign_add( + global_step, 1, name="global_step_incr").op + else: + self._global_step_incr_op = None + + def before_run(self, run_context): + if self._global_step_incr_op: + run_context.session.run(self._global_step_incr_op) + + class _StopAtSweepHook(session_run_hook.SessionRunHook): """Hook that requests stop at a given sweep.""" @@ -166,7 +183,7 @@ def _wals_factorization_model_function(features, labels, mode, params): # TRAIN mode: if mode == model_fn.ModeKeys.TRAIN: - # Training consists of the folowing ops (controlled using a SweepHook). + # Training consists of the following ops (controlled using a SweepHook). # Before a row sweep: # row_update_prep_gramian_op # initialize_row_update_op @@ -210,14 +227,6 @@ def _wals_factorization_model_function(features, labels, mode, params): summary.scalar("root_weighted_squared_error", rwse_var) summary.scalar("completed_sweeps", completed_sweeps_var) - # Increments global step. - global_step = training_util.get_global_step() - if global_step: - global_step_incr_op = state_ops.assign_add( - global_step, 1, name="global_step_incr").op - else: - global_step_incr_op = control_flow_ops.no_op() - def create_axis_ops(sp_input, num_items, update_fn, axis_name): """Creates book-keeping and training ops for a given axis. @@ -246,9 +255,6 @@ def _wals_factorization_model_function(features, labels, mode, params): collections=[ops.GraphKeys.GLOBAL_VARIABLES], trainable=False, name="processed_" + axis_name) - reset_processed_items_op = state_ops.assign( - processed_items, processed_items_init, - name="reset_processed_" + axis_name) _, update_op, loss, reg, sum_weights = update_fn(sp_input) input_indices = sp_input.indices[:, 0] with ops.control_dependencies([ @@ -264,13 +270,12 @@ def _wals_factorization_model_function(features, labels, mode, params): with ops.control_dependencies([update_processed_items]): is_sweep_done = math_ops.reduce_all(processed_items) axis_train_op = control_flow_ops.group( - global_step_incr_op, state_ops.assign(is_sweep_done_var, is_sweep_done), state_ops.assign_add( completed_sweeps_var, math_ops.cast(is_sweep_done, dtypes.int32)), name="{}_sweep_train_op".format(axis_name)) - return reset_processed_items_op, axis_train_op + return processed_items.initializer, axis_train_op reset_processed_rows_op, row_train_op = create_axis_ops( input_rows, @@ -296,7 +301,8 @@ def _wals_factorization_model_function(features, labels, mode, params): sweep_hook = _SweepHook( is_row_sweep_var, is_sweep_done_var, init_op, row_prep_ops, col_prep_ops, row_train_op, col_train_op, switch_op) - training_hooks = [sweep_hook] + global_step_hook = _IncrementGlobalStepHook() + training_hooks = [sweep_hook, global_step_hook] if max_sweeps is not None: training_hooks.append(_StopAtSweepHook(max_sweeps)) diff --git a/tensorflow/contrib/ffmpeg/BUILD b/tensorflow/contrib/ffmpeg/BUILD index dc5a04a0b15870babbc98cf104e109caf829901c..eccce99071dc1477cf4f3bb152f3304b3b0fc35a 100644 --- a/tensorflow/contrib/ffmpeg/BUILD +++ b/tensorflow/contrib/ffmpeg/BUILD @@ -155,7 +155,10 @@ tf_py_test( data = [ ":test_data", ], - tags = ["manual"], + tags = [ + "manual", + "notap", + ], ) py_library( diff --git a/tensorflow/contrib/ffmpeg/__init__.py b/tensorflow/contrib/ffmpeg/__init__.py index 871dff7bbe4912f0daf2bc184d6b0f12510abee7..daba965a98893b992abdc598ec713f13020d6e91 100644 --- a/tensorflow/contrib/ffmpeg/__init__.py +++ b/tensorflow/contrib/ffmpeg/__init__.py @@ -26,6 +26,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_audio +from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_video from tensorflow.contrib.ffmpeg.ffmpeg_ops import encode_audio from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_video diff --git a/tensorflow/contrib/ffmpeg/decode_video_op_test.py b/tensorflow/contrib/ffmpeg/decode_video_op_test.py index 4d1fac4ef8afbf44cd45bae065f8a95b0527079a..b43b6b8919223bd7731209d5423b142601396ea5 100644 --- a/tensorflow/contrib/ffmpeg/decode_video_op_test.py +++ b/tensorflow/contrib/ffmpeg/decode_video_op_test.py @@ -20,11 +20,9 @@ from __future__ import print_function import os.path -import six +import six # pylint: disable=unused-import from tensorflow.contrib import ffmpeg -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops from tensorflow.python.ops import image_ops from tensorflow.python.platform import resource_loader from tensorflow.python.platform import test @@ -32,7 +30,8 @@ from tensorflow.python.platform import test class DecodeVideoOpTest(test.TestCase): - def _loadFileAndTest(self, filename, width, height, frames, bmp_filename, index): + def _loadFileAndTest(self, filename, width, height, frames, bmp_filename, + index): """Loads an video file and validates the output tensor. Args: @@ -40,6 +39,8 @@ class DecodeVideoOpTest(test.TestCase): width: The width of the video. height: The height of the video. frames: The frames of the video. + bmp_filename: The filename for the bmp file. + index: Index location inside the video. """ with self.test_session(): path = os.path.join(resource_loader.get_data_files_path(), 'testdata', @@ -48,7 +49,7 @@ class DecodeVideoOpTest(test.TestCase): contents = f.read() bmp_path = os.path.join(resource_loader.get_data_files_path(), 'testdata', - bmp_filename) + bmp_filename) with open(bmp_path, 'rb') as f: bmp_contents = f.read() @@ -58,7 +59,7 @@ class DecodeVideoOpTest(test.TestCase): video_op = ffmpeg.decode_video(contents) video = video_op.eval() self.assertEqual(video.shape, (frames, height, width, 3)) - self.assertAllEqual(video[index,:,:,:], image) + self.assertAllEqual(video[index, :, :, :], image) def testMp4(self): self._loadFileAndTest('small.mp4', 560, 320, 166, 'small_100.bmp', 99) diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc index 201774e1d011f35df9c3803f2ed8818cc9b1c1c2..1e8af1458cea13b2ddb89b7d93a4ffb8b974ecd2 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc @@ -49,7 +49,8 @@ std::vector FfmpegAudioCommandLine(const string& input_filename, "-nostdin", // No interactive commands accepted. "-f", input_format_id, // eg: "mp3" "-probesize", StrCat(kDefaultProbeSize), "-i", input_filename, - "-loglevel", "info", // Enable verbose logging to support debugging. + "-loglevel", "error", // Print errors only. + "-hide_banner", // Skip printing build options, version, etc. "-map_metadata", "-1", // Copy global metadata from input to output. "-vn", // No video recording. "-ac:a:0", StrCat(channel_count), "-ar:a:0", @@ -72,7 +73,8 @@ std::vector FfmpegVideoCommandLine(const string& input_filename, "-probesize", StrCat(kDefaultProbeSize), "-loglevel", - "info", // Enable verbose logging to support debugging. + "error", // Print errors only. + "-hide_banner", // Skip printing build options, version, etc. "-vcodec", "rawvideo", "-pix_fmt", @@ -220,7 +222,8 @@ string BuildWavFile(int32 samples_per_second, int32 channel_count, Status ReadInfoFile(const string& filename, uint32* width, uint32* height, uint32* frames) { string data; - ReadFileToString(Env::Default(), filename, &data); + TF_QCHECK_OK(ReadFileToString(Env::Default(), filename, &data)) + << "Could not read FFmpeg file: " << filename; bool in_output = false; bool in_mapping = false; uint32 frames_value = 0; @@ -377,7 +380,7 @@ Status ReadVideoFile(const string& filename, std::vector* output_data, open(stderr_filename.c_str(), O_RDWR | O_CREAT | O_APPEND, 0600); if (fd < 0) { const int error = errno; - LOG(ERROR) << "FFmpeg stderr file coule not be created: " + LOG(ERROR) << "FFmpeg stderr file could not be created: " << strerror(error); ::_exit(error); } diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc index 2871c1462894c6a4ddef63e9178272df0d14824c..85b61b26163d87a10d4e316720b4f633e038bbec 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc @@ -39,7 +39,7 @@ const char kTestMp3Filename[] = // Set to true via a command line flag iff the test is expected to have FFmpeg // installed. -mutex mu; +mutex mu(LINKER_INITIALIZED); bool should_ffmpeg_be_installed GUARDED_BY(mu) = false; string ParseTestFlags(int* argc, char** argv) { diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc index 39e7e90cccf1012eb42261bde55d0dc3b7f278ef..36fc71794b06e0f3cb86c40b325ce50e8999c667 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc @@ -23,6 +23,7 @@ #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py index 78ead471d2cf9f0654a06dc022d7cc592d14c710..08b5a6ea48c2d4959af68a2ee9d27d21c6245457 100644 --- a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py +++ b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.ffmpeg.ops import gen_decode_audio_op_py +from tensorflow.contrib.ffmpeg.ops import gen_decode_video_op_py from tensorflow.contrib.ffmpeg.ops import gen_encode_audio_op_py from tensorflow.contrib.ffmpeg.ops import gen_decode_video_op_py from tensorflow.contrib.util import loader diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index e8dad886a1409babdf4ea47b9cd05def1f1ce25e..5b659ddaa1386736eb8cc05a203ed1827ccd160e 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -276,6 +276,7 @@ py_test( "//tensorflow/python:nn_ops", "//tensorflow/python:partitioned_variables", "//tensorflow/python:platform", + "//tensorflow/python:resource_variable_ops", "//tensorflow/python:session", "//tensorflow/python:training", "//tensorflow/python:variable_scope", diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index 3f592611830e40a30392239c85486a2fad15a2a2..4edc77f86ba786ca547b8d3842e2cf02833fbbac 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -65,6 +65,7 @@ See the @{$python/contrib.framework} guide. @@get_variable_full_name @@get_variables_to_restore @@get_variables +@@global_variable @@local_variable @@model_variable @@variable diff --git a/tensorflow/contrib/framework/python/framework/graph_util.py b/tensorflow/contrib/framework/python/framework/graph_util.py index 8ab8711db4650921e0d366a91adfe2f68b5a42f9..a18ff2320d99726bb355ff6179fc97a070c2fec7 100644 --- a/tensorflow/contrib/framework/python/framework/graph_util.py +++ b/tensorflow/contrib/framework/python/framework/graph_util.py @@ -24,12 +24,14 @@ import six # pylint: disable=unused-import from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import node_def_pb2 +from tensorflow.python.framework import ops from tensorflow.python.framework.graph_util_impl import _assert_nodes_are_present from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes from tensorflow.python.framework.graph_util_impl import _extract_graph_summary from tensorflow.python.framework.graph_util_impl import _node_name -__all__ = ["fuse_op"] + +__all__ = ["fuse_op", "get_placeholders"] def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes, @@ -91,7 +93,7 @@ def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes, (n, cur_node)) if cur_node not in input_nodes_set: next_to_visit += name_to_input_name[cur_node] - else: + elif n not in reachable_by_input: nodes_post_output.append(n) # Add all nodes upto the input nodes @@ -126,3 +128,27 @@ def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes, out.library.CopyFrom(graph_def.library) out.versions.CopyFrom(graph_def.versions) return out + + +def get_placeholders(graph): + """Get placeholders of a graph. + + Args: + graph: A tf.Graph. + Returns: + A list contains all placeholders of given graph. + + Raises: + TypeError: If `graph` is not a tensorflow graph. + """ + + if not isinstance(graph, ops.Graph): + raise TypeError("Input graph needs to be a Graph: %s" % graph) + + # For each placeholder() call, there is a corresponding + # operation of type 'Placeholder' registered to the graph. + # The return value (a Tensor) of placeholder() is the + # first output of this operation in fact. + operations = graph.get_operations() + result = [i.outputs[0] for i in operations if i.type == "Placeholder"] + return result diff --git a/tensorflow/contrib/framework/python/framework/graph_util_test.py b/tensorflow/contrib/framework/python/framework/graph_util_test.py index 87b992e22e1ad3aa20389d0834eeb3a5972c676e..b8a6d109e19211d271c2b15bac66ddacd38fe395 100644 --- a/tensorflow/contrib/framework/python/framework/graph_util_test.py +++ b/tensorflow/contrib/framework/python/framework/graph_util_test.py @@ -21,6 +21,9 @@ from tensorflow.contrib.framework.python.framework import graph_util from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import node_def_pb2 from tensorflow.core.framework import types_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -56,6 +59,41 @@ class GraphUtilTest(test.TestCase): self.assertEqual(fused_graph_def.node[2].name, 'D') self.assertEqual(fused_graph_def.node[3].name, 'E') + def testGraphUtilArtificialDependencyInjection(self): + graph_def = graph_pb2.GraphDef() + node_a = GetNewNode('A', 'Placeholder', []) + node_a1 = GetNewNode('A1', 'Placeholder', []) + node_b = GetNewNode('B', 'Op1', ['A']) + node_c = GetNewNode('C', 'Op1', ['B']) + node_d = GetNewNode('D', 'Op1', ['C']) + node_e = GetNewNode('E', 'Op1', ['D']) + graph_def.node.extend([node_a, node_a1, node_b, node_c, node_d, node_e]) + fused_graph_def = graph_util.fuse_op(graph_def, ['A', 'A1'], ['D'], + [types_pb2.DT_FLOAT], True, 'FusedOp', + 'Op2') + self.assertEqual(len(fused_graph_def.node), 5) + self.assertEqual(fused_graph_def.node[0].name, 'A') + self.assertEqual(fused_graph_def.node[1].name, 'A1') + self.assertEqual(fused_graph_def.node[2].name, 'FusedOp') + self.assertEqual(fused_graph_def.node[2].input[0], 'A') + self.assertEqual(fused_graph_def.node[2].op, 'Op2') + self.assertEqual(fused_graph_def.node[2].attr['_output_quantized'].b, True) + self.assertEqual(fused_graph_def.node[2].attr['_output_types'].list.type, + [types_pb2.DT_FLOAT]) + self.assertEqual(fused_graph_def.node[3].name, 'D') + self.assertEqual(fused_graph_def.node[4].name, 'E') + + +class GetPlaceholdersTest(test.TestCase): + + def test_get_placeholders(self): + with ops.Graph().as_default() as g: + placeholders = [array_ops.placeholder(dtypes.float32) for _ in range(5)] + results = graph_util.get_placeholders(g) + self.assertEqual( + sorted(placeholders, key=lambda x: x._id), # pylint: disable=protected-access + sorted(results, key=lambda x: x._id)) # pylint: disable=protected-access + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py b/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py index a0667bd489213cf366e27114a91e8699ed9e7428..2375ee4f550616ff60d20b87b5773704d8fbbe1e 100644 --- a/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py +++ b/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py @@ -48,7 +48,7 @@ def accumulate_n_v2(inputs, shape=None, tensor_dtype=None, name=None): tf.accumulate_n_v2([a, b, a]) # [[7, 4], [6, 14]] # Explicitly pass shape and type - tf.accumulate_n_v2([a, b, a], shape=[2, 2], tensor_dtype=tf.int32) + tf.accumulate_n_v2([a, b, a], shape=[2, 2], tensor_dtype=tf.int32) # [[7, 4], # [6, 14]] ``` @@ -93,7 +93,7 @@ def accumulate_n_v2(inputs, shape=None, tensor_dtype=None, name=None): elif len(inputs) == 1 and name is not None: return array_ops.identity(inputs[0], name=name) elif context.in_eager_mode(): - # TemporaryVariable not currently supported in eager mode; fall back + # TemporaryVariable not currently supported in eager mode; fall back # onto AddN for now. # TODO(frreiss) remove this once the lifetime of eager variables gets # addressed @@ -101,7 +101,7 @@ def accumulate_n_v2(inputs, shape=None, tensor_dtype=None, name=None): else: return gen_math_ops._accumulate_nv2(inputs, name=name, shape=shape) -# The following code should eventually be merged into +# The following code should eventually be merged into # tensorflow/python/ops/math_grad.py @ops.RegisterGradient("AccumulateNV2") def _AddNGrad(op, grad): diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py index c2229bb8ad3d5b38321d16f150ed94175ab9bdbe..8f44698da851b48abf831e957c80fa1643a58bda 100644 --- a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py +++ b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for new version of accumulate_n op that will eventually go into +"""Tests for new version of accumulate_n op that will eventually go into `ops.math_ops`. -These test cases spefically exercise the `eager` APIs. They need to be in a +These test cases spefically exercise the `eager` APIs. They need to be in a separate file from the remaining tests because eager mode is currently something you can turn on but can't turn off for the lifetime of the current process.""" from __future__ import absolute_import @@ -64,7 +64,7 @@ class AccumulateNV2EagerTest(test_util.TensorFlowTestCase): np.random.seed(42) num_inputs = 3 input_vars = [ - resource_variable_ops.ResourceVariable(10.0 * np.random.random(), + resource_variable_ops.ResourceVariable(10.0 * np.random.random(), name="t%d" % i) for i in range(0, num_inputs) ] @@ -72,7 +72,7 @@ class AccumulateNV2EagerTest(test_util.TensorFlowTestCase): def fn(first, second, third): return av2.accumulate_n_v2([first, second, third]) - grad_fn = backprop.gradients_function(fn) + grad_fn = backprop.gradients_function(fn) grad = grad_fn(input_vars[0], input_vars[1], input_vars[2]) self.assertAllEqual(np.repeat(1.0, num_inputs), # d/dx (x + y + ...) = 1 [elem.numpy() for elem in grad]) diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py index 3386e849d5cb8516ab3b1f6cb0429be3fc2fc960..b5e9f8df79262635bf579a6bf2260bc40c140c6f 100644 --- a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py +++ b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for new version of accumulate_n op that will eventually go into +"""Tests for new version of accumulate_n op that will eventually go into `ops.math_ops`.""" from __future__ import absolute_import from __future__ import division @@ -102,21 +102,21 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): a = variables.Variable(np.array([0.1,0.2])) b = variables.Variable(np.array([[0.3],[0.4]])) - tf_val = av2.accumulate_n_v2([a,b]) + tf_val = av2.accumulate_n_v2([a,b]) def testWrongType(self): with self.test_session(): with self.assertRaises(TypeError): a = variables.Variable(0.2, dtype=np.float32) b = variables.Variable(0.1, dtype=np.float32) - tf_val = av2.accumulate_n_v2([a,b], tensor_dtype=np.int32) + tf_val = av2.accumulate_n_v2([a,b], tensor_dtype=np.int32) def testWrongTypeOneInput(self): # Scenario that used to trigger a bug, even when testWrongType() worked with self.test_session(): with self.assertRaises(TypeError): a = variables.Variable(0.2, dtype=np.float32) - tf_val = av2.accumulate_n_v2([a], tensor_dtype=np.int32) + tf_val = av2.accumulate_n_v2([a], tensor_dtype=np.int32) if __name__ == "__main__": diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py index b7668379686b4f0ba2a3e415ddb44b287659baaa..3f1ece4510578b5ac39849c577fffbb2a3be45a7 100644 --- a/tensorflow/contrib/framework/python/ops/variables.py +++ b/tensorflow/contrib/framework/python/ops/variables.py @@ -60,6 +60,7 @@ __all__ = ['add_model_variable', 'get_variable_full_name', 'get_variables_to_restore', 'get_variables', + 'global_variable', 'local_variable', 'model_variable', 'variable', @@ -147,20 +148,48 @@ def get_or_create_global_step(graph=None): return training_util.get_or_create_global_step(graph) -def local_variable(initial_value, validate_shape=True, name=None): - """Create variable and add it to `GraphKeys.LOCAL_VARIABLES` collection. +def local_variable(initial_value, + validate_shape=True, + name=None, + use_resource=None): + """Create a variable with a value and add it to `GraphKeys.LOCAL_VARIABLES`. Args: initial_value: See variables.Variable.__init__. validate_shape: See variables.Variable.__init__. name: See variables.Variable.__init__. + use_resource: If `True` use a ResourceVariable instead of a Variable. Returns: New variable. """ return variable_scope.variable( initial_value, trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], - validate_shape=validate_shape, name=name) + validate_shape=validate_shape, + use_resource=use_resource, + name=name) + + +def global_variable(initial_value, + validate_shape=True, + name=None, + use_resource=None): + """Create a variable with a value and add it to `GraphKeys.GLOBAL_VARIABLES`. + + Args: + initial_value: See variables.Variable.__init__. + validate_shape: See variables.Variable.__init__. + name: See variables.Variable.__init__. + use_resource: If `True` use a ResourceVariable instead of a Variable. + Returns: + New variable. + """ + return variable_scope.variable( + initial_value, trainable=False, + collections=[ops.GraphKeys.GLOBAL_VARIABLES], + validate_shape=validate_shape, + use_resource=use_resource, + name=name) @contrib_add_arg_scope @@ -412,7 +441,7 @@ def get_unique_variable(var_op_name): """ candidates = get_variables(scope=var_op_name) if not candidates: - raise ValueError('Couldnt find variable %s' % var_op_name) + raise ValueError('Couldn\'t find variable %s' % var_op_name) for candidate in candidates: if candidate.op.name == var_op_name: diff --git a/tensorflow/contrib/framework/python/ops/variables_test.py b/tensorflow/contrib/framework/python/ops/variables_test.py index 6a74e4e8666e98ca3c97dc9ddd8a6c11613f708e..2f06df93acb0a4c0b36c68839ff531e3c22c5ee3 100644 --- a/tensorflow/contrib/framework/python/ops/variables_test.py +++ b/tensorflow/contrib/framework/python/ops/variables_test.py @@ -33,6 +33,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import gfile @@ -102,6 +103,82 @@ class LocalVariableTest(test.TestCase): sess.run(variables_lib.local_variables_initializer()) self.assertAllEqual(a.eval(), [0] * 5) + def testResourceVariable(self): + a = variables_lib2.local_variable(0) + b = variables_lib2.local_variable(0, use_resource=True) + self.assertEqual(type(a), variables_lib.Variable) + self.assertEqual(type(b), resource_variable_ops.ResourceVariable) + + +class GlobalVariableTest(test.TestCase): + + def test_global_variable(self): + with self.test_session() as sess: + self.assertEquals([], variables_lib.global_variables()) + value0 = 42 + variables_lib2.global_variable(value0) + value1 = 43 + variables_lib2.global_variable(value1) + variables = variables_lib.global_variables() + self.assertEquals(2, len(variables)) + with self.assertRaisesOpError( + 'Attempting to use uninitialized value Variable'): + sess.run(variables) + variables_lib.variables_initializer(variables).run() + self.assertAllEqual(set([value0, value1]), set(sess.run(variables))) + + def testVariableNameAndShape(self): + with self.test_session(): + with variable_scope.variable_scope('A'): + a = variables_lib2.global_variable([1, 1, 1, 1, 1], name='a') + self.assertEquals(a.op.name, 'A/a') + self.assertListEqual(a.get_shape().as_list(), [5]) + self.assertListEqual([a], variables_lib.global_variables()) + + def testGlobalVariableNotInLocalVariables(self): + with self.test_session(): + with variable_scope.variable_scope('A'): + a = variables_lib2.global_variable(0) + self.assertFalse(a in variables_lib.local_variables()) + self.assertTrue(a in variables_lib.global_variables()) + + def testGlobalVariableInVariablesToRestore(self): + with self.test_session(): + with variable_scope.variable_scope('A'): + a = variables_lib2.global_variable(0) + self.assertFalse(a in variables_lib.local_variables()) + self.assertTrue(a in variables_lib2.get_variables_to_restore()) + + def testGetVariablesReturnsThem(self): + with self.test_session(): + with variable_scope.variable_scope('A'): + a = variables_lib2.global_variable(0) + with variable_scope.variable_scope('B'): + b = variables_lib2.global_variable(0) + self.assertEquals([a], variables_lib2.get_variables('A')) + self.assertEquals([b], variables_lib2.get_variables('B')) + + def testGetLocalVariablesDontReturnsThem(self): + with self.test_session(): + with variable_scope.variable_scope('A'): + variables_lib2.global_variable(0) + with variable_scope.variable_scope('B'): + variables_lib2.global_variable(0) + self.assertEquals([], variables_lib2.get_local_variables('A')) + self.assertEquals([], variables_lib2.get_local_variables('B')) + + def testInitializedVariableValue(self): + with self.test_session() as sess: + a = variables_lib2.global_variable([0, 0, 0, 0, 0], name='a') + sess.run(variables_lib.global_variables_initializer()) + self.assertAllEqual(a.eval(), [0] * 5) + + def testResourceVariable(self): + a = variables_lib2.global_variable(0) + b = variables_lib2.global_variable(0, use_resource=True) + self.assertEqual(type(a), variables_lib.Variable) + self.assertEqual(type(b), resource_variable_ops.ResourceVariable) + class GlobalStepTest(test.TestCase): diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc index 88306094ab9947c9c78b03c0013f6afc88316803..5fec69ea4361a97c79ddc3188469e7ffb327f6cc 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc @@ -493,6 +493,8 @@ void LaunchFusedConv2DBiasActivationOp:: {{conv_input_rows, conv_input_cols}}, output_depth, {{filter_rows, filter_cols}}, + // TODO(yangzihao): Add support for arbitrary dilations for fused conv. + {{1, 1}}, // dilation_rows, dilation_cols {{row_stride, col_stride}}, {{padding_rows, padding_cols}}, conv_input->dtype(), diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h b/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h index dc43af11580ce5fda74ee25da6c151a5b89c7aee..fa7a3c03aa35c756252b22a004be91fa24c10e41 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h @@ -30,11 +30,12 @@ class FusedConvParameters : public ConvParameters { public: FusedConvParameters(int64 batch, int64 in_depths, const SpatialArray& in, int64 out_depths, const SpatialArray& filter, - const SpatialArray& stride, const SpatialArray& padding, - DataType dtype, int device_id, bool has_side_input, + const SpatialArray& dilation, const SpatialArray& stride, + const SpatialArray& padding, DataType dtype, + int device_id, bool has_side_input, ActivationMode activation_mode) - : ConvParameters(batch, in_depths, in, out_depths, filter, stride, - padding, dtype, device_id), + : ConvParameters(batch, in_depths, in, out_depths, filter, dilation, + stride, padding, dtype, device_id), activation_mode_(activation_mode), has_side_input_(has_side_input) { hash_code_ = Hash64Combine(hash_code_, has_side_input); diff --git a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc index 887ebc5a6c35379476fa1a643c866d38e2b25699..6a56237f67c844a3daa546eb02d64c9e2658f639 100644 --- a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc @@ -52,6 +52,7 @@ REGISTER_OP("FusedConv2DBiasActivation") .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'") .Attr("filter_format: {'HWIO', 'OIHW', 'OIHW_VECT_I'} = 'HWIO'") .Attr("activation_mode: {'Relu'} = 'Relu'") + .Attr("dilations: list(int) = [1, 1, 1, 1]") .SetShapeFn([](shape_inference::InferenceContext* c) { using shape_inference::ShapeHandle; using shape_inference::DimensionHandle; @@ -151,6 +152,11 @@ REGISTER_OP("FusedConv2DBiasActivation") kernel_height, kernel_width, input_channels % 4 ]` activation_mode: The activation applied to the output. Currently must be "Relu". + dilations: 1-D tensor of length 4. The dilation factor for each dimension + of `input`. If set to k > 1, there will be k-1 skipped cells between + each filter element on that dimension. The dimension order is determined + by the value of `data_format`, see above for details. Dilations in the + batch and depth dimensions must be 1. )doc"); } // namespace tensorflow diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index 1418c87023af0dbff890f46e10f0140d5b89e4b7..b355a79b1a5d967eb82a30d41c073bbb52e0364c 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -56,6 +56,7 @@ py_test( srcs = ["python/train_test.py"], srcs_version = "PY2AND3", deps = [ + ":features", ":namedtuples", ":train", "//tensorflow/contrib/framework:framework_py", @@ -82,6 +83,7 @@ py_library( deps = [ ":classifier_metrics", ":eval_utils", + ":sliced_wasserstein", ":summaries", "//tensorflow/python:util", ], @@ -116,6 +118,7 @@ py_library( deps = [ ":clip_weights", ":conditioning_utils", + ":random_tensor_pool", ":virtual_batchnorm", "//tensorflow/python:util", ], @@ -219,6 +222,37 @@ py_test( ], ) +py_library( + name = "random_tensor_pool", + srcs = [ + "python/features/python/random_tensor_pool.py", + "python/features/python/random_tensor_pool_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:util", + ], +) + +py_test( + name = "random_tensor_pool_test", + srcs = ["python/features/python/random_tensor_pool_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":random_tensor_pool", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//third_party/py/numpy", + ], +) + py_library( name = "virtual_batchnorm", srcs = [ @@ -470,6 +504,41 @@ py_test( ], ) +py_library( + name = "sliced_wasserstein", + srcs = [ + "python/eval/python/sliced_wasserstein.py", + "python/eval/python/sliced_wasserstein_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn", + "//tensorflow/python:nn_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:script_ops", + "//tensorflow/python:util", + "//third_party/py/numpy", + ], +) + +py_test( + name = "sliced_wasserstein_test", + srcs = ["python/eval/python/sliced_wasserstein_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":sliced_wasserstein", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:random_ops", + "//third_party/py/numpy", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/gan/README.md b/tensorflow/contrib/gan/README.md index 3ab84780705b35567169bd76fd3485ad355ba9d8..4ead66ca13e74bacc0e4679a8d5c4e0f23d04b69 100644 --- a/tensorflow/contrib/gan/README.md +++ b/tensorflow/contrib/gan/README.md @@ -8,7 +8,8 @@ explicitly model the distribution and without writing an explicit loss. For example, the generator could learn to draw samples from the distribution of natural images. For more details on this technique, see ['Generative Adversarial Networks'](https://arxiv.org/abs/1406.2661) by -Goodfellow et al. +Goodfellow et al. See [tensorflow/models](https://github.com/tensorflow/models/tree/master/research/gan/) for examples, and [this tutorial](https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb) for an +introduction. #### Usage ```python @@ -23,8 +24,8 @@ mix TFGAN, native TF, and other custom frameworks * Use already implemented [GAN losses and penalties](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/losses/python/losses_impl.py) (ex Wasserstein loss, gradient penalty, mutual information penalty, etc) * [Monitor and visualize](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/eval/python/summaries_impl.py) GAN progress during training, and [evaluate](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py) them * Use already-implemented [tricks](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/features/python/) to stabilize and improve training -* Develop based on examples of common GAN setups -* Use the TFGAN-backed tf.Learn Estimator to easily train a GAN model +* Develop based on examples of [common GAN setups](https://github.com/tensorflow/models/tree/master/research/gan/) +* Use the TFGAN-backed [GANEstimator](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py) to easily train a GAN model * Improvements in TFGAN infrastructure will automatically benefit your TFGAN project * Stay up-to-date with research as we add more algorithms @@ -51,7 +52,7 @@ network to evaluate your unconditional generative model. You can also use your own pretrained classifier for more specific performance numbers, or use other methods for evaluating conditional generative models. -* examples (coming soon): +* [examples](https://github.com/tensorflow/models/tree/master/research/gan/) and [tutorial](https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb): See examples of how to use TFGAN to make GAN training easier, or use the more complicated examples to jumpstart your own project. These include unconditional and conditional GANs, InfoGANs, adversarial losses on existing networks, and image-to-image translation. @@ -98,8 +99,8 @@ gan_model = tfgan.gan_model( # Build the GAN loss. gan_loss = tfgan.gan_loss( gan_model, - generator_loss_fn=tfgan_losses.wasserstein_generator_loss, - discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss) + generator_loss_fn=tfgan.losses.wasserstein_generator_loss, + discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss) # Create the train ops, which calculate gradients and apply updates to weights. train_ops = tfgan.gan_train_ops( @@ -160,8 +161,8 @@ gan_model = tfgan.gan_model( # Build the GAN loss and standard pixel loss. gan_loss = tfgan.gan_loss( gan_model, - generator_loss_fn=tfgan_losses.wasserstein_generator_loss, - discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss, + generator_loss_fn=tfgan.losses.wasserstein_generator_loss, + discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss, gradient_penalty=1.0) l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1) @@ -192,8 +193,8 @@ gan_model = tfgan.gan_model( # Build the GAN loss and standard pixel loss. gan_loss = tfgan.gan_loss( gan_model, - generator_loss_fn=tfgan_losses.least_squares_generator_loss, - discriminator_loss_fn=tfgan_losses.least_squares_discriminator_loss) + generator_loss_fn=tfgan.losses.least_squares_generator_loss, + discriminator_loss_fn=tfgan.losses.least_squares_discriminator_loss) l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1) # Modify the loss tuple to include the pixel loss. @@ -222,8 +223,8 @@ gan_model = tfgan.infogan_model( # Build the GAN loss with mutual information penalty. gan_loss = tfgan.gan_loss( gan_model, - generator_loss_fn=tfgan_losses.wasserstein_generator_loss, - discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss, + generator_loss_fn=tfgan.losses.wasserstein_generator_loss, + discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss, gradient_penalty=1.0, mutual_information_penalty_weight=1.0) diff --git a/tensorflow/contrib/gan/__init__.py b/tensorflow/contrib/gan/__init__.py index dff361fdc42708ea69999c2def4721f9d49fcf14..f1946c7f925660eae3aaa650c437e03da1f33d6c 100644 --- a/tensorflow/contrib/gan/__init__.py +++ b/tensorflow/contrib/gan/__init__.py @@ -12,7 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN grouped API. Please see README.md for details and usage.""" +"""TFGAN is a lightweight library for training and evaluating GANs. + +In addition to providing the infrastructure for easily training and evaluating +GANS, this library contains modules for a TFGAN-backed Estimator, +evaluation metrics, features (such as virtual batch normalization), and losses. +Please see README.md for details and usage. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/estimator/__init__.py b/tensorflow/contrib/gan/python/estimator/__init__.py index 8c4a18228039cb4f2c06e0333f4b8408f1f631e9..c9f7bc61b25230e4159cf8cbc7c9cceead0aa706 100644 --- a/tensorflow/contrib/gan/python/estimator/__init__.py +++ b/tensorflow/contrib/gan/python/estimator/__init__.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN grouped API. Please see README.md for details and usage.""" +"""TFGAN estimator module. + +GANEstimator provides all the infrastructure support of a TensorFlow Estimator +with the feature support of TFGAN. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index 0824ecf616caa91938c365d0c117287ed9ea8f32..d3dca3d9e75fe1ef3be67143e18c0b51e84ad24c 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import enum from tensorflow.contrib.framework.python.ops import variables as variable_lib @@ -29,6 +30,7 @@ from tensorflow.python.estimator import estimator from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import ops from tensorflow.python.ops import variable_scope +from tensorflow.python.util import tf_inspect as inspect __all__ = [ @@ -105,6 +107,7 @@ class GANEstimator(estimator.Estimator): discriminator_loss_fn=None, generator_optimizer=None, discriminator_optimizer=None, + get_hooks_fn=None, add_summaries=None, use_loss_summaries=True, config=None): @@ -116,7 +119,10 @@ class GANEstimator(estimator.Estimator): to continue training a previously saved model. generator_fn: A python function that takes a Tensor, Tensor list, or Tensor dictionary as inputs and returns the outputs of the GAN - generator. See `TFGAN` for more details and examples. + generator. See `TFGAN` for more details and examples. Additionally, if + it has an argument called `mode`, the Estimator's `mode` will be passed + in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch + normalization. discriminator_fn: A python function that takes the output of `generator_fn` or real data in the GAN setup, and `generator_inputs`. Outputs a Tensor in the range [-inf, inf]. See `TFGAN` for more details @@ -132,6 +138,10 @@ class GANEstimator(estimator.Estimator): work. discriminator_optimizer: Same as `generator_optimizer`, but for the discriminator updates. + get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a + list of hooks. These hooks are run on the generator and discriminator + train ops, and can be used to implement the GAN training scheme. + Defaults to `train.get_sequential_train_hooks()`. add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`. use_loss_summaries: If `True`, add loss summaries. If `False`, does not. If `None`, uses defaults. @@ -146,7 +156,7 @@ class GANEstimator(estimator.Estimator): else discriminator_optimizer) gan_head = head_lib.gan_head( generator_loss_fn, discriminator_loss_fn, gopt, dopt, - use_loss_summaries) + use_loss_summaries, get_hooks_fn=get_hooks_fn) return _gan_model_fn( features, labels, mode, generator_fn, discriminator_fn, gan_head, add_summaries) @@ -155,11 +165,6 @@ class GANEstimator(estimator.Estimator): model_fn=_model_fn, model_dir=model_dir, config=config) -def _use_check_shapes(real_data): - """Determines whether TFGAN should check Tensor shapes.""" - return isinstance(real_data, ops.Tensor) - - def _gan_model_fn( features, labels, @@ -225,16 +230,19 @@ def _gan_model_fn( labels=None) -def _make_train_gan_model(generator_fn, discriminator_fn, real_data, - generator_inputs, generator_scope, add_summaries): - """Make a `GANModel` for training.""" +def _make_gan_model(generator_fn, discriminator_fn, real_data, + generator_inputs, generator_scope, add_summaries, mode): + """Make a `GANModel`, and optionally pass in `mode`.""" + # If `generator_fn` has an argument `mode`, pass mode to it. + if 'mode' in inspect.getargspec(generator_fn).args: + generator_fn = functools.partial(generator_fn, mode=mode) gan_model = tfgan_train.gan_model( generator_fn, discriminator_fn, real_data, generator_inputs, generator_scope=generator_scope, - check_shapes=_use_check_shapes(real_data)) + check_shapes=False) if add_summaries: if not isinstance(add_summaries, (tuple, list)): add_summaries = [add_summaries] @@ -245,15 +253,28 @@ def _make_train_gan_model(generator_fn, discriminator_fn, real_data, return gan_model +def _make_train_gan_model(generator_fn, discriminator_fn, real_data, + generator_inputs, generator_scope, add_summaries): + """Make a `GANModel` for training.""" + return _make_gan_model(generator_fn, discriminator_fn, real_data, + generator_inputs, generator_scope, add_summaries, + model_fn_lib.ModeKeys.TRAIN) + + def _make_eval_gan_model(generator_fn, discriminator_fn, real_data, generator_inputs, generator_scope, add_summaries): """Make a `GANModel` for evaluation.""" - return _make_train_gan_model(generator_fn, discriminator_fn, real_data, - generator_inputs, generator_scope, add_summaries) + return _make_gan_model(generator_fn, discriminator_fn, real_data, + generator_inputs, generator_scope, add_summaries, + model_fn_lib.ModeKeys.EVAL) def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope): """Make a `GANModel` from just the generator.""" + # If `generator_fn` has an argument `mode`, pass mode to it. + if 'mode' in inspect.getargspec(generator_fn).args: + generator_fn = functools.partial(generator_fn, + mode=model_fn_lib.ModeKeys.PREDICT) with variable_scope.variable_scope(generator_scope) as gen_scope: generator_inputs = tfgan_train._convert_tensor_or_l_or_d(generator_inputs) # pylint:disable=protected-access generated_data = generator_fn(generator_inputs) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py index 1bfdce9ee94d4d05d5186cd999361662bc0e3f85..e752f0bcccda418b79d4fdabb27807394cbbb425 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py @@ -48,7 +48,8 @@ from tensorflow.python.training import training from tensorflow.python.training import training_util -def generator_fn(noise_dict): +def generator_fn(noise_dict, mode): + del mode noise = noise_dict['x'] return layers.fully_connected(noise, noise.shape[1].value) @@ -90,7 +91,6 @@ def mock_head(testcase, expected_generator_inputs, expected_real_data, generator_var_names, set([x.name for x in gan_model.generator_variables])) testcase.assertEqual(generator_scope_name, gan_model.generator_scope.name) - testcase.assertEqual(generator_fn, gan_model.generator_fn) testcase.assertEqual(_or_none(expected_real_data), gan_model.real_data) # TODO(joelshor): Add check on `discriminator_real_outputs`. # TODO(joelshor): Add check on `discriminator_gen_outputs`. diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py index 204c646e194319c0e63599da0b2a4909ef270ef3..a21358c50bbdb4a1a929b0c5bc322cec4c9923b5 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py @@ -71,7 +71,7 @@ class GANHead(head._Head): # pylint: disable=protected-access def __init__(self, generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer, use_loss_summaries=True, - get_hooks_fn=tfgan_train.get_sequential_train_hooks(), + get_hooks_fn=None, name=None): """`Head` for GAN training. @@ -86,10 +86,12 @@ class GANHead(head._Head): # pylint: disable=protected-access use_loss_summaries: If `True`, add loss summaries. If `False`, does not. If `None`, uses defaults. get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list - of hooks. + of hooks. Defaults to `train.get_sequential_train_hooks()` name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. """ + if get_hooks_fn is None: + get_hooks_fn = tfgan_train.get_sequential_train_hooks() # TODO(joelshor): Validate inputs. if use_loss_summaries in [True, False]: diff --git a/tensorflow/contrib/gan/python/eval/__init__.py b/tensorflow/contrib/gan/python/eval/__init__.py index bb8046187807d0cc584f7174eb9aac578855c110..f86b8513053a45f9830411f7df2c32d1f36a97b2 100644 --- a/tensorflow/contrib/gan/python/eval/__init__.py +++ b/tensorflow/contrib/gan/python/eval/__init__.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN grouped API. Please see README.md for details and usage.""" +"""TFGAN evaluation module. + +This module supports techniques such as Inception Score, Frechet Inception +distance, and Sliced Wasserstein distance. +""" # pylint: disable=,wildcard-import,unused-import from __future__ import absolute_import @@ -22,10 +26,12 @@ from __future__ import print_function # Collapse eval into a single namespace. from tensorflow.contrib.gan.python.eval.python import classifier_metrics from tensorflow.contrib.gan.python.eval.python import eval_utils +from tensorflow.contrib.gan.python.eval.python import sliced_wasserstein from tensorflow.contrib.gan.python.eval.python import summaries from tensorflow.contrib.gan.python.eval.python.classifier_metrics import * from tensorflow.contrib.gan.python.eval.python.eval_utils import * +from tensorflow.contrib.gan.python.eval.python.sliced_wasserstein import * from tensorflow.contrib.gan.python.eval.python.summaries import * # pylint: enable=wildcard-import,unused-import @@ -33,7 +39,10 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ 'classifier_metrics', + 'sliced_wasserstein_distance', 'summaries', 'eval_utils', -] + classifier_metrics.__all__ + summaries.__all__ + eval_utils.__all__ +] + ( + classifier_metrics.__all__ + sliced_wasserstein.__all__ + + summaries.__all__ + eval_utils.__all__) remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py index bb65f05b5a17e9a872e41d1dcb05aeb3cd6f6f40..82293b575aefa198a618ae7286ca24ebabd6987d 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py @@ -57,8 +57,10 @@ __all__ = [ 'run_inception', 'inception_score', 'classifier_score', + 'classifier_score_from_logits', 'frechet_inception_distance', 'frechet_classifier_distance', + 'frechet_classifier_distance_from_activations', 'INCEPTION_DEFAULT_IMAGE_SIZE', ] @@ -222,13 +224,13 @@ def run_inception(images, image_size: Required image width and height. See unit tests for the default values. input_tensor: Name of input Tensor. - output_tensor: Name of output Tensor. This function will compute activations - at the specified layer. Examples include INCEPTION_V3_OUTPUT and - INCEPTION_V3_FINAL_POOL which would result in this function computing + output_tensor: Name or list of output Tensors. This function will compute + activations at the specified layer. Examples include INCEPTION_V3_OUTPUT + and INCEPTION_V3_FINAL_POOL which would result in this function computing the final logits or the penultimate pooling layer. Returns: - Logits. + Tensor or Tensors corresponding to computed `output_tensor`. Raises: ValueError: If images are not the correct size. @@ -244,8 +246,14 @@ def run_inception(images, activations = run_image_classifier(images, graph_def, input_tensor, output_tensor) - if array_ops.rank(activations) != 2: - activations = layers.flatten(activations) + if isinstance(activations, list): + for i, activation in enumerate(activations): + if array_ops.rank(activation) != 2: + activations[i] = layers.flatten(activation) + else: + if array_ops.rank(activations) != 2: + activations = layers.flatten(activations) + return activations @@ -257,23 +265,26 @@ def run_image_classifier(tensor, graph_def, input_tensor, tensor: An Input tensor. graph_def: A GraphDef proto. input_tensor: Name of input tensor in graph def. - output_tensor: Name of output tensor in graph def. + output_tensor: A tensor name or list of tensor names in graph def. scope: Name scope for classifier. Returns: - Classifier output. Shape depends on the classifier used, but is often - [batch, classes]. + Classifier output if `output_tensor` is a string, or a list of outputs if + `output_tensor` is a list. Raises: - ValueError: If `image_size` is not `None`, and `tensor` are not the correct - size. + ValueError: If `input_tensor` or `output_tensor` aren't in the graph_def. """ input_map = {input_tensor: tensor} - return_elements = [output_tensor] - classifier_output = importer.import_graph_def( - graph_def, input_map, return_elements, name=scope)[0] + is_singleton = isinstance(output_tensor, str) + if is_singleton: + output_tensor = [output_tensor] + classifier_outputs = importer.import_graph_def( + graph_def, input_map, output_tensor, name=scope) + if is_singleton: + classifier_outputs = classifier_outputs[0] - return classifier_output + return classifier_outputs def classifier_score(images, classifier_fn, num_batches=1): @@ -312,6 +323,30 @@ def classifier_score(images, classifier_fn, num_batches=1): swap_memory=True, name='RunClassifier') logits = array_ops.concat(array_ops.unstack(logits), 0) + + return classifier_score_from_logits(logits) + + +def classifier_score_from_logits(logits): + """Classifier score for evaluating a conditional generative model. + + This is based on the Inception Score, but for an arbitrary classifier. + + This technique is described in detail in https://arxiv.org/abs/1606.03498. In + summary, this function calculates + + exp( E[ KL(p(y|x) || p(y)) ] ) + + which captures how different the network's classification prediction is from + the prior distribution over classes. + + Args: + logits: A 2D Tensor of logits. + + Returns: + The classifier score. A floating-point scalar of the same type as the output + of `logits`. + """ logits.shape.assert_has_rank(2) # Use maximum precision for best results. @@ -436,31 +471,71 @@ def frechet_classifier_distance(real_images, swap_memory=True, name='RunClassifier') - activations_dtype = activations.dtype # Split the activations by the real and generated images. real_a, gen_a = array_ops.split(activations, [num_batches, num_batches], 0) # Ensure the activations have the right shapes. real_a = array_ops.concat(array_ops.unstack(real_a), 0) gen_a = array_ops.concat(array_ops.unstack(gen_a), 0) - if activations_dtype != dtypes.float64: - real_a = math_ops.to_double(real_a) - gen_a = math_ops.to_double(gen_a) - real_a.shape.assert_has_rank(2) - gen_a.shape.assert_has_rank(2) + return frechet_classifier_distance_from_activations(real_a, gen_a) + + +def frechet_classifier_distance_from_activations( + real_activations, generated_activations): + """Classifier distance for evaluating a generative model. + + This is based on the Frechet Inception distance, but for an arbitrary + classifier. + + This technique is described in detail in https://arxiv.org/abs/1706.08500. + Given two Gaussian distribution with means m and m_w and covariance matrices + C and C_w, this function calcuates + + |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) + + which captures how different the distributions of real images and generated + images (or more accurately, their visual features) are. Note that unlike the + Inception score, this is a true distance and utilizes information about real + world images. + + Note that when computed using sample means and sample covariance matrices, + Frechet distance is biased. It is more biased for small sample sizes. (e.g. + even if the two distributions are the same, for a small sample size, the + expected Frechet distance is large). It is important to use the same + sample size to compute frechet classifier distance when comparing two + generative models. + + Args: + real_activations: Real images to use to compute Frechet Inception distance. + generated_activations: Generated images to use to compute Frechet Inception + distance. + + Returns: + The Frechet Inception distance. A floating-point scalar of the same type + as the output of the activations. + """ + real_activations.shape.assert_has_rank(2) + generated_activations.shape.assert_has_rank(2) + + activations_dtype = real_activations.dtype + if activations_dtype != dtypes.float64: + real_activations = math_ops.to_double(real_activations) + generated_activations = math_ops.to_double(generated_activations) # Compute mean and covariance matrices of activations. - m = math_ops.reduce_mean(real_a, 0) - m_v = math_ops.reduce_mean(gen_a, 0) - num_examples = math_ops.to_double(array_ops.shape(real_a)[0]) + m = math_ops.reduce_mean(real_activations, 0) + m_v = math_ops.reduce_mean(generated_activations, 0) + num_examples = math_ops.to_double(array_ops.shape(real_activations)[0]) # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T + real_centered = real_activations - m sigma = math_ops.matmul( - real_a - m, real_a - m, transpose_a=True) / (num_examples - 1) + real_centered, real_centered, transpose_a=True) / (num_examples - 1) + gen_centered = generated_activations - m_v sigma_v = math_ops.matmul( - gen_a - m_v, gen_a - m_v, transpose_a=True) / (num_examples - 1) + gen_centered, gen_centered, transpose_a=True) / (num_examples - 1) # Find the Tr(sqrt(sigma sigma_v)) component of FID sqrt_trace_component = trace_sqrt_product(sigma, sigma_v) diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py index 92e0a995748c1c4c2ddfff0daae59be5a6eaefb4..1e18c699ba93b5f524341c65d0a2db84556b65a2 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py @@ -190,6 +190,23 @@ class ClassifierMetricsTest(test.TestCase): # Check that none of the model variables are trainable. self.assertListEqual([], variables.trainable_variables()) + def test_run_inception_multiple_outputs(self): + """Test `run_inception` graph construction with multiple outputs.""" + batch_size = 3 + img = array_ops.ones([batch_size, 299, 299, 3]) + logits, pool = _run_with_mock( + classifier_metrics.run_inception, img, + output_tensor=[classifier_metrics.INCEPTION_OUTPUT, + classifier_metrics.INCEPTION_FINAL_POOL]) + + self.assertTrue(isinstance(logits, ops.Tensor)) + self.assertTrue(isinstance(pool, ops.Tensor)) + logits.shape.assert_is_compatible_with([batch_size, 1001]) + pool.shape.assert_is_compatible_with([batch_size, 2048]) + + # Check that none of the model variables are trainable. + self.assertListEqual([], variables.trainable_variables()) + def test_inception_score_graph(self): """Test `inception_score` graph construction.""" score = _run_with_mock(classifier_metrics.inception_score, diff --git a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein.py b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein.py new file mode 100644 index 0000000000000000000000000000000000000000..523968bed91f1021ae629bf52c405cf5c2d7b917 --- /dev/null +++ b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein.py @@ -0,0 +1,28 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model evaluation tools for TFGAN.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python.eval.python import sliced_wasserstein_impl +# pylint: disable=wildcard-import +from tensorflow.contrib.gan.python.eval.python.sliced_wasserstein_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +__all__ = sliced_wasserstein_impl.__all__ +remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..9bebcacbe46d85fc4226c4275b71b3ecbde57a97 --- /dev/null +++ b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py @@ -0,0 +1,282 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implementation of Sliced Wasserstein Distance. + +Proposed in https://arxiv.org/abs/1710.10196 and the official Theano +implementation that we used as reference can be found here: +https://github.com/tkarras/progressive_growing_of_gans + +Note: this is not an exact distance but an approximation through random +projections. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import script_ops + +__all__ = ['sliced_wasserstein_distance'] +_GAUSSIAN_FILTER = np.float32([[1, 4, 6, 4, 1], [4, 16, 24, 16, 4], [ + 6, 24, 36, 24, 6 +], [4, 16, 24, 16, 4], [1, 4, 6, 4, 1]]).reshape([5, 5, 1, 1]) / 256.0 + + +def _laplacian_pyramid(batch, num_levels): + """Compute a Laplacian pyramid. + + Args: + batch: (tensor) The batch of images (batch, height, width, channels). + num_levels: (int) Desired number of hierarchical levels. + Returns: + List of tensors from the highest to lowest resolution. + """ + gaussian_filter = constant_op.constant(_GAUSSIAN_FILTER) + + def spatial_conv(batch, gain): + s = array_ops.shape(batch) + padded = array_ops.pad(batch, [[0, 0], [2, 2], [2, 2], [0, 0]], 'REFLECT') + xt = array_ops.transpose(padded, [0, 3, 1, 2]) + xt = array_ops.reshape(xt, [s[0] * s[3], s[1] + 4, s[2] + 4, 1]) + conv_out = nn_ops.conv2d(xt, gaussian_filter * gain, [1] * 4, 'VALID') + conv_xt = array_ops.reshape(conv_out, [s[0], s[3], s[1], s[2]]) + conv_xt = array_ops.transpose(conv_xt, [0, 2, 3, 1]) + return conv_xt + + def pyr_down(batch): # matches cv2.pyrDown() + return spatial_conv(batch, 1)[:, ::2, ::2] + + def pyr_up(batch): # matches cv2.pyrUp() + s = array_ops.shape(batch) + zeros = array_ops.zeros([3 * s[0], s[1], s[2], s[3]]) + res = array_ops.concat([batch, zeros], 0) + res = array_ops.batch_to_space(res, crops=[[0, 0], [0, 0]], block_size=2) + res = spatial_conv(res, 4) + return res + + pyramid = [math_ops.to_float(batch)] + for _ in range(1, num_levels): + pyramid.append(pyr_down(pyramid[-1])) + pyramid[-2] -= pyr_up(pyramid[-1]) + return pyramid + + +def _batch_to_patches(batch, patches_per_image, patch_size): + """Extract patches from a batch. + + Args: + batch: (tensor) The batch of images (batch, height, width, channels). + patches_per_image: (int) Number of patches to extract per image. + patch_size: (int) Size of the patches (size, size, channels) to extract. + Returns: + Tensor (batch*patches_per_image, patch_size, patch_size, channels) of + patches. + """ + + def py_func_random_patches(batch): + """Numpy wrapper.""" + batch_size, height, width, channels = batch.shape + patch_count = patches_per_image * batch_size + hs = patch_size // 2 + # Randomly pick patches. + patch_id, y, x, chan = np.ogrid[0:patch_count, -hs:hs + 1, -hs:hs + 1, 0:3] + img_id = patch_id // patches_per_image + # pylint: disable=g-no-augmented-assignment + # Need explicit addition for broadcast to work properly. + y = y + np.random.randint(hs, height - hs, size=(patch_count, 1, 1, 1)) + x = x + np.random.randint(hs, width - hs, size=(patch_count, 1, 1, 1)) + # pylint: enable=g-no-augmented-assignment + idx = ((img_id * height + y) * width + x) * channels + chan + patches = batch.flat[idx] + return patches + + patches = script_ops.py_func( + py_func_random_patches, [batch], batch.dtype, stateful=False) + return patches + + +def _normalize_patches(patches): + """Normalize patches by their mean and standard deviation. + + Args: + patches: (tensor) The batch of patches (batch, size, size, channels). + Returns: + Tensor (batch, size, size, channels) of the normalized patches. + """ + patches = array_ops.concat(patches, 0) + mean, variance = nn.moments(patches, [1, 2, 3], keep_dims=True) + patches = (patches - mean) / math_ops.sqrt(variance) + return array_ops.reshape(patches, [array_ops.shape(patches)[0], -1]) + + +def _sort_rows(matrix, num_rows): + """Sort matrix rows by the last column. + + Args: + matrix: a matrix of values (row,col). + num_rows: (int) number of sorted rows to return from the matrix. + Returns: + Tensor (num_rows, col) of the sorted matrix top K rows. + """ + tmatrix = array_ops.transpose(matrix, [1, 0]) + sorted_tmatrix = nn_ops.top_k(tmatrix, num_rows)[0] + return array_ops.transpose(sorted_tmatrix, [1, 0]) + + +def _sliced_wasserstein(a, b, random_sampling_count, random_projection_dim): + """Compute the approximate sliced Wasserstein distance. + + Args: + a: (matrix) Distribution "a" of samples (row, col). + b: (matrix) Distribution "b" of samples (row, col). + random_sampling_count: (int) Number of random projections to average. + random_projection_dim: (int) Dimension of the random projection space. + Returns: + Float containing the approximate distance between "a" and "b". + """ + s = array_ops.shape(a) + means = [] + for _ in range(random_sampling_count): + # Random projection matrix. + proj = random_ops.random_normal( + [array_ops.shape(a)[1], random_projection_dim]) + proj *= math_ops.rsqrt( + math_ops.reduce_sum(math_ops.square(proj), 0, keep_dims=True)) + # Project both distributions and sort them. + proj_a = math_ops.matmul(a, proj) + proj_b = math_ops.matmul(b, proj) + proj_a = _sort_rows(proj_a, s[0]) + proj_b = _sort_rows(proj_b, s[0]) + # Pairwise Wasserstein distance. + wdist = math_ops.reduce_mean(math_ops.abs(proj_a - proj_b)) + means.append(wdist) + return math_ops.reduce_mean(means) + + +def _sliced_wasserstein_svd(a, b): + """Compute the approximate sliced Wasserstein distance using an SVD. + + This is not part of the paper, it's a variant with possibly more accurate + measure. + + Args: + a: (matrix) Distribution "a" of samples (row, col). + b: (matrix) Distribution "b" of samples (row, col). + Returns: + Float containing the approximate distance between "a" and "b". + """ + s = array_ops.shape(a) + # Random projection matrix. + sig, u = linalg_ops.svd(array_ops.concat([a, b], 0))[:2] + proj_a, proj_b = array_ops.split(u * sig, 2, axis=0) + proj_a = _sort_rows(proj_a[:, ::-1], s[0]) + proj_b = _sort_rows(proj_b[:, ::-1], s[0]) + # Pairwise Wasserstein distance. + wdist = math_ops.reduce_mean(math_ops.abs(proj_a - proj_b)) + return wdist + + +def sliced_wasserstein_distance(real_images, + fake_images, + resolution_min=16, + patches_per_image=64, + patch_size=7, + random_sampling_count=1, + random_projection_dim=7 * 7 * 3, + use_svd=False): + """Compute the Wasserstein distance between two distributions of images. + + Note that measure vary with the number of images. Use 8192 images to get + numbers comparable to the ones in the original paper. + + Args: + real_images: (tensor) Real images (batch, height, width, channels). + fake_images: (tensor) Fake images (batch, height, width, channels). + resolution_min: (int) Minimum resolution for the Laplacion pyramid. + patches_per_image: (int) Number of patches to extract per image per + Laplacian level. + patch_size: (int) Width of a square patch. + random_sampling_count: (int) Number of random projections to average. + random_projection_dim: (int) Dimension of the random projection space. + use_svd: experimental method to compute a more accurate distance. + Returns: + List of tuples (distance_real, distance_fake) for each level of the + Laplacian pyramid from the highest resoluion to the lowest. + distance_real is the Wasserstein distance between real images + distance_fake is the Wasserstein distance between real and fake images. + Raises: + ValueError: If the inputs shapes are incorrect. Input tensor dimensions + (batch, height, width, channels) are expected to be known at graph + construction time. In addition height and width must be the same and the + number of colors should be exactly 3. Real and fake images must have the + same size. + """ + height = real_images.shape[1] + real_images.shape.assert_is_compatible_with([None, None, height, 3]) + fake_images.shape.assert_is_compatible_with(real_images.shape) + + # Select resolutions. + resolution_full = int(height) + resolution_min = min(resolution_min, resolution_full) + resolution_max = resolution_full + # Base loss of detail. + resolutions = [ + 2**i + for i in range( + int(np.log2(resolution_max)), + int(np.log2(resolution_min)) - 1, -1) + ] + + # Gather patches for each level of the Laplacian pyramids. + patches_real, patches_fake, patches_test = ( + [[] for _ in resolutions] for _ in range(3)) + for lod, level in enumerate( + _laplacian_pyramid(real_images, len(resolutions))): + patches_real[lod].append( + _batch_to_patches(level, patches_per_image, patch_size)) + patches_test[lod].append( + _batch_to_patches(level, patches_per_image, patch_size)) + + for lod, level in enumerate( + _laplacian_pyramid(fake_images, len(resolutions))): + patches_fake[lod].append( + _batch_to_patches(level, patches_per_image, patch_size)) + + for lod in range(len(resolutions)): + for patches in [patches_real, patches_test, patches_fake]: + patches[lod] = _normalize_patches(patches[lod]) + + # Evaluate scores. + scores = [] + for lod in range(len(resolutions)): + if not use_svd: + scores.append( + (_sliced_wasserstein(patches_real[lod], patches_test[lod], + random_sampling_count, random_projection_dim), + _sliced_wasserstein(patches_real[lod], patches_fake[lod], + random_sampling_count, random_projection_dim))) + else: + scores.append( + (_sliced_wasserstein_svd(patches_real[lod], patches_test[lod]), + _sliced_wasserstein_svd(patches_real[lod], patches_fake[lod]))) + return scores diff --git a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_test.py b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b960af28eaa969079b72c7aabcde2ad6cd1f5c68 --- /dev/null +++ b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_test.py @@ -0,0 +1,131 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Sliced Wasserstein Distance.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from scipy import ndimage +from tensorflow.contrib.gan.python.eval.python import sliced_wasserstein_impl as swd +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import test + + +class ClassifierMetricsTest(test.TestCase): + + def test_laplacian_pyramid(self): + # The numpy/scipy code for reference estimation comes from: + # https://github.com/tkarras/progressive_growing_of_gans + gaussian_filter = np.float32([[1, 4, 6, 4, 1], [4, 16, 24, 16, 4], [ + 6, 24, 36, 24, 6 + ], [4, 16, 24, 16, 4], [1, 4, 6, 4, 1]]) / 256.0 + + def np_pyr_down(minibatch): # matches cv2.pyrDown() + assert minibatch.ndim == 4 + return ndimage.convolve( + minibatch, + gaussian_filter[np.newaxis, np.newaxis, :, :], + mode='mirror')[:, :, ::2, ::2] + + def np_pyr_up(minibatch): # matches cv2.pyrUp() + assert minibatch.ndim == 4 + s = minibatch.shape + res = np.zeros((s[0], s[1], s[2] * 2, s[3] * 2), minibatch.dtype) + res[:, :, ::2, ::2] = minibatch + return ndimage.convolve( + res, + gaussian_filter[np.newaxis, np.newaxis, :, :] * 4.0, + mode='mirror') + + def np_laplacian_pyramid(minibatch, num_levels): + # Note: there's a bug in the original SWD, fixed repeatability. + pyramid = [minibatch.astype('f').copy()] + for _ in range(1, num_levels): + pyramid.append(np_pyr_down(pyramid[-1])) + pyramid[-2] -= np_pyr_up(pyramid[-1]) + return pyramid + + data = np.random.normal(size=[256, 3, 32, 32]).astype('f') + pyramid = np_laplacian_pyramid(data, 3) + data_tf = array_ops.placeholder(dtypes.float32, [256, 32, 32, 3]) + pyramid_tf = swd._laplacian_pyramid(data_tf, 3) + with self.test_session() as sess: + pyramid_tf = sess.run( + pyramid_tf, feed_dict={ + data_tf: data.transpose(0, 2, 3, 1) + }) + for x in range(3): + self.assertAllClose( + pyramid[x].transpose(0, 2, 3, 1), pyramid_tf[x], atol=1e-6) + + def test_sliced_wasserstein_distance(self): + """Test the distance.""" + d1 = random_ops.random_uniform([256, 32, 32, 3]) + d2 = random_ops.random_normal([256, 32, 32, 3]) + wfunc = swd.sliced_wasserstein_distance(d1, d2) + with self.test_session() as sess: + wscores = [sess.run(x) for x in wfunc] + self.assertAllClose( + np.array([0.014, 0.014], 'f'), + np.array([x[0] for x in wscores], 'f'), + rtol=0.1) + self.assertAllClose( + np.array([0.014, 0.020], 'f'), + np.array([x[1] for x in wscores], 'f'), + rtol=0.1) + + def test_sliced_wasserstein_distance_svd(self): + """Test the distance.""" + d1 = random_ops.random_uniform([256, 32, 32, 3]) + d2 = random_ops.random_normal([256, 32, 32, 3]) + wfunc = swd.sliced_wasserstein_distance(d1, d2, use_svd=True) + with self.test_session() as sess: + wscores = [sess.run(x) for x in wfunc] + self.assertAllClose( + np.array([0.013, 0.013], 'f'), + np.array([x[0] for x in wscores], 'f'), + rtol=0.15) + self.assertAllClose( + np.array([0.014, 0.019], 'f'), + np.array([x[1] for x in wscores], 'f'), + rtol=0.15) + + def test_swd_mismatched(self): + """Test the inputs mismatched shapes are detected.""" + d1 = random_ops.random_uniform([256, 32, 32, 3]) + d2 = random_ops.random_normal([256, 32, 31, 3]) + d3 = random_ops.random_normal([256, 31, 32, 3]) + d4 = random_ops.random_normal([255, 32, 32, 3]) + with self.assertRaises(ValueError): + swd.sliced_wasserstein_distance(d1, d2) + with self.assertRaises(ValueError): + swd.sliced_wasserstein_distance(d1, d3) + with self.assertRaises(ValueError): + swd.sliced_wasserstein_distance(d1, d4) + + def test_swd_not_rgb(self): + """Test that only RGB is supported.""" + d1 = random_ops.random_uniform([256, 32, 32, 1]) + d2 = random_ops.random_normal([256, 32, 32, 1]) + with self.assertRaises(ValueError): + swd.sliced_wasserstein_distance(d1, d2) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/gan/python/features/__init__.py b/tensorflow/contrib/gan/python/features/__init__.py index 6d0972f8db418d6fcf517cc6f7e96093ae08a9e4..4816daf760143af9f1502873b123ffad8e5ec8ce 100644 --- a/tensorflow/contrib/gan/python/features/__init__.py +++ b/tensorflow/contrib/gan/python/features/__init__.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN grouped API. Please see README.md for details and usage.""" +"""TFGAN features module. + +This module includes support for virtual batch normalization, buffer replay, +conditioning, etc. +""" from __future__ import absolute_import from __future__ import division @@ -22,10 +26,12 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import from tensorflow.contrib.gan.python.features.python import clip_weights from tensorflow.contrib.gan.python.features.python import conditioning_utils +from tensorflow.contrib.gan.python.features.python import random_tensor_pool from tensorflow.contrib.gan.python.features.python import virtual_batchnorm from tensorflow.contrib.gan.python.features.python.clip_weights import * from tensorflow.contrib.gan.python.features.python.conditioning_utils import * +from tensorflow.contrib.gan.python.features.python.random_tensor_pool import * from tensorflow.contrib.gan.python.features.python.virtual_batchnorm import * # pylint: enable=unused-import,wildcard-import @@ -33,5 +39,6 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = clip_weights.__all__ _allowed_symbols += conditioning_utils.__all__ +_allowed_symbols += random_tensor_pool.__all__ _allowed_symbols += virtual_batchnorm.__all__ remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/gan/python/features/python/clip_weights_test.py b/tensorflow/contrib/gan/python/features/python/clip_weights_test.py index 030e37ec679ec58e3b534fd3644ffe1d23173404..2b7bb5f14e7f3d1b3f913d3426efaaae19079ffb 100644 --- a/tensorflow/contrib/gan/python/features/python/clip_weights_test.py +++ b/tensorflow/contrib/gan/python/features/python/clip_weights_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for tfgan.python.features.clip_weights.""" +"""Tests for features.clip_weights.""" from __future__ import absolute_import from __future__ import division @@ -31,17 +31,18 @@ class ClipWeightsTest(test.TestCase): """Tests for `discriminator_weight_clip`.""" def setUp(self): + super(ClipWeightsTest, self).setUp() self.variables = [variables.Variable(2.0)] self.tuple = collections.namedtuple( 'VarTuple', ['discriminator_variables'])(self.variables) def _test_weight_clipping_helper(self, use_tuple): - loss = self.variables[0] * 2.0 + loss = self.variables[0] opt = training.GradientDescentOptimizer(1.0) if use_tuple: - opt_clip = clip_weights.weight_clip(opt, self.variables, 0.1) + opt_clip = clip_weights.clip_variables(opt, self.variables, 0.1) else: - opt_clip = clip_weights.discriminator_weight_clip(opt, self.tuple, 0.1) + opt_clip = clip_weights.clip_discriminator_weights(opt, self.tuple, 0.1) train_op1 = opt.minimize(loss, var_list=self.variables) train_op2 = opt_clip.minimize(loss, var_list=self.variables) @@ -72,10 +73,14 @@ class ClipWeightsTest(test.TestCase): clip_weights.clip_discriminator_weights(opt, self.tuple, weight_clip=-1) else: with self.assertRaisesRegexp(ValueError, 'must be positive'): - clip_weights.clip_weights(opt, self.variables, weight_clip=-1) + clip_weights.clip_variables(opt, self.variables, weight_clip=-1) def test_incorrect_weight_clip_value_argsonly(self): self._test_incorrect_weight_clip_value_helper(False) def test_incorrect_weight_clip_value_tuple(self): self._test_incorrect_weight_clip_value_helper(True) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/gan/python/features/python/random_tensor_pool.py b/tensorflow/contrib/gan/python/features/python/random_tensor_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..ca904971fa8cb0440d3e0c9060f13cc214c9eaad --- /dev/null +++ b/tensorflow/contrib/gan/python/features/python/random_tensor_pool.py @@ -0,0 +1,35 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A tensor pool stores values from an input tensor and returns a stored one. + +See the following papers for more details. +1) `Learning from simulated and unsupervised images through adversarial + training` (https://arxiv.org/abs/1612.07828). +2) `Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial + Networks` (https://arxiv.org/abs/1703.10593). +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python.features.python import random_tensor_pool_impl +# pylint: disable=wildcard-import +from tensorflow.contrib.gan.python.features.python.random_tensor_pool_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +__all__ = random_tensor_pool_impl.__all__ +remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..9d10db0f5a3d09dc4dd7d8b1c97c16c29808547c --- /dev/null +++ b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py @@ -0,0 +1,135 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A tensor pool stores values from an input tensor and returns a stored one. + +We use this to keep a history of values created by a generator, such that +a discriminator can randomly be trained on some older samples, not just the +current one. This can help to not let the discriminator get too far ahead of the +generator and also to keep the system from oscilating, if the discriminator +forgets too fast what past samples from the generator looked like. + +See the following papers for more details. +1) `Learning from simulated and unsupervised images through adversarial + training` (https://arxiv.org/abs/1612.07828). +2) `Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial + Networks` (https://arxiv.org/abs/1703.10593). +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import random_ops + +__all__ = [ + 'tensor_pool', +] + + +def _to_tuple(x): + if isinstance(x, (list, tuple)): + return tuple(x) + return (x,) + + +def tensor_pool(input_values, + pool_size=50, + pooling_probability=0.5, + name='tensor_pool'): + """Queue storing input values and returning random previously stored ones. + + Every time the returned `output_value` is evaluated, `input_value` is + evaluated and its value either directly returned (with + `1-pooling_probability`) or stored in the pool and a random one of the samples + currently in the pool is popped and returned. As long as the pool in not fully + filled, the input_value is always directly returned, as well as stored in the + pool. Note during inference / testing, it may be appropriate to set + `pool_size` = 0 or `pooling_probability` = 0. + + Args: + input_values: A `Tensor`, or a list or tuple of `Tensor`s from which to read + values to be pooled. + pool_size: An integer specifying the maximum size of the pool. Defaults to + 50. + pooling_probability: A float `Tensor` specifying the probability of getting + a value from the pool, as opposed to just the current input. + name: A string prefix for the name scope for all tensorflow ops. + + Returns: + A `Tensor`, or a list or tuple of `Tensor`s (according to the type ofx + `input_values`) which is with given probability either the `input_values` or + a randomly chosen sample that was previously inserted in the pool. + + Raises: + ValueError: If `pool_size` is negative. + """ + pool_size = int(pool_size) + if pool_size < 0: + raise ValueError('`pool_size` is negative.') + elif pool_size == 0: + return input_values + + original_input_values = input_values + input_values = _to_tuple(input_values) + + with ops.name_scope( + '{}_pool_queue'.format(name), + values=input_values + (pooling_probability,)): + pool_queue = data_flow_ops.RandomShuffleQueue( + capacity=pool_size, + min_after_dequeue=0, + dtypes=[v.dtype for v in input_values], + shapes=None) + + # In pseudeo code this code does the following: + # if not pool_full: + # enqueue(input_values) + # return input_values + # else + # dequeue_values = dequeue_random_sample() + # enqueue(input_values) + # if rand() < pooling_probability: + # return dequeue_values + # else + # return input_values + + def _get_input_value_pooled(): + enqueue_op = pool_queue.enqueue(input_values) + with ops.control_dependencies([enqueue_op]): + return tuple(array_ops.identity(v) for v in input_values) + + def _get_random_pool_value_and_enqueue_input(): + dequeue_values = _to_tuple(pool_queue.dequeue()) + with ops.control_dependencies(dequeue_values): + enqueue_op = pool_queue.enqueue(input_values) + with ops.control_dependencies([enqueue_op]): + prob = random_ops.random_uniform( + (), dtype=dtypes.float32) < pooling_probability + return control_flow_ops.cond(prob, lambda: dequeue_values, + lambda: input_values) + + output_values = _to_tuple(control_flow_ops.cond( + pool_queue.size() < pool_size, _get_input_value_pooled, + _get_random_pool_value_and_enqueue_input)) + + if isinstance(original_input_values, list): + return list(output_values) + elif isinstance(original_input_values, tuple): + return output_values + return output_values[0] diff --git a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py new file mode 100644 index 0000000000000000000000000000000000000000..cef3a87ab34f9754099073eefcb3f1b1c97a3762 --- /dev/null +++ b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py @@ -0,0 +1,110 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.contrib.gan.python.features.random_tensor_pool.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.gan.python.features.python.random_tensor_pool_impl import tensor_pool +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class TensorPoolTest(test.TestCase): + + def test_pool_unknown_input_shape(self): + """Checks that `input_value` can have unknown shape.""" + input_value = array_ops.placeholder( + dtype=dtypes.int32, shape=[None, None, 3]) + output_value = tensor_pool(input_value, pool_size=10) + + with self.test_session(use_gpu=True) as session: + for i in range(10): + session.run(output_value, {input_value: [[[i] * 3]]}) + session.run(output_value, {input_value: [[[i] * 3] * 2]}) + session.run(output_value, {input_value: [[[i] * 3] * 5] * 2}) + + def test_pool_sequence(self): + """Checks that values are pooled and returned maximally twice.""" + input_value = array_ops.placeholder(dtype=dtypes.int32, shape=[]) + output_value = tensor_pool(input_value, pool_size=10) + + with self.test_session(use_gpu=True) as session: + outs = [] + for i in range(50): + out = session.run(output_value, {input_value: i}) + outs.append(out) + self.assertLessEqual(out, i) + + _, counts = np.unique(outs, return_counts=True) + # Check that each value is returned maximally twice. + self.assertTrue((counts <= 2).all()) + + def test_never_pool(self): + """Checks that setting `pooling_probability` to zero works.""" + input_value = array_ops.placeholder(dtype=dtypes.int32, shape=[]) + output_value = tensor_pool( + input_value, pool_size=10, pooling_probability=0.0) + + with self.test_session(use_gpu=True) as session: + for i in range(50): + out = session.run(output_value, {input_value: i}) + self.assertEqual(out, i) + + def test_pooling_probability(self): + """Checks that `pooling_probability` works.""" + input_value = array_ops.placeholder(dtype=dtypes.int32, shape=[]) + pool_size = 10 + pooling_probability = 0.2 + output_value = tensor_pool( + input_value, + pool_size=pool_size, + pooling_probability=pooling_probability) + + with self.test_session(use_gpu=True) as session: + not_pooled = 0 + total = 1000 + for i in range(total): + out = session.run(output_value, {input_value: i}) + if out == i: + not_pooled += 1 + self.assertAllClose( + (not_pooled - pool_size) / (total - pool_size), + 1 - pooling_probability, + atol=0.03) + + def test_input_values_tuple(self): + """Checks that `input_values` can be a tuple.""" + input_values = (array_ops.placeholder(dtype=dtypes.int32, shape=[]), + array_ops.placeholder(dtype=dtypes.int32, shape=[])) + output_values = tensor_pool(input_values, pool_size=3) + self.assertEqual(len(output_values), len(input_values)) + + with self.test_session(use_gpu=True) as session: + for i in range(10): + outs = session.run(output_values, { + input_values[0]: i, + input_values[1]: i + 1 + }) + self.assertEqual(len(outs), len(input_values)) + self.assertEqual(outs[1] - outs[0], 1) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/gan/python/losses/__init__.py b/tensorflow/contrib/gan/python/losses/__init__.py index 290ff867a1e443f20a63e27fd97f53fed8a6cc11..d9bf8ebfdf65dfc76e4569dcaf26e0e51c7fc107 100644 --- a/tensorflow/contrib/gan/python/losses/__init__.py +++ b/tensorflow/contrib/gan/python/losses/__init__.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN grouped API. Please see README.md for details and usage.""" +"""TFGAN losses and penalties. + +Losses can be used with individual arguments or with GANModel tuples. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/namedtuples.py b/tensorflow/contrib/gan/python/namedtuples.py index 48f5e8e47dbcd5d32c23806b967a0d1e7403d2f7..3d4e315ebd0bd52b3b5e3e4a8655df8bfe9cebe8 100644 --- a/tensorflow/contrib/gan/python/namedtuples.py +++ b/tensorflow/contrib/gan/python/namedtuples.py @@ -79,6 +79,7 @@ class InfoGANModel( collections.namedtuple('InfoGANModel', GANModel._fields + ( 'structured_generator_inputs', 'predicted_distributions', + 'discriminator_and_aux_fn', ))): """An InfoGANModel contains all the pieces needed for InfoGAN training. @@ -91,6 +92,8 @@ class InfoGANModel( predicted_distributions: A list of tf.Distributions. Predicted by the recognizer, and used to evaluate the likelihood of the structured noise. List length should match `structured_generator_inputs`. + discriminator_and_aux_fn: The original discriminator function that returns + a tuple of (logits, `predicted_distributions`). """ diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index ad2d5eb86cdab89273efbd4ddce45f6657b54406..edd0113977ff4ddc672b0ec134be1a48c621b579 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -215,7 +215,8 @@ def infogan_model( disc_scope, lambda x, y: discriminator_fn(x, y)[0], # conform to non-InfoGAN API structured_generator_inputs, - predicted_distributions) + predicted_distributions, + discriminator_fn) def acgan_model( @@ -326,6 +327,53 @@ def _use_aux_loss(aux_loss_weight): return False +def _tensor_pool_adjusted_model(model, tensor_pool_fn): + """Adjusts model using `tensor_pool_fn`. + + Args: + model: A GANModel tuple. + tensor_pool_fn: A function that takes (generated_data, generator_inputs), + stores them in an internal pool and returns a previously stored + (generated_data, generator_inputs) with some probability. For example + tfgan.features.tensor_pool. + + Returns: + A new GANModel tuple where discriminator outputs are adjusted by taking + pooled generator outputs as inputs. Returns the original model if + `tensor_pool_fn` is None. + + Raises: + ValueError: If tensor pool does not suport the `model`. + """ + if tensor_pool_fn is None: + return model + + pooled_generated_data, pooled_generator_inputs = tensor_pool_fn( + (model.generated_data, model.generator_inputs)) + + if isinstance(model, namedtuples.GANModel): + dis_gen_outputs = model.discriminator_fn(pooled_generated_data, + pooled_generator_inputs) + return model._replace(discriminator_gen_outputs=dis_gen_outputs) + elif isinstance(model, namedtuples.ACGANModel): + (dis_pooled_gen_outputs, + dis_pooled_gen_classification_logits) = model.discriminator_fn( + pooled_generated_data, pooled_generator_inputs) + return model._replace( + discriminator_gen_outputs=dis_pooled_gen_outputs, + discriminator_gen_classification_logits= + dis_pooled_gen_classification_logits) + elif isinstance(model, namedtuples.InfoGANModel): + (dis_pooled_gen_outputs, + pooled_predicted_distributions) = model.discriminator_and_aux_fn( + pooled_generated_data, pooled_generator_inputs) + return model._replace( + discriminator_gen_outputs=dis_pooled_gen_outputs, + predicted_distributions=pooled_predicted_distributions) + else: + raise ValueError('Tensor pool does not support `model`: %s.' % type(model)) + + def gan_loss( # GANModel. model, @@ -338,6 +386,7 @@ def gan_loss( mutual_information_penalty_weight=None, aux_cond_generator_weight=None, aux_cond_discriminator_weight=None, + tensor_pool_fn=None, # Options. add_summaries=True): """Returns losses necessary to train generator and discriminator. @@ -363,6 +412,10 @@ def gan_loss( https://arxiv.org/abs/1610.09585 aux_cond_discriminator_weight: If not None: add a classification loss as in https://arxiv.org/abs/1610.09585 + tensor_pool_fn: A function that takes (generated_data, generator_inputs), + stores them in an internal pool and returns previous stored + (generated_data, generator_inputs). For example + `tf.gan.features.tensor_pool`. Defaults to None (not using tensor pool). add_summaries: Whether or not to add summaries for the losses. Returns: @@ -402,7 +455,9 @@ def gan_loss( # Create standard losses. gen_loss = generator_loss_fn(model, add_summaries=add_summaries) - dis_loss = discriminator_loss_fn(model, add_summaries=add_summaries) + dis_loss = discriminator_loss_fn( + _tensor_pool_adjusted_model(model, tensor_pool_fn), + add_summaries=add_summaries) # Add optional extra losses. if _use_aux_loss(gradient_penalty_weight): @@ -422,7 +477,7 @@ def gan_loss( ac_disc_loss = tfgan_losses.acgan_discriminator_loss( model, add_summaries=add_summaries) dis_loss += aux_cond_discriminator_weight * ac_disc_loss - # Gathers auxilliary losses. + # Gathers auxiliary losses. if model.generator_scope: gen_reg_loss = losses.get_regularization_loss(model.generator_scope.name) else: diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py index 6b27b6926102b6e5a7ff134ceed75c23459a6534..519d101e07f4f28d684017b86102fce8fa7677ef 100644 --- a/tensorflow/contrib/gan/python/train_test.py +++ b/tensorflow/contrib/gan/python/train_test.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.contrib.framework.python.ops import variables as variables_lib from tensorflow.contrib.gan.python import namedtuples from tensorflow.contrib.gan.python import train +from tensorflow.contrib.gan.python.features.python import random_tensor_pool from tensorflow.contrib.slim.python.slim import learning as slim_learning from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -145,14 +146,16 @@ def get_infogan_model(): return namedtuples.InfoGANModel( *get_gan_model(), structured_generator_inputs=[constant_op.constant(0)], - predicted_distributions=[categorical.Categorical([1.0])]) + predicted_distributions=[categorical.Categorical([1.0])], + discriminator_and_aux_fn=infogan_discriminator_model) def get_callable_infogan_model(): return namedtuples.InfoGANModel( *get_callable_gan_model(), structured_generator_inputs=[constant_op.constant(0)], - predicted_distributions=[categorical.Categorical([1.0])]) + predicted_distributions=[categorical.Categorical([1.0])], + discriminator_and_aux_fn=infogan_discriminator_model) def create_infogan_model(): @@ -409,6 +412,51 @@ class GANLossTest(test.TestCase): def test_callable_acgan(self): self._test_acgan_helper(create_callable_acgan_model) + # Test tensor pool. + def _test_tensor_pool_helper(self, create_gan_model_fn): + model = create_gan_model_fn() + if isinstance(model, namedtuples.InfoGANModel): + + def tensor_pool_fn_impl(input_values): + generated_data, generator_inputs = input_values + output_values = random_tensor_pool.tensor_pool( + [generated_data] + generator_inputs, pool_size=5) + return output_values[0], output_values[1:] + + tensor_pool_fn = tensor_pool_fn_impl + else: + + def tensor_pool_fn_impl(input_values): + return random_tensor_pool.tensor_pool(input_values, pool_size=5) + + tensor_pool_fn = tensor_pool_fn_impl + loss = train.gan_loss(model, tensor_pool_fn=tensor_pool_fn) + self.assertTrue(isinstance(loss, namedtuples.GANLoss)) + + # Check values. + with self.test_session(use_gpu=True) as sess: + variables.global_variables_initializer().run() + for _ in range(10): + sess.run([loss.generator_loss, loss.discriminator_loss]) + + def test_tensor_pool_gan(self): + self._test_tensor_pool_helper(create_gan_model) + + def test_tensor_pool_callable_gan(self): + self._test_tensor_pool_helper(create_callable_gan_model) + + def test_tensor_pool_infogan(self): + self._test_tensor_pool_helper(create_infogan_model) + + def test_tensor_pool_callable_infogan(self): + self._test_tensor_pool_helper(create_callable_infogan_model) + + def test_tensor_pool_acgan(self): + self._test_tensor_pool_helper(create_acgan_model) + + def test_tensor_pool_callable_acgan(self): + self._test_tensor_pool_helper(create_callable_acgan_model) + def test_doesnt_crash_when_in_nested_scope(self): with variable_scope.variable_scope('outer_scope'): gan_model = train.gan_model( diff --git a/tensorflow/contrib/image/BUILD b/tensorflow/contrib/image/BUILD index 157e97d237021d95c935a6be66aa57842b97125c..54502cfc6eecb9d064ffde9773e97d893a24133a 100755 --- a/tensorflow/contrib/image/BUILD +++ b/tensorflow/contrib/image/BUILD @@ -9,6 +9,7 @@ package(default_visibility = ["//visibility:public"]) load( "//tensorflow:tensorflow.bzl", + "tf_cc_test", "tf_custom_op_library", "tf_gen_op_libs", "tf_gen_op_wrapper_py", @@ -106,10 +107,33 @@ tf_custom_op_library( name = "python/ops/_distort_image_ops.so", srcs = [ "kernels/adjust_hsv_in_yiq_op.cc", + "kernels/adjust_hsv_in_yiq_op.h", "ops/distort_image_ops.cc", ], + gpu_srcs = [ + "kernels/adjust_hsv_in_yiq_op_gpu.cu.cc", + "kernels/adjust_hsv_in_yiq_op.h", + ], deps = [ - "@protobuf_archive//:protobuf", + "//tensorflow/core/kernels:gpu_util_hdrs", + ], +) + +tf_cc_test( + name = "adjust_hsv_in_yiq_op_test", + size = "small", + srcs = [ + "kernels/adjust_hsv_in_yiq_op.h", + "kernels/adjust_hsv_in_yiq_op_test.cc", + ], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:ops_testutil", + "//tensorflow/core/kernels:ops_util", + "//third_party/eigen3", ], ) @@ -122,19 +146,6 @@ tf_gen_op_wrapper_py( deps = [":distort_image_ops_op_lib"], ) -cc_library( - name = "distort_image_ops_cc", - srcs = [ - "kernels/adjust_hsv_in_yiq_op.cc", - ], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//third_party/eigen3", - ], - alwayslink = 1, -) - py_library( name = "distort_image_py", srcs = [ diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc index f4962ed69dc68d4bad06ef29d7a167e0ba8ae044..478b716d88321101c971789f36c0ff8ecd3f418e 100644 --- a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc +++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc @@ -12,14 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif + +#include "tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h" #include -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/work_sharder.h" @@ -36,10 +37,10 @@ class AdjustHsvInYiqOpBase : public OpKernel { struct ComputeOptions { const Tensor* input = nullptr; + Tensor* output = nullptr; const Tensor* delta_h = nullptr; const Tensor* scale_s = nullptr; const Tensor* scale_v = nullptr; - Tensor* output = nullptr; int64 channel_count = 0; }; @@ -65,7 +66,7 @@ class AdjustHsvInYiqOpBase : public OpKernel { scale_v.shape().DebugString())); auto channels = input.dim_size(input.dims() - 1); OP_REQUIRES( - context, channels == 3, + context, channels == kChannelSize, errors::InvalidArgument("input must have 3 channels but instead has ", channels, " channels.")); @@ -101,53 +102,21 @@ class AdjustHsvInYiqOp : public AdjustHsvInYiqOpBase { const Tensor* input = options.input; Tensor* output = options.output; const int64 channel_count = options.channel_count; - static const int kChannelSize = 3; auto input_data = input->shaped({channel_count, kChannelSize}); const float delta_h = options.delta_h->scalar()(); const float scale_s = options.scale_s->scalar()(); const float scale_v = options.scale_v->scalar()(); auto output_data = output->shaped({channel_count, kChannelSize}); + float tranformation_matrix[kChannelSize * kChannelSize] = {0}; + internal::compute_tranformation_matrix( + delta_h, scale_s, scale_v, tranformation_matrix); const int kCostPerChannel = 10; const DeviceBase::CpuWorkerThreads& worker_threads = *context->device()->tensorflow_cpu_worker_threads(); Shard(worker_threads.num_threads, worker_threads.workers, channel_count, kCostPerChannel, - [channel_count, &input_data, &output_data, delta_h, scale_s, scale_v]( + [channel_count, &input_data, &output_data, &tranformation_matrix]( int64 start_channel, int64 end_channel) { - // Using approximate linear transfomation described in: - // https://beesbuzz.biz/code/hsv_color_transforms.php - /** Get the constants from sympy - from sympy import Matrix - from sympy.abc import u, w - # Projection matrix to YIQ. http://en.wikipedia.org/wiki/YIQ - tyiq = Matrix([[0.299, 0.587, 0.114], - [0.596, -0.274, -0.322], - [0.211, -0.523, 0.312]]) - # Hue rotation matrix in YIQ space. - hue_proj = Matrix(3,3, [v, 0, 0, 0, vsu, -vsw, 0, vsw, vsu]) - m = tyiq.inv() * hue_proj * tyiq - **/ - // TODO(huangyp): directly compute the projection matrix from tyiq. - static const float t[kChannelSize][kChannelSize][kChannelSize] = { - {{.299, .701, .16862179492229}, - {.587, -.587, .329804745287403}, - {.114, -.114, -0.498426540209694}}, - {{.299, -.299, -.327963394172371}, - {.587, .413, .0346106879248821}, - {.114, -.114, .293352706247489}}, - {{.299, -.299, 1.24646136576682}, - {.587, -.587, -1.04322888291964}, - {.114, .886, -.203232482847173}}}; - float m[kChannelSize][kChannelSize] = {{0.}}; - float su = scale_s * std::cos(delta_h); - float sw = scale_s * std::sin(delta_h); - for (int q_index = 0; q_index < kChannelSize; q_index++) { - for (int p_index = 0; p_index < kChannelSize; p_index++) { - m[q_index][p_index] = scale_v * (t[q_index][p_index][0] + - t[q_index][p_index][1] * su + - t[q_index][p_index][2] * sw); - } - } // Applying projection matrix to input RGB vectors. const float* p = input_data.data() + start_channel * kChannelSize; float* q = output_data.data() + start_channel * kChannelSize; @@ -155,7 +124,9 @@ class AdjustHsvInYiqOp : public AdjustHsvInYiqOpBase { for (int q_index = 0; q_index < kChannelSize; q_index++) { q[q_index] = 0; for (int p_index = 0; p_index < kChannelSize; p_index++) { - q[q_index] += m[q_index][p_index] * p[p_index]; + q[q_index] += + p[p_index] * + tranformation_matrix[q_index + kChannelSize * p_index]; } } p += kChannelSize; @@ -165,8 +136,33 @@ class AdjustHsvInYiqOp : public AdjustHsvInYiqOpBase { } }; -REGISTER_KERNEL_BUILDER(Name("AdjustHsvInYiq").Device(DEVICE_CPU), - AdjustHsvInYiqOp); +REGISTER_KERNEL_BUILDER( + Name("AdjustHsvInYiq").Device(DEVICE_CPU).TypeConstraint("T"), + AdjustHsvInYiqOp); + +#if GOOGLE_CUDA +template <> +class AdjustHsvInYiqOp : public AdjustHsvInYiqOpBase { + public: + explicit AdjustHsvInYiqOp(OpKernelConstruction* context) + : AdjustHsvInYiqOpBase(context) {} + + void DoCompute(OpKernelContext* ctx, const ComputeOptions& options) override { + const int64 number_of_elements = options.input->NumElements(); + if (number_of_elements <= 0) { + return; + } + const float* delta_h = options.delta_h->flat().data(); + const float* scale_s = options.scale_s->flat().data(); + const float* scale_v = options.scale_v->flat().data(); + functor::AdjustHsvInYiqGPU()(ctx, options.channel_count, options.input, + delta_h, scale_s, scale_v, options.output); + } +}; + +REGISTER_KERNEL_BUILDER( + Name("AdjustHsvInYiq").Device(DEVICE_GPU).TypeConstraint("T"), + AdjustHsvInYiqOp); +#endif -// TODO(huangyp): add the GPU kernel } // namespace tensorflow diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h new file mode 100644 index 0000000000000000000000000000000000000000..194ae2ba47456cac66c01989a78ab4ce607d1295 --- /dev/null +++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h @@ -0,0 +1,87 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_ + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA + +#include +#include "third_party/eigen3/Eigen/Core" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { + +static constexpr int kChannelSize = 3; + +namespace internal { + +template +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void compute_tranformation_matrix( + const float delta_h, const float scale_s, const float scale_v, + float* matrix) { + static_assert(MATRIX_SIZE == kChannelSize * kChannelSize, + "Size of matrix should be 9."); + // Projection matrix from RGB to YIQ. Numbers from wikipedia + // https://en.wikipedia.org/wiki/YIQ + Eigen::Matrix3f yiq; + /* clang-format off */ + yiq << 0.299, 0.587, 0.114, + 0.596, -0.274, -0.322, + 0.211, -0.523, 0.312; + Eigen::Matrix3f yiq_inverse; + yiq_inverse << 1, 0.95617069, 0.62143257, + 1, -0.2726886, -0.64681324, + 1, -1.103744, 1.70062309; + /* clang-format on */ + // Construct hsv linear transformation matrix in YIQ space. + // https://beesbuzz.biz/code/hsv_color_transforms.php + float vsu = scale_v * scale_s * std::cos(delta_h); + float vsw = scale_v * scale_s * std::sin(delta_h); + Eigen::Matrix3f hsv_transform; + /* clang-format off */ + hsv_transform << scale_v, 0, 0, + 0, vsu, -vsw, + 0, vsw, vsu; + /* clang-format on */ + // Compute final transformation matrix = inverse_yiq * hsv_transform * yiq + Eigen::Map> eigen_matrix(matrix); + eigen_matrix = yiq_inverse * hsv_transform * yiq; +} +} // namespace internal + +#if GOOGLE_CUDA +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { + +struct AdjustHsvInYiqGPU { + void operator()(OpKernelContext* ctx, int channel_count, + const Tensor* const input, const float* const delta_h, + const float* const scale_s, const float* const scale_v, + Tensor* const output); +}; + +} // namespace functor + +#endif // GOOGLE_CUDA + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_ diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..b71ff9cd507faac66b3a33d3c02ec9b5901d814a --- /dev/null +++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc @@ -0,0 +1,84 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h" +#include "tensorflow/core/kernels/gpu_utils.h" +#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" + +namespace tensorflow { + +namespace internal { + +__global__ void compute_tranformation_matrix_cuda(const float* const delta_h, + const float* const scale_s, + const float* const scale_v, + float* const matrix, + const int matrix_size) { + if (matrix_size == kChannelSize * kChannelSize) { + compute_tranformation_matrix( + *delta_h, *scale_s, *scale_v, matrix); + } +} +} // namespace internal + +namespace functor { + +void AdjustHsvInYiqGPU::operator()(OpKernelContext* ctx, int channel_count, + const Tensor* const input, + const float* const delta_h, + const float* const scale_s, + const float* const scale_v, + Tensor* const output) { + const uint64 m = channel_count; + const uint64 k = kChannelSize; + const uint64 n = kChannelSize; + auto* cu_stream = ctx->eigen_device().stream(); + OP_REQUIRES(ctx, cu_stream, errors::Internal("No GPU stream available.")); + Tensor tranformation_matrix; + OP_REQUIRES_OK(ctx, ctx->allocate_temp( + DT_FLOAT, TensorShape({kChannelSize * kChannelSize}), + &tranformation_matrix)); + // TODO(huangyp): It takes about 3.5 us to comute tranformation_matrix + // with one thread. Improve its performance if necessary. + internal::compute_tranformation_matrix_cuda<<<1, 1, 0, cu_stream>>>( + delta_h, scale_s, scale_v, tranformation_matrix.flat().data(), + tranformation_matrix.flat().size()); + // Call cuBlas C = A * B directly. + auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose; + auto a_ptr = + AsDeviceMemory(input->flat().data(), input->flat().size()); + auto b_ptr = AsDeviceMemory(tranformation_matrix.flat().data(), + tranformation_matrix.flat().size()); + auto c_ptr = AsDeviceMemory(output->flat().data(), + output->flat().size()); + auto* stream = ctx->op_device_context()->stream(); + OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available.")); + // TODO(huangyp): share/use autotune cublas algorithms in Matmul.op. + bool blas_launch_status = + stream + ->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f, b_ptr, n, + a_ptr, k, 0.0f, &c_ptr, n) + .ok(); + if (!blas_launch_status) { + ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, + ", n=", n, ", k=", k)); + } +} +} // namespace functor +} // namespace tensorflow +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_test.cc b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4cbbd277840133c9419f9ce3d945b7d099679dc0 --- /dev/null +++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_test.cc @@ -0,0 +1,48 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +class AdjustHsvInYiqOpTest : public OpsTestBase { + protected: +}; + +TEST_F(AdjustHsvInYiqOpTest, IdentiyTransformMatrix) { + Tensor matrix(allocator(), DT_FLOAT, TensorShape({9})); + internal::compute_tranformation_matrix<9>(0.0, 1.0, 1.0, + matrix.flat().data()); + Tensor expected(allocator(), DT_FLOAT, TensorShape({9})); + test::FillValues(&expected, {1, 0, 0, 0, 1, 0, 0, 0, 1}); + test::ExpectClose(matrix, expected); +} + +TEST_F(AdjustHsvInYiqOpTest, ScaleValueTransformMatrix) { + float scale_v = 2.3; + Tensor matrix(allocator(), DT_FLOAT, TensorShape({9})); + internal::compute_tranformation_matrix<9>(0.0, 1.0, scale_v, + matrix.flat().data()); + Tensor expected(allocator(), DT_FLOAT, TensorShape({9})); + test::FillValues(&expected, + {scale_v, 0, 0, 0, scale_v, 0, 0, 0, scale_v}); + test::ExpectClose(matrix, expected); +} + +} // end namespace tensorflow diff --git a/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc b/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc index 2b6799213827537f77deda4e052bb7ec16f46343..f8b56ab1c5400694b3aa8d4a0c19c7769aa8cbce 100755 --- a/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc +++ b/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc @@ -40,7 +40,7 @@ REGISTER_OP("SingleImageRandomDotStereograms") .Doc(R"doc( Outputs a single image random dot stereogram for export via encode_PNG/JPG OP. -Given the 2-D tensor 'depth_values' with encoded Z values, this operation will +Given the 2-D tensor 'depth_values' with encoded Z values, this operation will encode 3-D data into a 2-D image. The output of this Op is suitable for the encode_PNG/JPG ops. Be careful with image compression as this may corrupt the encode 3-D data witin the image. @@ -68,14 +68,14 @@ with open('picture_out.png', 'wb') as f: f.write(png) ``` -depth_values: Z values of data to encode into 'output_data_window' window, +depth_values: Z values of data to encode into 'output_data_window' window, lower values are further away {0.0 floor(far), 1.0 ceiling(near) after normalization}, must be 2-D tensor hidden_surface_removal: Activate hidden surface removal convergence_dots_size: Black dot size in pixels to help view converge image, drawn on bottom of image dots_per_inch: Output device in dots/inch eye_separation: Separation between eyes in inches mu: Depth of field, Fraction of viewing distance (eg. 1/3 = .3333) -normalize: Normalize input data to [0.0, 1.0] +normalize: Normalize input data to [0.0, 1.0] normalize_max: Fix MAX value for Normalization - if < MIN, autoscale normalize_min: Fix MIN value for Normalization - if > MAX, autoscale border_level: Value of border depth 0.0 {far} to 1.0 {near} diff --git a/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py index b85f19d29b79defa10493bdbaa4a1b237cb2a9ee..a495b58b7f6481d4cdedf73f23615d0390eb6a45 100644 --- a/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py +++ b/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py @@ -172,7 +172,7 @@ class AdjustValueInYiqTest(test_util.TensorFlowTestCase): raise AssertionError('Invalid test style: %s' % (test_style)) y_np = self._adjust_value_in_yiq_np(x_np, scale) y_tf = self._adjust_value_in_yiq_tf(x_np, scale) - self.assertAllClose(y_tf, y_np, rtol=2e-5, atol=1e-5) + self.assertAllClose(y_tf, y_np, rtol=2e-4, atol=1e-4) def test_invalid_shapes(self): x_np = np.random.rand(2, 3) * 255. @@ -237,7 +237,7 @@ class AdjustSaturationInYiqTest(test_util.TensorFlowTestCase): raise AssertionError('Invalid test style: %s' % (test_style)) y_baseline = self._adjust_saturation_in_yiq_np(x_np, scale) y_tf = self._adjust_saturation_in_yiq_tf(x_np, scale) - self.assertAllClose(y_tf, y_baseline, rtol=2e-5, atol=1e-5) + self.assertAllClose(y_tf, y_baseline, rtol=2e-4, atol=1e-4) def test_invalid_shapes(self): x_np = np.random.rand(2, 3) * 255. @@ -291,6 +291,9 @@ class AdjustHueInYiqBenchmark(test.Benchmark): def benchmark_adjust_hue_in_yiqCpuAll(self): self._benchmark_adjust_hue_in_yiq('/cpu:0', None) + def benchmark_adjust_hue_in_yiq_gpu_all(self): + self._benchmark_adjust_hue_in_yiq(test.gpu_device_name(), None) + class AdjustSaturationInYiqBenchmark(test.Benchmark): @@ -333,6 +336,9 @@ class AdjustSaturationInYiqBenchmark(test.Benchmark): def benchmark_adjust_saturation_in_yiq_cpu_all(self): self._benchmark_adjust_saturation_in_yiq('/cpu:0', None) + def benchmark_adjust_saturation_in_yiq_gpu_all(self): + self._benchmark_adjust_saturation_in_yiq(test.gpu_device_name(), None) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py b/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py index 5cccf26028ca6bf269dbc67a33075351edecb407..bb766e59d2cee648042cc08be466796d9233ad66 100755 --- a/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py +++ b/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py @@ -68,7 +68,7 @@ def single_image_random_dot_stereograms( ``` Args: - depth_values: A `Tensor`. Must be one of the following types: + depth_values: A `Tensor`. Must be one of the following types: `float64`, `float32`, `int64`, `int32`. Z values of data to encode into 'output_data_window' window, lower further away {0.0 floor(far), 1.0 ceiling(near) after norm}, must be 2-D tensor @@ -84,17 +84,17 @@ def single_image_random_dot_stereograms( mu: An optional `float`. Defaults to `0.3333`. Depth of field, Fraction of viewing distance (eg. 1/3 = 0.3333) normalize: An optional `bool`. Defaults to `True`. - Normalize input data to [0.0, 1.0] + Normalize input data to [0.0, 1.0] normalize_max: An optional `float`. Defaults to `-100`. Fix MAX value for Normalization (0.0) - if < MIN, autoscale normalize_min: An optional `float`. Defaults to `100`. Fix MIN value for Normalization (0.0) - if > MAX, autoscale border_level: An optional `float`. Defaults to `0`. - Value of bord in depth 0.0 {far} to 1.0 {near} + Value of bord in depth 0.0 {far} to 1.0 {near} number_colors: An optional `int`. Defaults to `256`. 2 (Black & White), 256 (grayscale), and Numbers > 256 (Full Color) are supported - output_image_shape: An optional `tf.TensorShape` or list of `ints`. + output_image_shape: An optional `tf.TensorShape` or list of `ints`. Defaults to shape `[1024, 768, 1]`. Defines output shape of returned image in '[X,Y, Channels]' 1-grayscale, 3 color; channels will be updated to 3 if number_colors > 256 diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD index 7d65ac9a43dd777baa020fe0453af65e69e6c509..95fba59e3c96ae3c69e0b154740785b0d2bcb3c9 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD +++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD @@ -16,6 +16,7 @@ py_test( "//tensorflow/contrib/kfac/python/ops:utils", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:init_ops", @@ -33,6 +34,7 @@ py_test( "//tensorflow/contrib/kfac/python/ops:fisher_factors", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:gradients", diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py index b52a7b52a7efd4292ad514c5a744c4da07082142..9b28c45c7263208d21b1514ae5f05b7e81e315a3 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.contrib.kfac.python.ops import estimator from tensorflow.contrib.kfac.python.ops import layer_collection as lc from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -33,6 +34,30 @@ from tensorflow.python.platform import test _ALL_ESTIMATION_MODES = ["gradients", "empirical", "curvature_prop", "exact"] +class DeviceContextGeneratorTest(test.TestCase): + + def testNoDevice(self): + device_context_generator = estimator._DeviceContextGenerator(None) + with ops.device("/device:CPU:0"): # This is what will be used + with device_context_generator(): # Does nothing + a = constant_op.constant([2.0], name="a") + self.assertEqual("/device:CPU:0", a.op.device) + + def testTwoDevices(self): + device_context_generator = estimator._DeviceContextGenerator( + ["/device:GPU:0", "/device:GPU:1"]) + with ops.device("/device:CPU:0"): # Will be over-ridden by the inner scopes + with device_context_generator(): + a = constant_op.constant([2.0], name="a") + with device_context_generator(): + b = constant_op.constant([2.0], name="b") + with device_context_generator(): + c = constant_op.constant([2.0], name="c") + self.assertEqual("/device:GPU:0", a.op.device) + self.assertEqual("/device:GPU:1", b.op.device) + self.assertEqual("/device:GPU:0", c.op.device) + + class EstimatorTest(test.TestCase): def setUp(self): diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py index 5f2b5c6cace9cd18f4cc5590ff55a9b39680a381..2d9b28185ce0db32d5cd7d84737fdf96e2c98851 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py @@ -40,6 +40,21 @@ def _make_psd(dim): return array_ops.constant(mat) +class UtilsTest(test.TestCase): + + def testComputePiTracenorm(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + left_factor = array_ops.diag([1., 2., 0., 1.]) + right_factor = array_ops.ones([2., 2.]) + + # pi is the sqrt of the left trace norm divided by the right trace norm + pi = fb._compute_pi_tracenorm(left_factor, right_factor) + + pi_val = sess.run(pi) + self.assertEqual(1., pi_val) + + class FullFBTest(test.TestCase): def testFullFBInitSingleTensor(self): @@ -301,8 +316,7 @@ class FullyConnectedDiagonalFB(test.TestCase): multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps( self.w, [self.inputs], [self.outputs], [self.output_grads]) multiply_result_small, multiply_inverse_result_small = ( - self.runFisherBlockOps(self.w, - np.split(self.inputs, 2), + self.runFisherBlockOps(self.w, np.split(self.inputs, 2), np.split(self.outputs, 2), np.split(self.output_grads, 2))) @@ -584,8 +598,7 @@ class ConvDiagonalFBTest(test.TestCase): multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps( self.w, [self.inputs], [self.outputs], [self.output_grads]) multiply_result_small, multiply_inverse_result_small = ( - self.runFisherBlockOps(self.w, - np.split(self.inputs, 2), + self.runFisherBlockOps(self.w, np.split(self.inputs, 2), np.split(self.outputs, 2), np.split(self.output_grads, 2))) @@ -608,8 +621,9 @@ class ConvDiagonalFBTest(test.TestCase): self.kernel_size, self.kernel_size, self.input_channels + 1, self.output_channels ]) - expected_result = (expected_result[:, :, 0:-1, :], np.reshape( - expected_result[:, :, -1, :], [self.output_channels])) + expected_result = (expected_result[:, :, 0:-1, :], + np.reshape(expected_result[:, :, -1, :], + [self.output_channels])) self.assertEqual(len(result), 2) self.assertAllClose(expected_result[0], result[0]) @@ -692,8 +706,8 @@ class ConvKFCBasicFBTest(test.TestCase): sess.run(block._input_factor.make_inverse_update_ops()) sess.run(block._output_factor.make_inverse_update_ops()) - vector = (np.arange(1, 15).reshape(7, 2).astype(np.float32), np.arange( - 2, 4).reshape(2, 1).astype(np.float32)) + vector = (np.arange(1, 15).reshape(7, 2).astype(np.float32), + np.arange(2, 4).reshape(2, 1).astype(np.float32)) output = block.multiply_inverse((array_ops.constant(vector[0]), array_ops.constant(vector[1]))) @@ -776,11 +790,50 @@ class ConvKFCBasicFBTest(test.TestCase): self.assertAllClose(output_flat, explicit) +class FullyConnectedSeriesFBTest(test.TestCase): + + def testFullyConnectedSeriesFBInit(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + inputs = array_ops.constant([1., 2.]) + outputs = array_ops.constant([3., 4.]) + block = fb.FullyConnectedSeriesFB( + lc.LayerCollection(), inputs=[inputs], outputs=[outputs]) + self.assertAllEqual([outputs], block.tensors_to_compute_grads()) + + def testInstantiateFactorsHasBias(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + inputs = array_ops.constant([[1., 2.], [3., 4.]]) + outputs = array_ops.constant([[3., 4.], [5., 6.]]) + block = fb.FullyConnectedSeriesFB( + lc.LayerCollection(), + inputs=[inputs], + outputs=[outputs], + has_bias=True) + grads = outputs**2 + block.instantiate_factors(((grads,),), 0.5) + + def testInstantiateFactorsNoBias(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + inputs = array_ops.constant([[1., 2.], [3., 4.]]) + outputs = array_ops.constant([[3., 4.], [5., 6.]]) + block = fb.FullyConnectedSeriesFB( + lc.LayerCollection(), + inputs=[inputs], + outputs=[outputs], + has_bias=False) + grads = outputs**2 + block.instantiate_factors(((grads,),), 0.5) + + def as_tensors(tensor_or_tuple): """Converts a potentially nested tuple of np.array to Tensors.""" if isinstance(tensor_or_tuple, (tuple, list)): return tuple(as_tensors(t) for t in tensor_or_tuple) return ops.convert_to_tensor(tensor_or_tuple) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py index fbb3d219139a4bc05253841a89e73645ef37dddd..70e56db055078bd4399b03e4d4a877e34249cc5e 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py @@ -22,6 +22,7 @@ import numpy as np import numpy.random as npr from tensorflow.contrib.kfac.python.ops import fisher_factors as ff +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops as tf_ops from tensorflow.python.framework import random_seed @@ -32,6 +33,25 @@ from tensorflow.python.ops import variables as tf_variables from tensorflow.python.platform import test +class MaybeColocateTest(test.TestCase): + + def testFalse(self): + with tf_ops.Graph().as_default(): + a = constant_op.constant([2.0], name='a') + with ff._maybe_colocate_with(a, False): + b = constant_op.constant(3.0, name='b') + self.assertEqual([b'loc:@a'], a.op.colocation_groups()) + self.assertEqual([b'loc:@b'], b.op.colocation_groups()) + + def testTrue(self): + with tf_ops.Graph().as_default(): + a = constant_op.constant([2.0], name='a') + with ff._maybe_colocate_with(a, True): + b = constant_op.constant(3.0, name='b') + self.assertEqual([b'loc:@a'], a.op.colocation_groups()) + self.assertEqual([b'loc:@a'], b.op.colocation_groups()) + + class FisherFactorTestingDummy(ff.FisherFactor): """Dummy class to test the non-abstract methods on ff.FisherFactor.""" @@ -47,12 +67,19 @@ class FisherFactorTestingDummy(ff.FisherFactor): def _num_sources(self): return 1 + @property + def _dtype(self): + return dtypes.float32 + def _compute_new_cov(self): raise NotImplementedError def instantiate_covariance(self): pass + def make_inverse_update_ops(self): + return [] + class InverseProvidingFactorTestingDummy(ff.InverseProvidingFactor): """Dummy class to test the non-abstract methods on ff.InverseProvidingFactor. @@ -74,6 +101,10 @@ class InverseProvidingFactorTestingDummy(ff.InverseProvidingFactor): def _num_sources(self): return 1 + @property + def _dtype(self): + return dtypes.float32 + def _compute_new_cov(self): raise NotImplementedError @@ -101,7 +132,7 @@ class NumericalUtilsTest(test.TestCase): normalizer = 10. x = npr.randn(100, 3) - cov = ff._compute_cov(array_ops.constant(x), normalizer) + cov = ff._compute_cov(array_ops.constant(x), normalizer=normalizer) np_cov = np.dot(x.T, x) / normalizer self.assertAllClose(sess.run(cov), np_cov) @@ -247,13 +278,13 @@ class InverseProvidingFactorTest(test.TestCase): for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1): factor.register_damped_inverse(1. / i) ops = factor.make_inverse_update_ops() - self.assertEqual(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD, len(ops)) + self.assertEqual(1, len(ops)) sess.run(tf_variables.global_variables_initializer()) new_invs = [] + sess.run(ops) for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1): # The inverse op will assign the damped inverse of cov to the inv var. - sess.run(ops[i - 1]) new_invs.append(sess.run(factor._inverses_by_damping[1. / i])) # We want to see that the new invs are all different from each other. for i in range(len(new_invs)): @@ -311,6 +342,16 @@ class FullFactorTest(test.TestCase): factor = ff.FullFactor((tensor,), 32) self.assertEqual([6, 6], factor.get_cov().get_shape().as_list()) + def testFullFactorInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') + factor = ff.FullFactor((tensor,), 32) + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([6, 6], cov.get_shape().as_list()) + def testMakeCovarianceUpdateOp(self): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) @@ -331,6 +372,16 @@ class NaiveDiagonalFactorTest(test.TestCase): factor = ff.NaiveDiagonalFactor((tensor,), 32) self.assertEqual([6, 1], factor.get_cov().get_shape().as_list()) + def testNaiveDiagonalFactorInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') + factor = ff.NaiveDiagonalFactor((tensor,), 32) + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([6, 1], cov.get_shape().as_list()) + def testMakeCovarianceUpdateOp(self): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) @@ -344,18 +395,25 @@ class NaiveDiagonalFactorTest(test.TestCase): class FullyConnectedKroneckerFactorTest(test.TestCase): - def _testFullyConnectedKroneckerFactorInit(self, has_bias, final_shape): + def _testFullyConnectedKroneckerFactorInit(self, + has_bias, + final_shape, + dtype=dtypes.float32_ref): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), name='a/b/c') + tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') factor = ff.FullyConnectedKroneckerFactor((tensor,), has_bias=has_bias) - self.assertEqual(final_shape, factor.get_cov().get_shape().as_list()) + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual(final_shape, cov.get_shape().as_list()) def testFullyConnectedKroneckerFactorInitNoBias(self): - self._testFullyConnectedKroneckerFactorInit(False, [3, 3]) + for dtype in (dtypes.float32_ref, dtypes.float64_ref): + self._testFullyConnectedKroneckerFactorInit(False, [3, 3], dtype=dtype) def testFullyConnectedKroneckerFactorInitWithBias(self): - self._testFullyConnectedKroneckerFactorInit(True, [4, 4]) + for dtype in (dtypes.float32_ref, dtypes.float64_ref): + self._testFullyConnectedKroneckerFactorInit(True, [4, 4], dtype=dtype) def testMakeCovarianceUpdateOpWithBias(self): with tf_ops.Graph().as_default(), self.test_session() as sess: @@ -398,6 +456,18 @@ class ConvInputKroneckerFactorTest(test.TestCase): self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1], factor.get_cov().get_shape().as_list()) + def testConvInputKroneckerFactorInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') + factor = ff.ConvInputKroneckerFactor( + tensor, (1, 2, 3, 4), 3, 2, has_bias=True) + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1], + cov.get_shape().as_list()) + def testMakeCovarianceUpdateOpWithBias(self): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) @@ -433,6 +503,16 @@ class ConvOutputKroneckerFactorTest(test.TestCase): factor = ff.ConvOutputKroneckerFactor((tensor,)) self.assertEqual([5, 5], factor.get_cov().get_shape().as_list()) + def testConvOutputKroneckerFactorInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3, 4, 5), dtype=dtype, name='a/b/c') + factor = ff.ConvOutputKroneckerFactor((tensor,)) + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([5, 5], cov.get_shape().as_list()) + def testConvOutputKroneckerFactorInitNotEnoughDims(self): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) @@ -451,5 +531,49 @@ class ConvOutputKroneckerFactorTest(test.TestCase): self.assertAllClose([[43, 46.5], [46.5, 51.5]], new_cov) +class FullyConnectedMultiKFTest(test.TestCase): + + def testFullyConnectedMultiKFInit(self): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), name='a/b/c') + tensor_list = [tensor] + factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=False) + self.assertEqual([3, 3], factor.get_cov().get_shape().as_list()) + + def testFullyConnectedMultiKFInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') + tensor_list = [tensor] + factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=False) + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([3, 3], cov.get_shape().as_list()) + + def testMakeCovarianceUpdateOpWithBias(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') + tensor_list = [tensor] + factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=True) + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(.5)) + self.assertAllClose([[3, 3.5, 1], [3.5, 5.5, 1.5], [1, 1.5, 1]], new_cov) + + def testMakeCovarianceUpdateOpNoBias(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') + tensor_list = [tensor] + factor = ff.FullyConnectedMultiKF((tensor_list,)) + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(.5)) + self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py index c5ad90d1dc7807ae5214523d4a443fb2430d202f..b8ccbeadd0a9d69edb41fef50e3edb090457adf2 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py @@ -128,8 +128,9 @@ class LayerCollectionTest(test.TestCase): key = array_ops.constant(1) lc.register_fully_connected(key, array_ops.constant(2), array_ops.constant(3)) - with self.assertRaises(ValueError): + with self.assertRaises(ValueError) as cm: lc.register_generic(key, 16) + self.assertIn('already in LayerCollection', str(cm.exception)) def testRegisterSingleParamNotRegistered(self): x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) @@ -144,16 +145,18 @@ class LayerCollectionTest(test.TestCase): x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) lc = layer_collection.LayerCollection() lc.fisher_blocks = {x: '1'} - with self.assertRaises(ValueError): + with self.assertRaises(ValueError) as cm: lc.register_block(x, 'foo') + self.assertIn('already in LayerCollection', str(cm.exception)) def testRegisterSingleParamRegisteredInTuple(self): x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) lc = layer_collection.LayerCollection() lc.fisher_blocks = {(x, y): '1'} - lc.register_block(x, 'foo') - self.assertEqual(set(['1']), set(lc.get_blocks())) + with self.assertRaises(ValueError) as cm: + lc.register_block(x, 'foo') + self.assertIn('was already registered', str(cm.exception)) def testRegisterTupleParamNotRegistered(self): x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) @@ -173,8 +176,9 @@ class LayerCollectionTest(test.TestCase): lc = layer_collection.LayerCollection() lc.fisher_blocks = {(x, y): '1'} - with self.assertRaises(ValueError): + with self.assertRaises(ValueError) as cm: lc.register_block((x, y), 'foo') + self.assertIn('already in LayerCollection', str(cm.exception)) def testRegisterTupleParamRegisteredInSuperset(self): x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) @@ -183,8 +187,9 @@ class LayerCollectionTest(test.TestCase): lc = layer_collection.LayerCollection() lc.fisher_blocks = {(x, y, z): '1'} - lc.register_block((x, y), 'foo') - self.assertEqual(set(['1']), set(lc.get_blocks())) + with self.assertRaises(ValueError) as cm: + lc.register_block((x, y), 'foo') + self.assertIn('was already registered', str(cm.exception)) def testRegisterTupleParamSomeRegistered(self): x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) @@ -193,10 +198,9 @@ class LayerCollectionTest(test.TestCase): lc = layer_collection.LayerCollection() lc.fisher_blocks = {x: MockFisherBlock('1'), z: MockFisherBlock('2')} - lc.register_block((x, y), MockFisherBlock('foo')) - self.assertEqual( - set([MockFisherBlock('2'), MockFisherBlock('foo')]), set( - lc.get_blocks())) + with self.assertRaises(ValueError) as cm: + lc.register_block((x, y), MockFisherBlock('foo')) + self.assertIn('was already registered', str(cm.exception)) def testRegisterTupleVarSomeRegisteredInOtherTuples(self): x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) @@ -206,8 +210,9 @@ class LayerCollectionTest(test.TestCase): lc = layer_collection.LayerCollection() lc.fisher_blocks = {(x, z): '1', (z, w): '2'} - with self.assertRaises(ValueError): + with self.assertRaises(ValueError) as cm: lc.register_block((x, y), 'foo') + self.assertIn('was already registered', str(cm.exception)) def testRegisterCategoricalPredictiveDistribution(self): with ops.Graph().as_default(), self.test_session() as sess: @@ -427,6 +432,23 @@ class LayerCollectionTest(test.TestCase): self.ensureLayerReuseWorks(register_fn) + def testReuseWithInvalidRegistration(self): + """Invalid registrations shouldn't overwrite existing blocks.""" + with ops.Graph().as_default(): + inputs = array_ops.ones([2, 5, 5, 10]) + outputs = array_ops.zeros([2, 5, 5, 3]) + w = variable_scope.get_variable('w', [1, 1, 10, 3]) + b = variable_scope.get_variable('b', [3]) + lc = layer_collection.LayerCollection() + lc.register_fully_connected(w, inputs, outputs) + self.assertEqual(lc.fisher_blocks[w].num_registered_minibatches, 1) + with self.assertRaises(KeyError): + lc.register_fully_connected((w, b), inputs, outputs, reuse=True) + self.assertNotIn((w, b), lc.fisher_blocks) + self.assertEqual(lc.fisher_blocks[w].num_registered_minibatches, 1) + lc.register_fully_connected(w, inputs, outputs, reuse=True) + self.assertEqual(lc.fisher_blocks[w].num_registered_minibatches, 2) + def testMakeOrGetFactor(self): with ops.Graph().as_default(): random_seed.set_random_seed(200) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py index 55fe38e3e9aab2dbd70a45cdc8fa0c208b036db0..d255a6e7160386d8eb6fca00765eea8a318f4eaa 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py @@ -222,18 +222,6 @@ class UtilsTest(test.TestCase): self.assertAllClose(b, np.array([4., 5.])) self.assertAllClose(c, np.array([[6.], [7.], [8.], [9.]])) - def testComputePi(self): - with ops.Graph().as_default(), self.test_session() as sess: - random_seed.set_random_seed(200) - left_factor = array_ops.diag([1., 2., 0., 1.]) - right_factor = array_ops.ones([2., 2.]) - - # pi is the sqrt of the left trace norm divided by the right trace norm - pi = utils.compute_pi(left_factor, right_factor) - - pi_val = sess.run(pi) - self.assertEqual(1., pi_val) - def testPosDefInvCholesky(self): with ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD index de4b8920b849dbf2117657de6e7c26f94f4d0363..3d731c7bc206d6f168e9b8f29b66bf4f1dbe8542 100644 --- a/tensorflow/contrib/kfac/python/ops/BUILD +++ b/tensorflow/contrib/kfac/python/ops/BUILD @@ -38,6 +38,7 @@ py_library( ":utils", "//tensorflow/python:array_ops", "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", "//tensorflow/python:special_math_ops", @@ -171,6 +172,7 @@ py_library( deps = [ ":utils", "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", "//tensorflow/python:gradients", "//tensorflow/python:util", "//third_party/py/numpy", diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py index ce4e776324bbde1b8f214d89daa876032d8a21ff..5e1680967c184bf19f2a2578219db07a48264dc9 100644 --- a/tensorflow/contrib/kfac/python/ops/estimator.py +++ b/tensorflow/contrib/kfac/python/ops/estimator.py @@ -18,16 +18,53 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import math +import contextlib +import itertools import numpy as np from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import ops as tf_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.util import nest +class _DeviceContextGenerator(object): + """Class for generating device contexts in a round-robin fashion.""" + + def __init__(self, devices): + """Creates a _DeviceContextGenerator object. + + Example usage: + + ```python + dcg = _DeviceContextGenerator(['/gpu:0', 'gpu:1']) + with dcg(): + # All operations in this context will be placed on GPU 0 + ... + with dcg(): + # All operations in this context will be placed on GPU 1 + ... + ``` + + Args: + devices: An iterable of device strings (or None). Successive calls to + __call__ will give contexts which place devices on these devices in + a round-robin fashion. + """ + self._cycle = None if devices is None else itertools.cycle(devices) + + @contextlib.contextmanager + def __call__(self): + """Returns a context manager specifying the default device.""" + if self._cycle is None: + yield + else: + with tf_ops.device(next(self._cycle)): + yield + + class FisherEstimator(object): """Fisher estimator class supporting various approximations of the Fisher.""" @@ -36,7 +73,10 @@ class FisherEstimator(object): cov_ema_decay, damping, layer_collection, - estimation_mode="gradients"): + estimation_mode="gradients", + colocate_gradients_with_ops=False, + cov_devices=None, + inv_devices=None): """Create a FisherEstimator object. Args: @@ -54,7 +94,7 @@ class FisherEstimator(object): blocks, kronecker factors, and losses associated with the graph. estimation_mode: The type of estimator to use for the Fishers. Can be - 'gradients', 'empirical', 'curvature_propagation', or 'exact'. + 'gradients', 'empirical', 'curvature_prop', or 'exact'. (Default: 'gradients'). 'gradients' is the basic estimation approach from the original K-FAC paper. 'empirical' computes the 'empirical' Fisher information matrix (which uses the data's distribution for the @@ -69,6 +109,14 @@ class FisherEstimator(object): for each coordinate of the output instead of using 1/-1 vectors. It is more expensive to compute than the other three options by a factor equal to the output dimension, roughly speaking. + colocate_gradients_with_ops: Whether we should request gradients be + colocated with their respective ops. + cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance + computations will be placed on these devices in a round-robin fashion. + Can be None, which means that no devices are specified. + inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion + computations will be placed on these devices in a round-robin fashion. + Can be None, which means that no devices are specified. Raises: ValueError: If no losses have been registered with layer_collection. @@ -79,13 +127,19 @@ class FisherEstimator(object): self._estimation_mode = estimation_mode self._layers = layer_collection self._layers.create_subgraph() - self._check_registration(variables) + self._layers.check_registration(variables) self._gradient_fns = { "gradients": self._get_grads_lists_gradients, "empirical": self._get_grads_lists_empirical, "curvature_prop": self._get_grads_lists_curvature_prop, "exact": self._get_grads_lists_exact } + self._colocate_gradients_with_ops = colocate_gradients_with_ops + self._cov_device_context_generator = _DeviceContextGenerator(cov_devices) + if inv_devices == cov_devices: + self._inv_device_context_generator = self._cov_device_context_generator + else: + self._inv_device_context_generator = _DeviceContextGenerator(inv_devices) setup = self._setup(cov_ema_decay) self.cov_update_op, self.inv_update_op, self.inv_updates_dict = setup @@ -148,49 +202,6 @@ class FisherEstimator(object): return self._apply_transformation(vecs_and_vars, lambda fb, vec: fb.multiply(vec)) - def _check_registration(self, variables): - """Checks that all variable uses have been registered properly. - - Args: - variables: List of variables. - - Raises: - ValueError: If any registered variables are not included in the list. - ValueError: If any variable in the list is not registered. - ValueError: If any variable in the list is registered with the wrong - number of "uses" in the subgraph recorded (vs the number of times that - variable is actually used in the subgraph). - """ - # Note that overlapping parameters (i.e. those that share variables) will - # be caught by layer_collection.LayerParametersDict during registration. - - reg_use_map = self._layers.get_use_count_map() - - error_messages = [] - - for var in variables: - total_uses = self._layers.subgraph.variable_uses(var) - reg_uses = reg_use_map[var] - - if reg_uses == 0: - error_messages.append("Variable {} not registered.".format(var)) - elif (not math.isinf(reg_uses)) and reg_uses != total_uses: - error_messages.append( - "Variable {} registered with wrong number of uses ({} " - "registrations vs {} uses).".format(var, reg_uses, total_uses)) - - num_get_vars = len(reg_use_map) - - if num_get_vars > len(variables): - error_messages.append("{} registered variables were not included in list." - .format(num_get_vars - len(variables))) - - if error_messages: - error_messages = [ - "Found the following errors with variable registration:" - ] + error_messages - raise ValueError("\n\t".join(error_messages)) - def _setup(self, cov_ema_decay): """Sets up the various operations. @@ -219,8 +230,13 @@ class FisherEstimator(object): raise ValueError("Unrecognized value {} for estimation_mode.".format( self._estimation_mode)) + # TODO(b/68033310): This loop round-robins the "concat" operations which + # gather the inputs for the cov_updates. In future, we might do these + # computations locally then communicate the results, which would require a + # modification to this code. for grads_list, fb in zip(grads_lists, fisher_blocks_list): - fb.instantiate_factors(grads_list, self.damping) + with self._cov_device_context_generator(): + fb.instantiate_factors(grads_list, self.damping) cov_updates = [ factor.make_covariance_update_op(cov_ema_decay) @@ -233,18 +249,23 @@ class FisherEstimator(object): def _get_all_inverse_update_ops(self): for factor in self._layers.get_factors(): - for op in factor.make_inverse_update_ops(): - yield op + with self._inv_device_context_generator(): + for op in factor.make_inverse_update_ops(): + yield op def _get_grads_lists_gradients(self, tensors): - grads_flat = gradients_impl.gradients(self._layers.total_sampled_loss(), - nest.flatten(tensors)) + grads_flat = gradients_impl.gradients( + self._layers.total_sampled_loss(), + nest.flatten(tensors), + colocate_gradients_with_ops=self._colocate_gradients_with_ops) grads_all = nest.pack_sequence_as(tensors, grads_flat) return tuple((grad,) for grad in grads_all) def _get_grads_lists_empirical(self, tensors): - grads_flat = gradients_impl.gradients(self._layers.total_loss(), - nest.flatten(tensors)) + grads_flat = gradients_impl.gradients( + self._layers.total_loss(), + nest.flatten(tensors), + colocate_gradients_with_ops=self._colocate_gradients_with_ops) grads_all = nest.pack_sequence_as(tensors, grads_flat) return tuple((grad,) for grad in grads_all) @@ -262,11 +283,13 @@ class FisherEstimator(object): grads_flat = gradients_impl.gradients( nest.flatten(loss_inputs), nest.flatten(tensors), - grad_ys=nest.flatten(transformed_random_signs)) + grad_ys=nest.flatten(transformed_random_signs), + colocate_gradients_with_ops=self._colocate_gradients_with_ops) grads_all = nest.pack_sequence_as(tensors, grads_flat) return tuple((grad,) for grad in grads_all) def _get_grads_lists_exact(self, tensors): + """No docstring required.""" # Loop over all coordinates of all losses. grads_all = [] for loss in self._layers.losses: @@ -274,6 +297,9 @@ class FisherEstimator(object): transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot( index) grads_flat = gradients_impl.gradients( - loss.inputs, nest.flatten(tensors), grad_ys=transformed_one_hot) + loss.inputs, + nest.flatten(tensors), + grad_ys=transformed_one_hot, + colocate_gradients_with_ops=self._colocate_gradients_with_ops) grads_all.append(nest.pack_sequence_as(tensors, grads_flat)) return zip(*grads_all) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py index e822a1213a4132522be8031401609c78572cb1a6..1ccb9e040f2bb6bcfd217886918abd40e3cc1cfb 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -38,6 +38,7 @@ from __future__ import division from __future__ import print_function import abc +import enum # pylint: disable=g-bad-import-order import six @@ -52,14 +53,54 @@ from tensorflow.python.ops import math_ops # damping /= num_replications ** NORMALIZE_DAMPING_POWER NORMALIZE_DAMPING_POWER = 1.0 +# Methods for adjusting damping for FisherBlocks. See +# _compute_pi_adjusted_damping() for details. +PI_OFF_NAME = "off" +PI_TRACENORM_NAME = "tracenorm" +PI_TYPE = PI_TRACENORM_NAME -def set_global_constants(normalize_damping_power=None): + +def set_global_constants(normalize_damping_power=None, pi_type=None): """Sets various global constants used by the classes in this module.""" global NORMALIZE_DAMPING_POWER + global PI_TYPE if normalize_damping_power is not None: NORMALIZE_DAMPING_POWER = normalize_damping_power + if pi_type is not None: + PI_TYPE = pi_type + + +def _compute_pi_tracenorm(left_cov, right_cov): + """Computes the scalar constant pi for Tikhonov regularization/damping. + + pi = sqrt( (trace(A) / dim(A)) / (trace(B) / dim(B)) ) + See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details. + + Args: + left_cov: The left Kronecker factor "covariance". + right_cov: The right Kronecker factor "covariance". + + Returns: + The computed scalar constant pi for these Kronecker Factors (as a Tensor). + """ + # Instead of dividing by the dim of the norm, we multiply by the dim of the + # other norm. This works out the same in the ratio. + left_norm = math_ops.trace(left_cov) * right_cov.shape.as_list()[0] + right_norm = math_ops.trace(right_cov) * left_cov.shape.as_list()[0] + return math_ops.sqrt(left_norm / right_norm) + + +def _compute_pi_adjusted_damping(left_cov, right_cov, damping): + + if PI_TYPE == PI_TRACENORM_NAME: + pi = _compute_pi_tracenorm(left_cov, right_cov) + return (damping * pi, damping / pi) + + elif PI_TYPE == PI_OFF_NAME: + return (damping, damping) + @six.add_metaclass(abc.ABCMeta) class FisherBlock(object): @@ -153,7 +194,7 @@ class FullFB(FisherBlock): self._factor.register_damped_inverse(damping) def multiply_inverse(self, vector): - inverse = self._factor.get_inverse(self._damping) + inverse = self._factor.get_damped_inverse(self._damping) out_flat = math_ops.matmul(inverse, utils.tensors_to_column(vector)) return utils.column_to_tensors(vector, out_flat) @@ -411,7 +452,7 @@ class ConvDiagonalFB(FisherBlock): (self._strides[1] * self._strides[2])) if NORMALIZE_DAMPING_POWER: - damping /= self._num_locations ** NORMALIZE_DAMPING_POWER + damping /= self._num_locations**NORMALIZE_DAMPING_POWER self._damping = damping self._factor = self._layer_collection.make_or_get_factor( @@ -465,11 +506,10 @@ class KroneckerProductFB(FisherBlock): Args: damping: The base damping factor (float or Tensor) for the damped inverse. """ - pi = utils.compute_pi(self._input_factor.get_cov(), - self._output_factor.get_cov()) - - self._input_damping = (damping**0.5) * pi - self._output_damping = (damping**0.5) / pi + self._input_damping, self._output_damping = _compute_pi_adjusted_damping( + self._input_factor.get_cov(), + self._output_factor.get_cov(), + damping**0.5) self._input_factor.register_damped_inverse(self._input_damping) self._output_factor.register_damped_inverse(self._output_damping) @@ -487,8 +527,9 @@ class KroneckerProductFB(FisherBlock): return 1.0 def multiply_inverse(self, vector): - left_factor_inv = self._input_factor.get_inverse(self._input_damping) - right_factor_inv = self._output_factor.get_inverse(self._output_damping) + left_factor_inv = self._input_factor.get_damped_inverse(self._input_damping) + right_factor_inv = self._output_factor.get_damped_inverse( + self._output_damping) reshaped_vector = utils.layer_params_to_mat2d(vector) reshaped_out = math_ops.matmul(left_factor_inv, math_ops.matmul(reshaped_vector, @@ -720,3 +761,260 @@ def _concat_along_batch_dim(tensor_list): def _num_conv_locations(input_shape, strides): """Returns the number of locations a Conv kernel is applied to.""" return input_shape[1] * input_shape[2] // (strides[1] * strides[2]) + + +class FullyConnectedMultiIndepFB(KroneckerProductFB): + """FisherBlock for fully-connected layers that share parameters. + """ + + def __init__(self, layer_collection, inputs, outputs, has_bias=False): + """Creates a FullyConnectedMultiIndepFB block. + + Args: + layer_collection: LayerCollection instance. + inputs: list or tuple of Tensors. Each Tensor has shape [batch_size, + inputs_size]. + outputs: list or tuple of Tensors. Each Tensor has shape [batch_size, + outputs_size]. + has_bias: bool. If True, estimates Fisher with respect to a bias + parameter as well as the layer's parameters. + """ + + assert len(inputs) == len(outputs) + # We need to make sure inputs and outputs are tuples and not lists so that + # they get hashed by layer_collection.make_or_get_factor properly. + self._inputs = tuple(inputs) + self._outputs = tuple(outputs) + self._has_bias = has_bias + self._num_uses = len(inputs) + + super(FullyConnectedMultiIndepFB, self).__init__(layer_collection) + + @property + def num_registered_minibatches(self): + # TODO(b/69411207): Add support for registering additional minibatches. + return 1 + + def instantiate_factors(self, grads_list, damping): + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, + ((self._inputs,), self._has_bias)) + + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, (grads_list,)) + + if NORMALIZE_DAMPING_POWER: + damping /= self._num_uses**NORMALIZE_DAMPING_POWER + + self._register_damped_input_and_output_inverses(damping) + + @property + def _renorm_coeff(self): + return self._num_uses + + def tensors_to_compute_grads(self): + return self._outputs + + def num_inputs(self): + return len(self._inputs) + + +class SeriesFBApproximation(enum.IntEnum): + """See FullyConnectedSeriesFB.__init__ for description and usage.""" + option1 = 1 + option2 = 2 + + +class FullyConnectedSeriesFB(FisherBlock): + """FisherBlock for fully-connected layers that share parameters across time. + + See the following preprint for details: + https://openreview.net/pdf?id=HyMTkQZAb + + See the end of the appendix of the paper for a pseudo-code of the + algorithm being implemented by multiply_inverse here. Note that we are + using pre-computed versions of certain matrix-matrix products to speed + things up. This is explicitly explained wherever it is done. + """ + + def __init__(self, + layer_collection, + inputs, + outputs, + has_bias=False, + option=SeriesFBApproximation.option2): + """Constructs a new `FullyConnectedSeriesFB`. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + inputs: List of tensors of shape [batch_size, input_size]. + Inputs to the layer. + outputs: List of tensors of shape [batch_size, input_size]. + Outputs of the layer (before activations). + has_bias: Whether the layer includes a bias parameter. + option: A `SeriesFBApproximation` specifying the simplifying assumption + to be used in this block. `option1` approximates the cross-covariance + over time as a symmetric matrix, while `option2` makes + the assumption that training sequences are infinitely long. See section + 3.5 of the paper for more details. + """ + + assert len(inputs) == len(outputs) + # We need to make sure inputs and outputs are tuples and not lists so that + # they get hashed by layer_collection.make_or_get_factor properly. + self._inputs = tuple(inputs) + self._outputs = tuple(outputs) + self._has_bias = has_bias + self._num_timesteps = len(inputs) + self._option = option + + super(FullyConnectedSeriesFB, self).__init__(layer_collection) + + @property + def num_registered_minibatches(self): + # TODO(b/69411207): Add support for registering additional minibatches. + return 1 + + def instantiate_factors(self, grads_list, damping): + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, ((self._inputs,), self._has_bias)) + + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, (grads_list,)) + + if NORMALIZE_DAMPING_POWER: + damping /= self._num_timesteps**NORMALIZE_DAMPING_POWER + + self._damping_input, self._damping_output = _compute_pi_adjusted_damping( + self._input_factor.get_cov(), + self._output_factor.get_cov(), + damping**0.5) + + if self._option == SeriesFBApproximation.option1: + self._input_factor.register_option1quants(self._damping_input) + self._output_factor.register_option1quants(self._damping_output) + elif self._option == SeriesFBApproximation.option2: + self._input_factor.register_option2quants(self._damping_input) + self._output_factor.register_option2quants(self._damping_output) + else: + raise ValueError( + "Unrecognized FullyConnectedSeriesFB approximation: {}".format( + self._option)) + + def multiply_inverse(self, vector): + # pylint: disable=invalid-name + + Z = utils.layer_params_to_mat2d(vector) + + # Derivations were done for "batch_dim==1" case so we need to convert to + # that orientation: + Z = array_ops.transpose(Z) + + if self._option == SeriesFBApproximation.option1: + + # Note that L_A = A0^(-1/2) * U_A and L_G = G0^(-1/2) * U_G. + L_A, psi_A = self._input_factor.get_option1quants(self._damping_input) + L_G, psi_G = self._output_factor.get_option1quants(self._damping_output) + + def gamma(x): + # We are assuming that each case has the same number of time-steps. + # If this stops being the case one shouldn't simply replace this T + # with its average value. Instead, one needs to go back to the + # definition of the gamma function from the paper. + T = self._num_timesteps + return (1 - x)**2 / (T * (1 - x**2) - 2 * x * (1 - x**T)) + + # Y = gamma( psi_G*psi_A^T ) (computed element-wise) + # Even though Y is Z-independent we are recomputing it from the psi's + # each since Y depends on both A and G quantities, and it is relatively + # cheap to compute. + Y = gamma(array_ops.reshape(psi_G, [int(psi_G.shape[0]), -1]) * psi_A) + + # Z = L_G^T * Z * L_A + # This is equivalent to the following computation from the original + # pseudo-code: + # Z = G0^(-1/2) * Z * A0^(-1/2) + # Z = U_G^T * Z * U_A + Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A), transpose_a=True) + + # Z = Z .* Y + Z *= Y + + # Z = L_G * Z * L_A^T + # This is equivalent to the following computation from the original + # pseudo-code: + # Z = U_G * Z * U_A^T + # Z = G0^(-1/2) * Z * A0^(-1/2) + Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A, transpose_b=True)) + + elif self._option == SeriesFBApproximation.option2: + + # Note that P_A = A_1^T * A_0^(-1) and P_G = G_1^T * G_0^(-1), + # and K_A = A_0^(-1/2) * E_A and K_G = G_0^(-1/2) * E_G. + P_A, K_A, mu_A = self._input_factor.get_option2quants(self._damping_input) + P_G, K_G, mu_G = self._output_factor.get_option2quants( + self._damping_output) + + # Our approach differs superficially from the pseudo-code in the paper + # in order to reduce the total number of matrix-matrix multiplies. + # In particular, the first three computations in the pseudo code are + # Z = G0^(-1/2) * Z * A0^(-1/2) + # Z = Z - hPsi_G^T * Z * hPsi_A + # Z = E_G^T * Z * E_A + # Noting that hPsi = C0^(-1/2) * C1 * C0^(-1/2), so that + # C0^(-1/2) * hPsi = C0^(-1) * C1 * C0^(-1/2) = P^T * C0^(-1/2) + # the entire computation can be written as + # Z = E_G^T * (G0^(-1/2) * Z * A0^(-1/2) + # - hPsi_G^T * G0^(-1/2) * Z * A0^(-1/2) * hPsi_A) * E_A + # = E_G^T * (G0^(-1/2) * Z * A0^(-1/2) + # - G0^(-1/2) * P_G * Z * P_A^T * A0^(-1/2)) * E_A + # = E_G^T * G0^(-1/2) * Z * A0^(-1/2) * E_A + # - E_G^T* G0^(-1/2) * P_G * Z * P_A^T * A0^(-1/2) * E_A + # = K_G^T * Z * K_A - K_G^T * P_G * Z * P_A^T * K_A + # This final expression is computed by the following two lines: + # Z = Z - P_G * Z * P_A^T + Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A, transpose_b=True)) + # Z = K_G^T * Z * K_A + Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A), transpose_a=True) + + # Z = Z ./ (1*1^T - mu_G*mu_A^T) + # Be careful with the outer product. We don't want to accidentally + # make it an inner-product instead. + tmp = 1.0 - array_ops.reshape(mu_G, [int(mu_G.shape[0]), -1]) * mu_A + # Prevent some numerical issues by setting any 0.0 eigs to 1.0 + tmp += 1.0 * math_ops.cast(math_ops.equal(tmp, 0.0), dtype=tmp.dtype) + Z /= tmp + + # We now perform the transpose/reverse version of the operations + # derived above, whose derivation from the original pseudo-code is + # analgous. + # Z = K_G * Z * K_A^T + Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A, transpose_b=True)) + + # Z = Z - P_G^T * Z * P_A + Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A), transpose_a=True) + + # Z = normalize (1/E[T]) * Z + # Note that this normalization is done because we compute the statistics + # by averaging, not summing, over time. (And the gradient is presumably + # summed over time, not averaged, and thus their scales are different.) + Z /= math_ops.cast(self._num_timesteps, Z.dtype) + + # Convert back to the "batch_dim==0" orientation. + Z = array_ops.transpose(Z) + + return utils.mat2d_to_layer_params(vector, Z) + + # pylint: enable=invalid-name + + def multiply(self, vector): + raise NotImplementedError + + def tensors_to_compute_grads(self): + return self._outputs + + def num_inputs(self): + return len(self._inputs) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py index 4e36813369e69de1d6f13ddb00566bda912244f6..5a6d1a93ff217c3922f45a047b4d548086ac5258 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import abc +import contextlib import numpy as np import six @@ -26,6 +27,8 @@ import six from tensorflow.contrib.kfac.python.ops import utils from tensorflow.python.framework import ops as tf_ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import special_math_ops @@ -50,7 +53,22 @@ EIGENVALUE_DECOMPOSITION_THRESHOLD = 2 EIGENVALUE_CLIPPING_THRESHOLD = 0.0 -def set_global_constants(init_covariances_at_zero=None, zero_debias=None, +@contextlib.contextmanager +def _maybe_colocate_with(op, colocate_cov_ops_with_inputs): + """Context to colocate with `op` if `colocate_cov_ops_with_inputs`.""" + if colocate_cov_ops_with_inputs: + if isinstance(op, (list, tuple)): + with tf_ops.colocate_with(op[0]): + yield + else: + with tf_ops.colocate_with(op): + yield + else: + yield + + +def set_global_constants(init_covariances_at_zero=None, + zero_debias=None, eigenvalue_decomposition_threshold=None, eigenvalue_clipping_threshold=None): """Sets various global constants used by the classes in this module.""" @@ -85,7 +103,7 @@ def diagonal_covariance_initializer(shape, dtype, partition_info): # pylint: di return array_ops.ones(shape, dtype) -def _compute_cov(tensor, normalizer=None): +def _compute_cov(tensor, tensor_right=None, normalizer=None): """Compute the empirical second moment of the rows of a 2D Tensor. This function is meant to be applied to random matrices for which the true row @@ -93,6 +111,8 @@ def _compute_cov(tensor, normalizer=None): Args: tensor: A 2D Tensor. + tensor_right: An optional 2D Tensor. If provided, this function computes + the matrix product tensor^T * tensor_right instead of tensor^T * tensor. normalizer: optional scalar for the estimator (by default, the normalizer is the number of rows of tensor). @@ -101,9 +121,14 @@ def _compute_cov(tensor, normalizer=None): """ if normalizer is None: normalizer = array_ops.shape(tensor)[0] - cov = (math_ops.matmul(tensor, tensor, transpose_a=True) / math_ops.cast( - normalizer, tensor.dtype)) - return (cov + array_ops.transpose(cov)) / math_ops.cast(2, cov.dtype) + if tensor_right is None: + cov = ( + math_ops.matmul(tensor, tensor, transpose_a=True) / math_ops.cast( + normalizer, tensor.dtype)) + return (cov + array_ops.transpose(cov)) / math_ops.cast(2.0, cov.dtype) + else: + return (math_ops.matmul(tensor, tensor_right, transpose_a=True) / + math_ops.cast(normalizer, tensor.dtype)) def _append_homog(tensor): @@ -119,7 +144,7 @@ def _append_homog(tensor): rank = len(tensor.shape.as_list()) shape = array_ops.concat([array_ops.shape(tensor)[:-1], [1]], axis=0) ones = array_ops.ones(shape, dtype=tensor.dtype) - return array_ops.concat([tensor, ones], axis=rank-1) + return array_ops.concat([tensor, ones], axis=rank - 1) def scope_string_from_params(params): @@ -157,8 +182,8 @@ def scope_string_from_params(params): elif isinstance(param, (tf_ops.Tensor, variables.Variable)): name_parts.append(scope_string_from_name(param)) else: - raise ValueError( - "Encountered an unsupported param type {}".format(type(param))) + raise ValueError("Encountered an unsupported param type {}".format( + type(param))) return "_".join(name_parts) @@ -209,6 +234,10 @@ class FisherFactor(object): """ pass + @abc.abstractproperty + def _dtype(self): + pass + @property def _cov_initializer(self): return covariance_initializer @@ -220,7 +249,8 @@ class FisherFactor(object): "cov", initializer=self._cov_initializer, shape=self._cov_shape, - trainable=False) + trainable=False, + dtype=self._dtype) @abc.abstractmethod def _compute_new_cov(self, idx=0): @@ -240,9 +270,10 @@ class FisherFactor(object): return moving_averages.assign_moving_average( self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS) + @abc.abstractmethod def make_inverse_update_ops(self): """Create and return update ops corresponding to registered computations.""" - return [] + pass def get_cov(self): return self._cov @@ -257,6 +288,13 @@ class InverseProvidingFactor(FisherFactor): _cov_shape properties. """ + # TODO(b/69108481): This class (and its subclasses) should be refactored to + # serve the matrix quantities it computes as both (potentially stale) + # variables, updated by the inverse update ops, and fresh values stored in + # tensors that recomputed once every session.run() call. Currently matpower + # and damp_inverse have the former behavior, while eigendecomposition has + # the latter. + def __init__(self): self._inverses_by_damping = {} self._matpower_by_exp_and_damping = {} @@ -267,6 +305,10 @@ class InverseProvidingFactor(FisherFactor): def register_damped_inverse(self, damping): """Registers a damped inverse needed by a FisherBlock. + This creates a variable and signals make_inverse_update_ops to make the + corresponding update op. The variable can be read via the method + get_inverse. + Args: damping: The damping value (float or Tensor) for this factor. """ @@ -277,12 +319,17 @@ class InverseProvidingFactor(FisherFactor): "inv_damp{}".format(damping_string), initializer=inverse_initializer, shape=self._cov_shape, - trainable=False) + trainable=False, + dtype=self._dtype) self._inverses_by_damping[damping] = inv def register_matpower(self, exp, damping): """Registers a matrix power needed by a FisherBlock. + This creates a variable and signals make_inverse_update_ops to make the + corresponding update op. The variable can be read via the method + get_matpower. + Args: exp: The exponent (float or Tensor) to raise the matrix to. damping: The damping value (float or Tensor). @@ -295,57 +342,78 @@ class InverseProvidingFactor(FisherFactor): "matpower_exp{}_damp{}".format(exp_string, damping_string), initializer=inverse_initializer, shape=self._cov_shape, - trainable=False) + trainable=False, + dtype=self._dtype) self._matpower_by_exp_and_damping[(exp, damping)] = matpower def register_eigendecomp(self): - """Registers that an eigendecomposition is needed by a FisherBlock.""" + """Registers an eigendecomposition. + + Unlike register_damp_inverse and register_matpower this doesn't create + any variables or inverse ops. Instead it merely makes tensors containing + the eigendecomposition available to anyone that wants them. They will be + recomputed (once) for each session.run() call (when they needed by some op). + """ if not self._eigendecomp: - self._eigendecomp = linalg_ops.self_adjoint_eig(self._cov) + eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self._cov) + + # The matrix self._cov is positive semidefinite by construction, but the + # numerical eigenvalues could be negative due to numerical errors, so here + # we clip them to be at least FLAGS.eigenvalue_clipping_threshold + clipped_eigenvalues = math_ops.maximum(eigenvalues, + EIGENVALUE_CLIPPING_THRESHOLD) + self._eigendecomp = (clipped_eigenvalues, eigenvectors) def make_inverse_update_ops(self): """Create and return update ops corresponding to registered computations.""" - ops = super(InverseProvidingFactor, self).make_inverse_update_ops() + ops = [] num_inverses = len(self._inverses_by_damping) matrix_power_registered = bool(self._matpower_by_exp_and_damping) - use_eig = (self._eigendecomp or matrix_power_registered or - num_inverses >= EIGENVALUE_DECOMPOSITION_THRESHOLD) + use_eig = ( + self._eigendecomp or matrix_power_registered or + num_inverses >= EIGENVALUE_DECOMPOSITION_THRESHOLD) if use_eig: self.register_eigendecomp() # ensures self._eigendecomp is set eigenvalues, eigenvectors = self._eigendecomp # pylint: disable=unpacking-non-sequence - # The matrix self._cov is positive semidefinite by construction, but the - # numerical eigenvalues could be negative due to numerical errors, so here - # we clip them to be at least EIGENVALUE_CLIPPING_THRESHOLD. - clipped_eigenvalues = math_ops.maximum(eigenvalues, - EIGENVALUE_CLIPPING_THRESHOLD) - for damping, inv in self._inverses_by_damping.items(): ops.append( inv.assign( - math_ops.matmul(eigenvectors / (clipped_eigenvalues + damping), + math_ops.matmul(eigenvectors / (eigenvalues + damping), array_ops.transpose(eigenvectors)))) for (exp, damping), matpower in self._matpower_by_exp_and_damping.items(): ops.append( matpower.assign( - math_ops.matmul(eigenvectors * (clipped_eigenvalues + damping)** - exp, array_ops.transpose(eigenvectors)))) + math_ops.matmul(eigenvectors * + (eigenvalues + damping)**exp, + array_ops.transpose(eigenvectors)))) + # These ops share computation and should be run on a single device. + ops = [control_flow_ops.group(*ops)] else: for damping, inv in self._inverses_by_damping.items(): ops.append(inv.assign(utils.posdef_inv(self._cov, damping))) return ops - def get_inverse(self, damping): + def get_damped_inverse(self, damping): + # Note that this function returns a variable which gets updated by the + # inverse ops. It may be stale / inconsistent with the latest value of + # get_cov(). return self._inverses_by_damping[damping] def get_matpower(self, exp, damping): + # Note that this function returns a variable which gets updated by the + # inverse ops. It may be stale / inconsistent with the latest value of + # get_cov(). return self._matpower_by_exp_and_damping[(exp, damping)] def get_eigendecomp(self): + # Unlike get_inverse and get_matpower this doesn't retrieve a stored + # variable, but instead always computes a fresh version from the current + # value of get_cov(). return self._eigendecomp @@ -356,12 +424,21 @@ class FullFactor(InverseProvidingFactor): to any type of parameter in principle, but has very high variance. """ - def __init__(self, params_grads, batch_size): + def __init__(self, + params_grads, + batch_size, + colocate_cov_ops_with_inputs=False): self._batch_size = batch_size + self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs self._orig_params_grads_name = scope_string_from_params( [params_grads, self._batch_size]) - self._params_grads_flat = tuple( - utils.tensors_to_column(params_grad) for params_grad in params_grads) + params_grads_flat = [] + for params_grad in params_grads: + with _maybe_colocate_with(params_grad, + self._colocate_cov_ops_with_inputs): + col = utils.tensors_to_column(params_grad) + params_grads_flat.append(col) + self._params_grads_flat = tuple(params_grads_flat) super(FullFactor, self).__init__() @property @@ -377,11 +454,17 @@ class FullFactor(InverseProvidingFactor): def _num_sources(self): return len(self._params_grads_flat) + @property + def _dtype(self): + return self._params_grads_flat[0].dtype + def _compute_new_cov(self, idx=0): # This will be a very basic rank 1 estimate - return ((self._params_grads_flat[idx] * array_ops.transpose( - self._params_grads_flat[idx])) / math_ops.cast( - self._batch_size, self._params_grads_flat[idx].dtype)) + with _maybe_colocate_with(self._params_grads_flat[idx], + self._colocate_cov_ops_with_inputs): + return ((self._params_grads_flat[idx] * array_ops.transpose( + self._params_grads_flat[idx])) / math_ops.cast( + self._batch_size, self._params_grads_flat[idx].dtype)) class DiagonalFactor(FisherFactor): @@ -394,6 +477,9 @@ class DiagonalFactor(FisherFactor): def _cov_initializer(self): return diagonal_covariance_initializer + def make_inverse_update_ops(self): + return [] + class NaiveDiagonalFactor(DiagonalFactor): """FisherFactor for a diagonal approximation of any type of param's Fisher. @@ -402,10 +488,19 @@ class NaiveDiagonalFactor(DiagonalFactor): to any type of parameter in principle, but has very high variance. """ - def __init__(self, params_grads, batch_size): + def __init__(self, + params_grads, + batch_size, + colocate_cov_ops_with_inputs=False): self._batch_size = batch_size - self._params_grads = tuple( - utils.tensors_to_column(params_grad) for params_grad in params_grads) + self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs + params_grads_flat = [] + for params_grad in params_grads: + with _maybe_colocate_with(params_grad, + self._colocate_cov_ops_with_inputs): + col = utils.tensors_to_column(params_grad) + params_grads_flat.append(col) + self._params_grads = tuple(params_grads_flat) self._orig_params_grads_name = scope_string_from_params( [self._params_grads, self._batch_size]) super(NaiveDiagonalFactor, self).__init__() @@ -422,9 +517,15 @@ class NaiveDiagonalFactor(DiagonalFactor): def _num_sources(self): return len(self._params_grads) + @property + def _dtype(self): + return self._params_grads[0].dtype + def _compute_new_cov(self, idx=0): - return (math_ops.square(self._params_grads[idx]) / math_ops.cast( - self._batch_size, self._params_grads[idx].dtype)) + with _maybe_colocate_with(self._params_grads[idx], + self._colocate_cov_ops_with_inputs): + return (math_ops.square(self._params_grads[idx]) / math_ops.cast( + self._batch_size, self._params_grads[idx].dtype)) class FullyConnectedDiagonalFactor(DiagonalFactor): @@ -440,7 +541,11 @@ class FullyConnectedDiagonalFactor(DiagonalFactor): # TODO(jamesmartens): add units tests for this class - def __init__(self, inputs, outputs_grads, has_bias=False): + def __init__(self, + inputs, + outputs_grads, + has_bias=False, + colocate_cov_ops_with_inputs=False): """Instantiate FullyConnectedDiagonalFactor. Args: @@ -449,18 +554,22 @@ class FullyConnectedDiagonalFactor(DiagonalFactor): outputs_grads: List of Tensors of shape [batch_size, output_size]. Gradient of loss with respect to layer's preactivations. has_bias: bool. If True, append '1' to each input. + colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with + their inputs. """ self._outputs_grads = outputs_grads + self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs self._batch_size = array_ops.shape(inputs)[0] - self._orig_tensors_name = scope_string_from_params((inputs,) + - tuple(outputs_grads)) + self._orig_tensors_name = scope_string_from_params( + (inputs,) + tuple(outputs_grads)) # Note that we precompute the required operations on the inputs since the # inputs don't change with the 'idx' argument to _compute_new_cov. (Only # the target entry of _outputs_grads changes with idx.) - if has_bias: - inputs = _append_homog(inputs) - self._squared_inputs = math_ops.square(inputs) + with _maybe_colocate_with(inputs, self._colocate_cov_ops_with_inputs): + if has_bias: + inputs = _append_homog(inputs) + self._squared_inputs = math_ops.square(inputs) super(FullyConnectedDiagonalFactor, self).__init__() @@ -476,17 +585,23 @@ class FullyConnectedDiagonalFactor(DiagonalFactor): def _num_sources(self): return len(self._outputs_grads) + @property + def _dtype(self): + return self._outputs_grads[0].dtype + def _compute_new_cov(self, idx=0): # The well-known special formula that uses the fact that the entry-wise # square of an outer product is the outer-product of the entry-wise squares. # The gradient is the outer product of the input and the output gradients, # so we just square both and then take their outer-product. - new_cov = math_ops.matmul( - self._squared_inputs, - math_ops.square(self._outputs_grads[idx]), - transpose_a=True) - new_cov /= math_ops.cast(self._batch_size, new_cov.dtype) - return new_cov + with _maybe_colocate_with(self._squared_inputs, + self._colocate_cov_ops_with_inputs): + new_cov = math_ops.matmul( + self._squared_inputs, + math_ops.square(self._outputs_grads[idx]), + transpose_a=True) + new_cov /= math_ops.cast(self._batch_size, new_cov.dtype) + return new_cov class ConvDiagonalFactor(DiagonalFactor): @@ -494,8 +609,14 @@ class ConvDiagonalFactor(DiagonalFactor): # TODO(jamesmartens): add units tests for this class - def __init__(self, inputs, outputs_grads, filter_shape, strides, padding, - has_bias=False): + def __init__(self, + inputs, + outputs_grads, + filter_shape, + strides, + padding, + has_bias=False, + colocate_cov_ops_with_inputs=False): """Creates a ConvDiagonalFactor object. Args: @@ -510,29 +631,36 @@ class ConvDiagonalFactor(DiagonalFactor): padding: The padding in this layer (1-D of Tensor length 4). has_bias: Python bool. If True, the layer is assumed to have a bias parameter in addition to its filter parameter. + colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with + their inputs. """ self._filter_shape = filter_shape self._has_bias = has_bias self._outputs_grads = outputs_grads + self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs - self._orig_tensors_name = scope_string_from_name((inputs,) - + tuple(outputs_grads)) + self._orig_tensors_name = scope_string_from_name( + (inputs,) + tuple(outputs_grads)) # Note that we precompute the required operations on the inputs since the # inputs don't change with the 'idx' argument to _compute_new_cov. (Only # the target entry of _outputs_grads changes with idx.) - filter_height, filter_width, _, _ = self._filter_shape - patches = array_ops.extract_image_patches( - inputs, - ksizes=[1, filter_height, filter_width, 1], - strides=strides, - rates=[1, 1, 1, 1], - padding=padding) + with _maybe_colocate_with(inputs, self._colocate_cov_ops_with_inputs): + filter_height, filter_width, _, _ = self._filter_shape - if has_bias: - patches = _append_homog(patches) + # TODO(b/64144716): there is potential here for a big savings in terms of + # memory use. + patches = array_ops.extract_image_patches( + inputs, + ksizes=[1, filter_height, filter_width, 1], + strides=strides, + rates=[1, 1, 1, 1], + padding=padding) + + if has_bias: + patches = _append_homog(patches) - self._patches = patches + self._patches = patches super(ConvDiagonalFactor, self).__init__() @@ -543,21 +671,29 @@ class ConvDiagonalFactor(DiagonalFactor): @property def _cov_shape(self): filter_height, filter_width, in_channels, out_channels = self._filter_shape - return [filter_height * filter_width * in_channels + self._has_bias, - out_channels] + return [ + filter_height * filter_width * in_channels + self._has_bias, + out_channels + ] @property def _num_sources(self): return len(self._outputs_grads) + @property + def _dtype(self): + return self._outputs_grads[0].dtype + def _compute_new_cov(self, idx=0): - outputs_grad = self._outputs_grads[idx] - batch_size = array_ops.shape(self._patches)[0] + with _maybe_colocate_with(self._outputs_grads[idx], + self._colocate_cov_ops_with_inputs): + outputs_grad = self._outputs_grads[idx] + batch_size = array_ops.shape(self._patches)[0] - new_cov = self._convdiag_sum_of_squares(self._patches, outputs_grad) - new_cov /= math_ops.cast(batch_size, new_cov.dtype) + new_cov = self._convdiag_sum_of_squares(self._patches, outputs_grad) + new_cov /= math_ops.cast(batch_size, new_cov.dtype) - return new_cov + return new_cov def _convdiag_sum_of_squares(self, patches, outputs_grad): # This computes the sum of the squares of the per-training-case "gradients". @@ -572,19 +708,24 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor): """Kronecker factor for the input or output side of a fully-connected layer. """ - def __init__(self, tensors, has_bias=False): + def __init__(self, + tensors, + has_bias=False, + colocate_cov_ops_with_inputs=False): """Instantiate FullyConnectedKroneckerFactor. Args: tensors: List of Tensors of shape [batch_size, n]. Represents either a layer's inputs or its output's gradients. - has_bias: bool. If True, assume this factor is for the layer's inputs and - append '1' to each row. + has_bias: bool. If True, append '1' to each row. + colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with + their inputs. """ # The tensor argument is either a tensor of input activations or a tensor of # output pre-activation gradients. self._has_bias = has_bias self._tensors = tensors + self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs super(FullyConnectedKroneckerFactor, self).__init__() @property @@ -601,11 +742,17 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor): def _num_sources(self): return len(self._tensors) + @property + def _dtype(self): + return self._tensors[0].dtype + def _compute_new_cov(self, idx=0): - tensor = self._tensors[idx] - if self._has_bias: - tensor = _append_homog(tensor) - return _compute_cov(tensor) + with _maybe_colocate_with(self._tensors[idx], + self._colocate_cov_ops_with_inputs): + tensor = self._tensors[idx] + if self._has_bias: + tensor = _append_homog(tensor) + return _compute_cov(tensor) class ConvInputKroneckerFactor(InverseProvidingFactor): @@ -618,7 +765,13 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): Section 3.1 Estimating the factors. """ - def __init__(self, inputs, filter_shape, strides, padding, has_bias=False): + def __init__(self, + inputs, + filter_shape, + strides, + padding, + has_bias=False, + colocate_cov_ops_with_inputs=False): """Initializes ConvInputKroneckerFactor. Args: @@ -630,12 +783,15 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): width_stride, in_channel_stride]. padding: str. Padding method for layer. "SAME" or "VALID". has_bias: bool. If True, append 1 to in_channel. + colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with + their inputs. """ self._filter_shape = filter_shape self._strides = strides self._padding = padding self._has_bias = has_bias self._inputs = inputs + self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs super(ConvInputKroneckerFactor, self).__init__() @property @@ -655,26 +811,34 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): def _num_sources(self): return 1 + @property + def _dtype(self): + return self._inputs.dtype + def _compute_new_cov(self, idx=0): if idx != 0: raise ValueError("ConvInputKroneckerFactor only supports idx = 0") # TODO(jamesmartens): factor this patches stuff out into a utility function - filter_height, filter_width, in_channels, _ = self._filter_shape - patches = array_ops.extract_image_patches( - self._inputs, - ksizes=[1, filter_height, filter_width, 1], - strides=self._strides, - rates=[1, 1, 1, 1], - padding=self._padding) + with _maybe_colocate_with(self._inputs, self._colocate_cov_ops_with_inputs): + filter_height, filter_width, in_channels, _ = self._filter_shape - flatten_size = (filter_height * filter_width * in_channels) - patches_flat = array_ops.reshape(patches, [-1, flatten_size]) + # TODO(b/64144716): there is potential here for a big savings in terms of + # memory use. + patches = array_ops.extract_image_patches( + self._inputs, + ksizes=[1, filter_height, filter_width, 1], + strides=self._strides, + rates=[1, 1, 1, 1], + padding=self._padding) - if self._has_bias: - patches_flat = _append_homog(patches_flat) + flatten_size = (filter_height * filter_width * in_channels) + patches_flat = array_ops.reshape(patches, [-1, flatten_size]) - return _compute_cov(patches_flat) + if self._has_bias: + patches_flat = _append_homog(patches_flat) + + return _compute_cov(patches_flat) class ConvOutputKroneckerFactor(InverseProvidingFactor): @@ -688,15 +852,18 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor): Section 3.1 Estimating the factors. """ - def __init__(self, outputs_grads): + def __init__(self, outputs_grads, colocate_cov_ops_with_inputs=False): """Initializes ConvOutputKroneckerFactor. Args: outputs_grads: list of Tensors. Each Tensor is of shape - [batch_size, height, width, out_channels]. + [batch_size, height, width, out_channels]. + colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with + their inputs. """ self._out_channels = outputs_grads[0].shape.as_list()[3] self._outputs_grads = outputs_grads + self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs super(ConvOutputKroneckerFactor, self).__init__() @property @@ -712,7 +879,286 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor): def _num_sources(self): return len(self._outputs_grads) + @property + def _dtype(self): + return self._outputs_grads[0].dtype + def _compute_new_cov(self, idx=0): - reshaped_tensor = array_ops.reshape(self._outputs_grads[idx], - [-1, self._out_channels]) - return _compute_cov(reshaped_tensor) + with _maybe_colocate_with(self._outputs_grads[idx], + self._colocate_cov_ops_with_inputs): + reshaped_tensor = array_ops.reshape(self._outputs_grads[idx], + [-1, self._out_channels]) + return _compute_cov(reshaped_tensor) + + +class FullyConnectedMultiKF(InverseProvidingFactor): + """Kronecker factor for a fully connected recurrent layer.""" + + def __init__(self, + tensor_lists, + has_bias=False, + colocate_cov_ops_with_inputs=False): + """Constructs a new `FullyConnectedMultiKF`. + + Args: + tensor_lists: List of lists of Tensors of shape [batch_size, n]. + has_bias: bool. If True, '1' is appended to each row. + colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with + their inputs. + """ + + self._orig_tensors_name = scope_string_from_params(tensor_lists) + self._batch_size = array_ops.shape(tensor_lists[0][0])[0] + self._num_timesteps = len(tensor_lists[0]) + + tensors = tuple( + array_ops.concat(tensor_list, 0) for tensor_list in tensor_lists) + if has_bias: + tensors = tuple(_append_homog(tensor) for tensor in tensors) + self._tensors = tensors + + self._cov_dt1 = None + self._option1quants_by_damping = {} + self._option2quants_by_damping = {} + self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs + + super(FullyConnectedMultiKF, self).__init__() + + @property + def _var_scope(self): + return "ff_fc_multi/" + self._orig_tensors_name + + @property + def _num_sources(self): + return len(self._tensors) + + @property + def _dtype(self): + return self._tensors[0].dtype + + def make_covariance_update_op(self, ema_decay): + with _maybe_colocate_with(self._tensors, + self._colocate_cov_ops_with_inputs): + op = super(FullyConnectedMultiKF, + self).make_covariance_update_op(ema_decay) + + if self._cov_dt1 is not None: + new_cov_dt1 = math_ops.add_n( + tuple( + self._compute_new_cov_dt1(idx) + for idx in range(self._num_sources))) + op2 = moving_averages.assign_moving_average( + self._cov_dt1, new_cov_dt1, ema_decay, zero_debias=ZERO_DEBIAS) + + # TODO(b/69112164): + # It's important that _cov and _cov_dt1 remain consistent with each + # other while the inverse ops are happening. How can we ensure this? + # We will need to add explicit synchronization for this to + # work with asynchronous training. + op = control_flow_ops.group(op, op2) + + return op + + def _compute_new_cov(self, idx=0): + tensor = self._tensors[idx] + normalizer = self._num_timesteps * self._batch_size + return _compute_cov(tensor, normalizer=normalizer) + + def _compute_new_cov_dt1(self, idx=0): + tensor = self._tensors[idx] + normalizer = self._num_timesteps * self._batch_size + tensor_present = tensor[:-self._batch_size, :] + tensor_future = tensor[self._batch_size:, :] + return _compute_cov( + tensor_future, tensor_right=tensor_present, normalizer=normalizer) + + @property + def _cov_shape(self): + size = self._tensors[0].shape[1] + return [size, size] + + @property + def _vec_shape(self): + size = self._tensors[0].shape[1] + return [size] + + def get_option1quants(self, damping): + return self._option1quants_by_damping[damping] + + def get_option2quants(self, damping): + return self._option2quants_by_damping[damping] + + def get_cov_dt1(self): + assert self._cov_dt1 is not None + return self._cov_dt1 + + def register_cov_dt1(self): + """Create a variable representing temporal cross-covariance. + + (This is technically the second moment, not covariance, since it's + not mean subtracted.) + """ + if self._cov_dt1 is None: + with variable_scope.variable_scope(self._var_scope): + self._cov_dt1 = variable_scope.get_variable( + "cov_dt1", + initializer=init_ops.zeros_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + + def register_option1quants(self, damping): + + self.register_eigendecomp() + self.register_cov_dt1() + + if damping not in self._option1quants_by_damping: + # It's questionable as to whether we should initialize with stuff like + # this at all. Ideally these values should never be used until they are + # updated at least once. + damping_string = scalar_or_tensor_to_string(damping) + with variable_scope.variable_scope(self._var_scope): + Lmat = variable_scope.get_variable( # pylint: disable=invalid-name + "Lmat_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + psi = variable_scope.get_variable( + "psi_damp{}".format(damping_string), + initializer=init_ops.ones_initializer, + shape=self._vec_shape, + trainable=False, + dtype=self._dtype) + + self._option1quants_by_damping[damping] = (Lmat, psi) + + def register_option2quants(self, damping): + + self.register_eigendecomp() + self.register_cov_dt1() + + if damping not in self._option2quants_by_damping: + # It's questionable as to whether we should initialize with stuff like + # this at all. Ideally these values should never be used until they are + # updated at least once. + damping_string = scalar_or_tensor_to_string(damping) + with variable_scope.variable_scope(self._var_scope): + Pmat = variable_scope.get_variable( # pylint: disable=invalid-name + "Lmat_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + Kmat = variable_scope.get_variable( # pylint: disable=invalid-name + "Kmat_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + mu = variable_scope.get_variable( + "mu_damp{}".format(damping_string), + initializer=init_ops.ones_initializer, + shape=self._vec_shape, + trainable=False, + dtype=self._dtype) + + self._option2quants_by_damping[damping] = (Pmat, Kmat, mu) + + def make_inverse_update_ops(self): + """Create and return update ops corresponding to registered computations.""" + # TODO(b/69918258): Add correctness tests for this method. + # pylint: disable=invalid-name + + ops = super(FullyConnectedMultiKF, self).make_inverse_update_ops() + + if (len(self._option1quants_by_damping) + + len(self._option2quants_by_damping)): + + # Note that C0 and C1 are stand-ins for A0 and A1, or G0 and G1, from + # the pseudo-code in the original paper. Because the computations for + # the A and G case are essentially the same they can both be performed by + # the same class (this one). + + C1 = self.get_cov_dt1() + + # Get the eigendecomposition of C0 (= self.get_cov()) + eigen_e, eigen_V = self.get_eigendecomp() + + # TODO(b/69678661): Note, there is an implicit assumption here that C1 + # and C0 (as represented here by its eigen-decomp) are consistent. This + # could fail to be the case if self._cov and self._cov_dt1 are not updated + # consistently, or are somehow read between or during the cov updates. + # Can this possibly happen? Is there a way to prevent it? + + for damping, (Lmat_var, + psi_var) in self._option1quants_by_damping.items(): + + invsqrtC0 = math_ops.matmul( + eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True) + + # Might need to enforce symmetry lost due to numerical issues. + invsqrtC0 = (invsqrtC0 + array_ops.transpose(invsqrtC0)) / 2.0 + + # The following line imposses the symmetry assumed by "Option 1" on C1. + # Stangely the code can work okay with this line commented out, + # depending on how psd_eig is defined. I'm not sure why. + C1 = (C1 + array_ops.transpose(C1)) / 2.0 + + # hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means \hat{Psi}) + hPsi = math_ops.matmul(math_ops.matmul(invsqrtC0, C1), invsqrtC0) + + # Compute the decomposition U*diag(psi)*U^T = hPsi + psi, U = utils.posdef_eig(hPsi) + + # L = C0^(-1/2) * U + Lmat = math_ops.matmul(invsqrtC0, U) + + ops.append(Lmat_var.assign(Lmat)) + ops.append(psi_var.assign(psi)) + + for damping, (Pmat_var, Kmat_var, + mu_var) in self._option2quants_by_damping.items(): + + # compute C0^(-1/2) + invsqrtC0 = math_ops.matmul( + eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True) + + # Might need to enforce symmetry lost due to numerical issues. + invsqrtC0 = (invsqrtC0 + array_ops.transpose(invsqrtC0)) / 2.0 + + # Compute the product C0^(-1/2) * C1 + invsqrtC0C1 = math_ops.matmul(invsqrtC0, C1) + + # hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means \hat{Psi}) + hPsi = math_ops.matmul(invsqrtC0C1, invsqrtC0) + + # Compute the decomposition E*diag(mu)*E^T = hPsi^T * hPsi + # Note that we using the notation mu instead of "m" for the eigenvalues. + # Instead of computing the product hPsi^T * hPsi and then doing an + # eigen-decomposition of this we just compute the SVD of hPsi and then + # square the singular values to get the eigenvalues. For a justification + # of this approach, see: + # https://en.wikipedia.org/wiki/Singular-value_decomposition#Relation_to_eigenvalue_decomposition + sqrtmu, _, E = linalg_ops.svd(hPsi) + mu = math_ops.square(sqrtmu) + + # Mathematically, the eigenvalues should not should not exceed 1.0, but + # due to numerical issues, or possible issues with inconsistent + # values of C1 and (the eigen-decomposition of) C0 they might. So + # we enforce this condition. + mu = math_ops.minimum(mu, 1.0) + + # P = (C0^(-1/2) * C1)^T * C0^(-1/2) = C_1^T * C_0^(-1) + Pmat = math_ops.matmul(invsqrtC0C1, invsqrtC0, transpose_a=True) + + # K = C_0^(-1/2) * E + Kmat = math_ops.matmul(invsqrtC0, E) + + ops.append(Pmat_var.assign(Pmat)) + ops.append(Kmat_var.assign(Kmat)) + ops.append(mu_var.assign(mu)) + + return [control_flow_ops.group(*ops)] + + # pylint: enable=invalid-name diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py index 2139a261e05e33bcb650f31d5d9e85f592009ba6..ca42afe6fb2f5c7d7de8b5b087dc11be30a75d5e 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -26,7 +26,9 @@ from __future__ import print_function from collections import defaultdict from collections import OrderedDict +from functools import partial +import math import six from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb @@ -35,7 +37,6 @@ from tensorflow.contrib.kfac.python.ops import utils from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest # Names for various approximations that can be requested for Fisher blocks. @@ -58,12 +59,22 @@ _CONV2D_APPROX_TO_BLOCK_TYPES = { APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB, } +APPROX_KRONECKER_INDEP_NAME = "kron_indep" +APPROX_KRONECKER_SERIES_1_NAME = "kron_series_1" +APPROX_KRONECKER_SERIES_2_NAME = "kron_series_2" + +_FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_INDEP_NAME: fb.FullyConnectedMultiIndepFB, + APPROX_KRONECKER_SERIES_1_NAME: partial(fb.FullyConnectedSeriesFB, + option=1), + APPROX_KRONECKER_SERIES_2_NAME: partial(fb.FullyConnectedSeriesFB, + option=2) +} + # Possible value for 'reuse' keyword argument. Sets 'reuse' to # tf.get_variable_scope().reuse. VARIABLE_SCOPE = "VARIABLE_SCOPE" -# TODO(jamesmartens): need to add find_canonical_output back into this somewhere - def ensure_sequence(obj): """If `obj` isn't a tuple or list, return a tuple containing `obj`.""" @@ -129,7 +140,10 @@ class LayerCollection(object): sum. """ - def __init__(self, graph=None, name="LayerCollection"): + def __init__(self, + graph=None, + colocate_cov_ops_with_inputs=False, + name="LayerCollection"): self.fisher_blocks = LayerParametersDict() self.fisher_factors = OrderedDict() self._linked_parameters = dict( @@ -140,6 +154,9 @@ class LayerCollection(object): self._default_generic_approximation = APPROX_FULL_NAME self._default_fully_connected_approximation = APPROX_KRONECKER_NAME self._default_convolution_2d_approximation = APPROX_KRONECKER_NAME + self._default_fully_connected_multi_approximation = ( + APPROX_KRONECKER_SERIES_2_NAME) + self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs with variable_scope.variable_scope(None, default_name=name) as scope: self._var_scope = scope.name @@ -149,19 +166,13 @@ class LayerCollection(object): """LossFunctions registered with this LayerCollection.""" return list(self._loss_dict.values()) - def is_variable_registered(self, variable): - """Checks whether the variable has already been registered. - - Args: - variable: A single variable or tensor. - Returns: - True if the variable has been registered either by itself or as part of a - tuple. - """ - return any([ - variable in key if isinstance(key, (tuple, list)) else variable == key - for key in self.fisher_blocks.keys() - ]) + @property + def registered_variables(self): + """A tuple of all of the variables currently registered.""" + tuple_of_tuples = (ensure_sequence(key) for key, block + in six.iteritems(self.fisher_blocks)) + flat_tuple = tuple(item for tuple_ in tuple_of_tuples for item in tuple_) + return flat_tuple @property def linked_parameters(self): @@ -181,8 +192,7 @@ class LayerCollection(object): def default_generic_approximation(self): return self._default_generic_approximation - @default_generic_approximation.setter - def default_generic_approximation(self, value): + def set_default_generic_approximation(self, value): if value not in _GENERIC_APPROX_TO_BLOCK_TYPES: raise ValueError( "{} is not a valid approximation for generic variables.".format( @@ -193,8 +203,7 @@ class LayerCollection(object): def default_fully_connected_approximation(self): return self._default_fully_connected_approximation - @default_fully_connected_approximation.setter - def default_fully_connected_approximation(self, value): + def set_default_fully_connected_approximation(self, value): if value not in _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES: raise ValueError( "{} is not a valid approximation for fully connected layers.".format( @@ -205,50 +214,44 @@ class LayerCollection(object): def default_conv2d_approximation(self): return self._default_convolution_2d_approximation - @default_conv2d_approximation.setter - def default_conv2d_approximation(self, value): + def set_default_conv2d_approximation(self, value): if value not in _CONV2D_APPROX_TO_BLOCK_TYPES: raise ValueError( "{} is not a valid approximation for 2d convolutional layers.".format( value)) self._default_convolution_2d_approximation = value + @property + def default_fully_connected_multi_approximation(self): + return self._default_fully_connected_multi_approximation + + def set_default_fully_connected_multi_approximation(self, value): + if value not in _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES: + raise ValueError("{} is not a valid approximation for a fully-connected " + "multi layer.".format(value)) + self._default_fully_connected_multi_approximation = value + def register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE): """Validates and registers the layer_key associated with the fisher_block. - Validation consists of checking whether the key was already registered or - if any of the elements of layer_key (if it's a tuple) were already - registered as part of another tuple (throws an error if so). If any of the - elements were registered by themselves, or as part of tuples that are - subsets of this layer_key, those registrations are first removed. - - If the layer_key is a subset of an existing registration, registration of - the new, smaller layer_key is skipped. - - e.g. If registrations include {'a': foo, ('b', 'c'): bar}, then - - register_layer('a', baz) -> ValueError - - register_layer(('b', 'c', 'd'), baz) -> - {'a': foo, ('b', 'c', 'd'): baz} - - register_layer('b', baz) -> - {'a': foo, ('b', 'c'): bar} (No change) - - register_layer(('a', 'd'), baz) -> - {('a', 'd'): baz, ('b', 'c'): bar} - - register_layer(('b', 'd'), baz) -> ValueError - Args: - layer_key: The key to check for in existing registrations and to register - if valid. - fisher_block: The associated fisher block. - reuse: Method to use for inserting new FisherBlocks. One of True, False, - or VARIABLE_SCOPE. + layer_key: A variable or tuple of variables. The key to check for in + existing registrations and to register if valid. + fisher_block: The associated `FisherBlock`. + reuse: Method to use for inserting new `FisherBlock`s. One of True, False, + or 'VARIABLE_SCOPE'. Raises: - ValueError: If the layer_key was already registered, or if a subset of the - layer_key has already been registered as part of a different tuple. + ValueError: If `layer_key` was already registered and reuse is `False`, + if `layer_key` was registered with a different block type, or if + `layer_key` shares any variables with but is not equal to a previously + registered key. + KeyError: If `reuse` is `True` but `layer_key` was not previously + registered. Returns: - FisherBlock registered under 'layer_key'. May or may not be the same as - 'fisher_block'. + The `FisherBlock` registered under `layer_key`. If `layer_key` was already + registered, this will be the previously registered `FisherBlock`. """ if reuse is VARIABLE_SCOPE: reuse = variable_scope.get_variable_scope().reuse @@ -268,110 +271,84 @@ class LayerCollection(object): # Insert fisher_block into self.fisher_blocks. if layer_key in self.fisher_blocks: raise ValueError("Duplicate registration: {}".format(layer_key)) - if isinstance(layer_key, (tuple, list)): - return self._register_block_with_sequence_key(layer_key, fisher_block) - else: - return self._register_block_with_nonsequence_key(layer_key, fisher_block) - - def _register_block_with_sequence_key(self, layer_key, fisher_block): - """Validates and registers the layer_key if it's a sequence.""" - # Find all keys that are either supersets or subsets of 'layer_key'. - inclusions = { - fisher_elt - for layer_elt in layer_key - for fisher_elt in self.fisher_blocks - if self._equal_or_subset(layer_elt, fisher_elt) + # Raise an error if any variable in layer_key has been registered in any + # other blocks. + variable_to_block = { + var: (params, block) + for (params, block) in self.fisher_blocks.items() + for var in ensure_sequence(params) } - - if not inclusions: - self.fisher_blocks[layer_key] = fisher_block - return fisher_block - - result_key = None - for key in inclusions: - fisher_block_key = key if isinstance(key, (tuple, list)) else (key,) - in_existing_only = set(fisher_block_key) - set(layer_key) - in_new_only = set(layer_key) - set(fisher_block_key) - - if in_existing_only and in_new_only: - # Existing and new key have an intersection but neither is a subset of - # the other. This is an error. + for variable in ensure_sequence(layer_key): + if variable in variable_to_block: + prev_key, prev_block = variable_to_block[variable] raise ValueError( - "Inconsistent registration, expected new key to be a subset or " - "superset of the existing key: existing is {}, new is {}".format( - key, layer_key)) - elif in_existing_only and not in_new_only: - # Existing key is strict superset of new key. Return existing - # FisherBlock. - logging.warning("Graph Registration Warning: tried to register " - "a subset ({}) of an already registered tuple " - "({}), skipping".format(layer_key, fisher_block_key)) - assert result_key is None - result_key = key - elif in_new_only and not in_existing_only: - # Existing key is a strict subset of new key. Replace existing - # FisherBlock with new one. - # - # TODO(b/68715045): This is dangerous. If there are existing - # registrations for a minibatch from elsewhere in the graph, they won't - # be re-registered with this new FisherBlock. The type of FisherBlock - # could also change here. - logging.warning( - "Replacing existing FisherBlock for key {} with new FisherBlock " - "for key {}. {} registered minibatches from the existing " - "FisherBlock will not be migrated.".format( - key, layer_key, - self.fisher_blocks[key].num_registered_minibatches)) - self.fisher_blocks.pop(key) - self.fisher_blocks[layer_key] = fisher_block - assert result_key is None - result_key = layer_key - elif not in_new_only and not in_existing_only: - # Existing and new are identical. Reuse the old FisherBlock. - # - # TODO(b/68715045): This is dangerous. If the new FisherBlock has - # existing registered minibatches, they will not be migrated to the - # existing FisherBlock. - assert result_key is None - result_key = key - else: - raise ValueError("Unexpected layer key conflict: {} vs. {}".format( - layer_key, key)) - - return self.fisher_blocks[result_key] - - def _register_block_with_nonsequence_key(self, layer_key, fisher_block): - """Validates and registers the layer_key if it's not a sequence.""" - inclusions = { - fisher_elt - for fisher_elt in self.fisher_blocks - if self._equal_or_subset(layer_key, fisher_elt) - } - - if not inclusions: - self.fisher_blocks[layer_key] = fisher_block - else: - logging.warning("Graph Registration Warning: tried to register " - "variable ({}) but a containing tuple was already " - "registered ({}), skipping".format(layer_key, inclusions)) - + "Attempted to register layer_key {} with block {}, but variable {}" + " was already registered in key {} with block {}.".format( + layer_key, fisher_block, variable, prev_key, prev_block)) + self.fisher_blocks[layer_key] = fisher_block return fisher_block - def _equal_or_subset(self, elt1, elt2): - """Checks if the elements are equal or one is contained in the other.""" - return (elt1 == elt2 or (isinstance(elt1, - (tuple, list)) and elt2 in elt1) or - (isinstance(elt2, (tuple, list)) and elt1 in elt2)) - def get_use_count_map(self): """Returns a dict of variables to their number of registrations.""" + # TODO(b/70283403): Reimplement this in the old way, where each + # registration function would be responsible for incrementing the count. + # Also, this version has a bug: it won't do the right thing for generic + # registration for parameters that are shared. i.e. it won't set the use + # count to infinity. vars_to_uses = defaultdict(int) for key, block in six.iteritems(self.fisher_blocks): - key = key if isinstance(key, (tuple, list)) else (key,) + n = ( + block.num_inputs()*block.num_registered_minibatches if isinstance( + block, (fb.FullyConnectedSeriesFB, fb.FullyConnectedMultiIndepFB)) + else block.num_registered_minibatches) + key = ensure_sequence(key) for k in key: - vars_to_uses[k] += block.num_registered_minibatches + vars_to_uses[k] += n return vars_to_uses + def check_registration(self, variables): + """Checks that all variable uses have been registered properly. + + Args: + variables: List of variables. + + Raises: + ValueError: If any registered variables are not included in the list. + ValueError: If any variable in the list is not registered. + ValueError: If any variable in the list is registered with the wrong + number of "uses" in the subgraph recorded (vs the number of times that + variable is actually used in the subgraph). + """ + # Note that overlapping parameters (i.e. those that share variables) will + # be caught by layer_collection.LayerParametersDict during registration. + + reg_use_map = self.get_use_count_map() + + error_messages = [] + + for var in variables: + total_uses = self.subgraph.variable_uses(var) + reg_uses = reg_use_map[var] + + if reg_uses == 0: + error_messages.append("Variable {} not registered.".format(var)) + elif (not math.isinf(reg_uses)) and reg_uses != total_uses: + error_messages.append( + "Variable {} registered with wrong number of uses ({} " + "registrations vs {} uses).".format(var, reg_uses, total_uses)) + + num_get_vars = len(reg_use_map) + + if num_get_vars > len(variables): + error_messages.append("{} registered variables were not included in list." + .format(num_get_vars - len(variables))) + + if error_messages: + error_messages = [ + "Found the following errors with variable registration:" + ] + error_messages + raise ValueError("\n\t".join(error_messages)) + def get_blocks(self): return self.fisher_blocks.values() @@ -463,11 +440,11 @@ class LayerCollection(object): this layer. Weight matrix should have shape [input_size, output_size]. Bias should have shape [output_size]. inputs: Tensor of shape [batch_size, input_size]. Inputs to layer. - outputs: Tensor of shape [batch_size, output_size]. Preactivations + outputs: Tensor of shape [batch_size, output_size]. Outputs produced by layer. - approx: str. One of APPROX_KRONECKER_NAME or APPROX_DIAGONAL_NAME. + approx: str. One of "kron" or "diagonal". reuse: bool or str. If True, reuse an existing FisherBlock. If False, - create a new FisherBlock. If VARIABLE_SCOPE, use + create a new FisherBlock. If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. Raises: @@ -509,10 +486,10 @@ class LayerCollection(object): inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs to layer. outputs: Tensor of shape [batch_size, height, width, out_channels]. - Preactivations produced by layer. - approx: str. One of APPROX_KRONECKER_NAME or APPROX_DIAGONAL_NAME. + Output produced by layer. + approx: str. One of "kron" or "diagonal". reuse: bool or str. If True, reuse an existing FisherBlock. If False, - create a new FisherBlock. If VARIABLE_SCOPE, use + create a new FisherBlock. If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. Raises: @@ -542,14 +519,11 @@ class LayerCollection(object): """Registers a generic layer. Args: - params: Tensor or 2-tuple of Tensors corresponding to weight and bias of - this layer. Weight matrix should have shape [kernel_height, - kernel_width, in_channels, out_channels]. Bias should have shape - [out_channels]. + params: Tensor or tuple of Tensors corresponding to the parameters. batch_size: 0-D Tensor. Size of the minibatch. - approx: str. One of APPROX_KRONECKER_NAME or APPROX_DIAGONAL_NAME. + approx: str. One of "full" or "diagonal". reuse: bool or str. If True, reuse an existing FisherBlock. If False, - create a new FisherBlock. If VARIABLE_SCOPE, use + create a new FisherBlock. If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. Raises: @@ -570,6 +544,47 @@ class LayerCollection(object): block = self.register_block(params, block_type(self, params), reuse=reuse) block.register_additional_minibatch(batch_size) + def register_fully_connected_multi(self, params, inputs, outputs, + approx=None): + """Register fully connected layers with shared parameters. + + This can handle general fully-connected layers with shared parameters, but + has specialized approximations to deal with the case where there is a + meaningful linear order to the share instances (such as in an RNN). + + Args: + params: Tensor or 2-tuple of Tensors corresponding to weight and bias of + this layer. Weight matrix should have shape [input_size, output_size]. + Bias should have shape [output_size]. + inputs: A list of tensors, each of shape [batch_size, input_size]. Inputs + to layer. In the case of RNNs, one Tensor per time step. + outputs: A list of tensors, the same length as 'inputs', each of shape + [batch_size, output_size]. Outputs produced by layer. In the case of + RNNs, one Tensor per time step. + approx: str. One of "kron_indep", "kron_series_1", or "kron_series_2". + + Raises: + ValueError: For improper value to 'approx'. + """ + if approx is None: + approx = self._get_linked_approx(params) + if approx is None: + approx = self.default_fully_connected_multi_approximation + has_bias = isinstance(params, (tuple, list)) + + # TODO(b/70283649): something along the lines of find_canonical_output + # should be added back in here (and for the other block types, arguably). + + if approx not in _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES: + raise ValueError("Bad value {} for approx.".format(approx)) + block_type = _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES[approx] + + # For now we don't support multiple minibatches for this type of layer, so + # we set reuse=False + self.register_block(params, + block_type(self, inputs, outputs, has_bias=has_bias), + reuse=False) + def register_categorical_predictive_distribution(self, logits, seed=None, @@ -710,6 +725,9 @@ class LayerCollection(object): "LayerCollection.fisher_factors. The pair cannot be hashed.").format( cls, args)) - with variable_scope.variable_scope(self._var_scope): - return utils.setdefault(self.fisher_factors, (cls, args), - lambda: cls(*args)) + key = cls, args + if key not in self.fisher_factors: + colo = self._colocate_cov_ops_with_inputs + with variable_scope.variable_scope(self._var_scope): + self.fisher_factors[key] = cls(*args, colocate_cov_ops_with_inputs=colo) + return self.fisher_factors[key] diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions.py b/tensorflow/contrib/kfac/python/ops/loss_functions.py index e2e5bc3ffea3e52087c24802948bc8260e3b199a..d449abcfa78b361b9d4774ca5c2e936f14f65433 100644 --- a/tensorflow/contrib/kfac/python/ops/loss_functions.py +++ b/tensorflow/contrib/kfac/python/ops/loss_functions.py @@ -91,13 +91,13 @@ class LossFunction(object): @abc.abstractmethod def _evaluate(self, targets): - """Evaluates the log probability of the targets. + """Evaluates the negative log probability of the targets. Args: targets: Tensor that distribution can calculate log_prob() of. Returns: - log probability of each target, summed across all targets. + negative log probability of each target, summed across all targets. """ pass diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py index 88299e495cb3069280cd3ae33d1cdd65f653a01b..ecf7f3e4e5ab7d9c151f760fdab733bc3830e37b 100644 --- a/tensorflow/contrib/kfac/python/ops/optimizer.py +++ b/tensorflow/contrib/kfac/python/ops/optimizer.py @@ -35,17 +35,20 @@ from tensorflow.python.training import gradient_descent class KfacOptimizer(gradient_descent.GradientDescentOptimizer): """The KFAC Optimizer (https://arxiv.org/abs/1503.05671).""" - def __init__( - self, - learning_rate, - cov_ema_decay, - damping, - layer_collection, - momentum=0., - momentum_type="regular", - norm_constraint=None, - name="KFAC", - estimation_mode="gradients"): + def __init__(self, + learning_rate, + cov_ema_decay, + damping, + layer_collection, + var_list=None, + momentum=0., + momentum_type="regular", + norm_constraint=None, + name="KFAC", + estimation_mode="gradients", + colocate_gradients_with_ops=False, + cov_devices=None, + inv_devices=None): """Initializes the KFAC optimizer with the given settings. Args: @@ -64,6 +67,9 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): blocks, kronecker factors, and losses associated with the graph. The layer_collection cannot be modified after KfacOptimizer's initialization. + var_list: Optional list or tuple of variables to train. Defaults to the + list of variables collected in the graph under the key + `GraphKeys.TRAINABLE_VARIABLES`. momentum: The momentum value for this optimizer. Only applies when momentum_type is 'regular' or 'adam'. (Default: 0) momentum_type: The type of momentum to use in this optimizer, one of @@ -77,6 +83,14 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): 'gradients', 'empirical', 'curvature_propagation', or 'exact'. (Default: 'gradients'). See the doc-string for FisherEstimator for more a more detailed description of these options. + colocate_gradients_with_ops: Whether we should request gradients we + compute in the estimator be colocated with their respective ops. + cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance + computations will be placed on these devices in a round-robin fashion. + Can be None, which means that no devices are specified. + inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion + computations will be placed on these devices in a round-robin fashion. + Can be None, which means that no devices are specified. Raises: ValueError: If the momentum type is unsupported. @@ -86,13 +100,19 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): or 'adam'. """ - # We may consider determining the set of variables some other way, but for - # now it's just all the trainable variables. - variables = tf_variables.trainable_variables() + variables = var_list + if variables is None: + variables = tf_variables.trainable_variables() - self._fisher_est = est.FisherEstimator(variables, cov_ema_decay, damping, - layer_collection, - estimation_mode=estimation_mode) + self._fisher_est = est.FisherEstimator( + variables, + cov_ema_decay, + damping, + layer_collection, + estimation_mode=estimation_mode, + colocate_gradients_with_ops=colocate_gradients_with_ops, + cov_devices=cov_devices, + inv_devices=inv_devices) momentum_type = momentum_type.lower() legal_momentum_types = ["regular", "adam", "qmodel"] @@ -107,7 +127,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): raise ValueError("Momentum must be unspecified if using a momentum_type " "other than 'regular' or 'adam'.") - self._momentum = ops.convert_to_tensor(momentum, name="momentum") + self._momentum = momentum self._momentum_type = momentum_type self._norm_constraint = norm_constraint @@ -131,16 +151,24 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): return self._fisher_est.damping def minimize(self, *args, **kwargs): - - if "var_list" not in kwargs: - kwargs["var_list"] = tf_variables.trainable_variables() - + kwargs["var_list"] = kwargs.get("var_list") or self.variables if set(kwargs["var_list"]) != set(self.variables): raise ValueError("var_list doesn't match with set of Fisher-estimating " "variables.") - return super(KfacOptimizer, self).minimize(*args, **kwargs) + def compute_gradients(self, *args, **kwargs): + # args[1] could be our var_list + if len(args) > 1: + var_list = args[1] + else: + kwargs["var_list"] = kwargs.get("var_list") or self.variables + var_list = kwargs["var_list"] + if set(var_list) != set(self.variables): + raise ValueError("var_list doesn't match with set of Fisher-estimating " + "variables.") + return super(KfacOptimizer, self).compute_gradients(*args, **kwargs) + def apply_gradients(self, grads_and_vars, *args, **kwargs): """Applies gradients to variables. @@ -297,14 +325,17 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): self._batch_size, dtype=fft_precon_grads[0].dtype) # compute the entries of the 2x2 matrix - m_11 = (_inner_product_list(fft_precon_grads, fft_precon_grads) / batch_size - + self.damping * _inner_product_list(precon_grads, precon_grads)) + m_11 = ( + _inner_product_list(fft_precon_grads, fft_precon_grads) / batch_size + + self.damping * _inner_product_list(precon_grads, precon_grads)) - m_21 = (_inner_product_list(fft_prev_updates, fft_precon_grads) / batch_size - + self.damping * _inner_product_list(prev_updates, precon_grads)) + m_21 = ( + _inner_product_list(fft_prev_updates, fft_precon_grads) / batch_size + + self.damping * _inner_product_list(prev_updates, precon_grads)) - m_22 = (_inner_product_list(fft_prev_updates, fft_prev_updates) / batch_size - + self.damping * _inner_product_list(prev_updates, prev_updates)) + m_22 = ( + _inner_product_list(fft_prev_updates, fft_prev_updates) / batch_size + + self.damping * _inner_product_list(prev_updates, prev_updates)) def non_zero_prevupd_case(): r"""Computes optimal (alpha, mu) given non-zero previous update. @@ -390,8 +421,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): grads = list(grad for (grad, _) in grads_and_vars) variables = list(var for (_, var) in grads_and_vars) # previous updates are the negative velocities (up to scaling by LR) - prev_updates = list(-self._zeros_slot(var, "velocity", self._name) - for var in variables) + prev_updates = list( + -self._zeros_slot(var, "velocity", self._name) for var in variables) # Compute optimal velocity update parameters according to quadratic model alpha, mu, _ = self._compute_qmodel_hyperparams( diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py index 0fd7f5147739f0f46d2ab6a1c284c6dc75f53cc2..cec018e406bc51c07f5cafcc2c38efe7e9601618 100644 --- a/tensorflow/contrib/kfac/python/ops/utils.py +++ b/tensorflow/contrib/kfac/python/ops/utils.py @@ -28,9 +28,9 @@ from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops - # Method used for inverting matrices. POSDEF_INV_METHOD = "cholesky" +POSDEF_EIG_METHOD = "self_adjoint" def set_global_constants(posdef_inv_method=None): @@ -64,13 +64,6 @@ class SequenceDict(object): return list(self._dict.items()) -def setdefault(dct, key, thunk): - """Like dict.setdefault but delays evaluation of the value to be set.""" - if key not in dct: - dct[key] = thunk() - return dct[key] - - def tensors_to_column(tensors): """Converts a tensor or list of tensors to a column vector. @@ -169,33 +162,11 @@ def mat2d_to_layer_params(vector_template, mat2d): return array_ops.reshape(mat2d, vector_template.shape) -def compute_pi(left_factor, right_factor): - """Computes the scalar constant pi for Tikhonov regularization/damping. - - pi = sqrt( (trace(A) / dim(A)) / (trace(B) / dim(B)) ) - See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details. - - Args: - left_factor: The left Kronecker factor Tensor. - right_factor: The right Kronecker factor Tensor. - - Returns: - The computed scalar constant pi for these Kronecker Factors (as a Tensor). - """ - # Instead of dividing by the dim of the norm, we multiply by the dim of the - # other norm. This works out the same in the ratio. - left_norm = math_ops.trace(left_factor) * right_factor.get_shape().as_list()[ - 0] - right_norm = math_ops.trace(right_factor) * left_factor.get_shape().as_list()[ - 0] - return math_ops.sqrt(left_norm / right_norm) - - def posdef_inv(tensor, damping): """Computes the inverse of tensor + damping * identity.""" identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype) damping = math_ops.cast(damping, dtype=tensor.dtype) - return posdef_inv_funcs[POSDEF_INV_METHOD](tensor, identity, damping) + return posdef_inv_functions[POSDEF_INV_METHOD](tensor, identity, damping) def posdef_inv_matrix_inverse(tensor, identity, damping): @@ -209,9 +180,44 @@ def posdef_inv_cholesky(tensor, identity, damping): return linalg_ops.cholesky_solve(chol, identity) -posdef_inv_funcs = { +def posdef_inv_eig(tensor, identity, damping): + """Computes inverse(tensor + damping * identity) with eigendecomposition.""" + eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig( + tensor + damping * identity) + return math_ops.matmul( + eigenvectors / eigenvalues, eigenvectors, transpose_b=True) + + +posdef_inv_functions = { "matrix_inverse": posdef_inv_matrix_inverse, "cholesky": posdef_inv_cholesky, + "eig": posdef_inv_eig, +} + + +def posdef_eig(mat): + """Computes the eigendecomposition of a positive semidefinite matrix.""" + return posdef_eig_functions[POSDEF_EIG_METHOD](mat) + + +def posdef_eig_svd(mat): + """Computes the singular values and left singular vectors of a matrix.""" + evals, evecs, _ = linalg_ops.svd(mat) + + return evals, evecs + + +def posdef_eig_self_adjoint(mat): + """Computes eigendecomposition using self_adjoint_eig.""" + evals, evecs = linalg_ops.self_adjoint_eig(mat) + evals = math_ops.abs(evals) # Should be equivalent to svd approach. + + return evals, evecs + + +posdef_eig_functions = { + "self_adjoint": posdef_eig_self_adjoint, + "svd": posdef_eig_svd, } @@ -268,8 +274,8 @@ def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None): # generated by the first gradients_impl.gradients call. us = [array_ops.zeros_like(y) + float("nan") for y in ys] - dydxs = gradients_impl.gradients(ys, xs, grad_ys=us, - stop_gradients=stop_gradients) + dydxs = gradients_impl.gradients( + ys, xs, grad_ys=us, stop_gradients=stop_gradients) # Deal with strange types that gradients_impl.gradients returns but can't # deal with. @@ -285,3 +291,6 @@ def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None): dysdx = gradients_impl.gradients(dydxs, us, grad_ys=grad_xs) return dysdx + +# TODO(b/69623235): Add a function for finding tensors that share gradients +# to eliminate redundant fisher factor computations. diff --git a/tensorflow/contrib/kfac/python/ops/utils_lib.py b/tensorflow/contrib/kfac/python/ops/utils_lib.py index ddbb4485ce6967082f1844c6d798c078f1cc303b..8903c90fbce6a890aa419d89b3b79d75f69509fc 100644 --- a/tensorflow/contrib/kfac/python/ops/utils_lib.py +++ b/tensorflow/contrib/kfac/python/ops/utils_lib.py @@ -25,13 +25,11 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ "SequenceDict", - "setdefault", "tensors_to_column", "column_to_tensors", "kronecker_product", "layer_params_to_mat2d", "mat2d_to_layer_params", - "compute_pi", "posdef_inv", "posdef_inv_matrix_inverse", "posdef_inv_cholesky", diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index 2f1f283811b6cb9e8bfb52ab2052afac1de700cb..852d06e1e3cc8f8deecd15b7436cd4e4a393ad66 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -61,6 +61,7 @@ tf_custom_op_py_library( "python/layers/normalization.py", "python/layers/optimizers.py", "python/layers/regularizers.py", + "python/layers/rev_block_lib.py", "python/layers/summaries.py", "python/layers/target_column.py", "python/layers/utils.py", @@ -376,6 +377,20 @@ py_test( ], ) +py_test( + name = "rev_block_lib_test", + size = "small", + srcs = ["python/layers/rev_block_lib_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":layers_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:init_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py index d309ba958ded86afdc1e4bba2ff471a5181cda4e..6c624929f20503054e0258aad8a843f4a201be64 100644 --- a/tensorflow/contrib/layers/__init__.py +++ b/tensorflow/contrib/layers/__init__.py @@ -42,6 +42,9 @@ See the @{$python/contrib.layers} guide. @@relu @@relu6 @@repeat +@@recompute_grad +@@RevBlock +@@rev_block @@safe_embedding_lookup_sparse @@scale_gradient @@separable_conv2d diff --git a/tensorflow/contrib/layers/python/layers/__init__.py b/tensorflow/contrib/layers/python/layers/__init__.py index 03337f9a5d11784316124442125bb498c4ce9603..f1ae2de68be33880a6fc09957f4d857973902b26 100644 --- a/tensorflow/contrib/layers/python/layers/__init__.py +++ b/tensorflow/contrib/layers/python/layers/__init__.py @@ -28,6 +28,7 @@ from tensorflow.contrib.layers.python.layers.layers import * from tensorflow.contrib.layers.python.layers.normalization import * from tensorflow.contrib.layers.python.layers.optimizers import * from tensorflow.contrib.layers.python.layers.regularizers import * +from tensorflow.contrib.layers.python.layers.rev_block_lib import * from tensorflow.contrib.layers.python.layers.summaries import * from tensorflow.contrib.layers.python.layers.target_column import * from tensorflow.contrib.layers.python.ops.bucketization_op import * diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py index 226d933d85d91600e36ffb84212703e10455bfbb..8d2931b4867938024a494459c77976e1e714de5a 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column.py +++ b/tensorflow/contrib/layers/python/layers/feature_column.py @@ -156,6 +156,10 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import deprecation +# Imports the core `InputLayer` symbol in contrib during development. +InputLayer = fc_core.InputLayer # pylint: disable=invalid-name + + class _LinearEmbeddingLookupArguments( collections.namedtuple("_LinearEmbeddingLookupArguments", ["input_tensor", @@ -521,7 +525,7 @@ def sparse_column_with_integerized_feature(column_name, Args: column_name: A string defining sparse column name. - bucket_size: An int that is > 1. The number of buckets. It should be bigger + bucket_size: An int that is >= 1. The number of buckets. It should be bigger than maximum feature. In other words features in this column should be an int64 in range [0, bucket_size) combiner: A string specifying how to reduce if the sparse column is @@ -539,7 +543,7 @@ def sparse_column_with_integerized_feature(column_name, An integerized _SparseColumn definition. Raises: - ValueError: bucket_size is not greater than 1. + ValueError: bucket_size is less than 1. ValueError: dtype is not integer. """ return _SparseColumnIntegerized( diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops.py b/tensorflow/contrib/layers/python/layers/feature_column_ops.py index fa0047f05d893f6543ddb1680824a32469e13293..78affea44cbfb92523063968dbc1be98841854db 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops.py @@ -97,10 +97,13 @@ def _input_from_feature_columns(columns_to_tensors, trainable, scope, output_rank, - default_name): + default_name, + cols_to_outs=None): """Implementation of `input_from(_sequence)_feature_columns`.""" columns_to_tensors = columns_to_tensors.copy() check_feature_columns(feature_columns) + if cols_to_outs is not None and not isinstance(cols_to_outs, dict): + raise ValueError('cols_to_outs must be a dict unless None') with variable_scope.variable_scope(scope, default_name=default_name, values=columns_to_tensors.values()): @@ -144,6 +147,8 @@ def _input_from_feature_columns(columns_to_tensors, except ValueError as e: raise ValueError('Error creating input layer for column: {}.\n' '{}, {}'.format(column.name, e, ee)) + if cols_to_outs is not None: + cols_to_outs[column] = output_tensors[-1] return array_ops.concat(output_tensors, output_rank - 1) @@ -151,7 +156,8 @@ def input_from_feature_columns(columns_to_tensors, feature_columns, weight_collections=None, trainable=True, - scope=None): + scope=None, + cols_to_outs=None): """A tf.contrib.layers style input layer builder based on FeatureColumns. Generally a single example in training data is described with feature columns. @@ -196,6 +202,8 @@ def input_from_feature_columns(columns_to_tensors, trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). scope: Optional scope for variable_scope. + cols_to_outs: Optional dict from feature column to output tensor, + which is concatenated into the returned tensor. Returns: A Tensor which can be consumed by hidden layers in the neural network. @@ -209,7 +217,8 @@ def input_from_feature_columns(columns_to_tensors, trainable, scope, output_rank=2, - default_name='input_from_feature_columns') + default_name='input_from_feature_columns', + cols_to_outs=cols_to_outs) @experimental diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py index fbfa0e32de55edab3c90189ddfe05ab826ac9167..e6bbd86ab722c4e853a59f816bed8a8ac1fe9ede 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py @@ -607,6 +607,31 @@ class CreateInputLayersForDNNsTest(test.TestCase): # Verify cross compatibility: Core builder output should equal to contrib. self.assertAllEqual(output.eval().shape, output_core.eval().shape) + def testAllDNNColumnsWithColumnwiseOutputs(self): + sparse_column = feature_column.sparse_column_with_keys( + "ids", ["a", "b", "c", "unseen"]) + real_valued_column = feature_column.real_valued_column("income", 2) + one_hot_column = feature_column.one_hot_column(sparse_column) + embedding_column = feature_column.embedding_column(sparse_column, 10) + features = { + "ids": + sparse_tensor.SparseTensor( + values=["c", "b", "a"], + indices=[[0, 0], [1, 0], [2, 0]], + dense_shape=[3, 1]), + "income": + constant_op.constant([[20.3, 10], [110.3, 0.4], [-3.0, 30.4]]), + } + columns = [one_hot_column, embedding_column, real_valued_column] + cols_to_outs = {} + feature_column_ops.input_from_feature_columns( + features, columns, cols_to_outs=cols_to_outs) + with self.test_session(): + variables_lib.global_variables_initializer().run() + lookup_ops.tables_initializer().run() + for column in columns: + self.assertTrue(column in cols_to_outs) + def testRealValuedColumn(self): real_valued = feature_column.real_valued_column("price") features = {"price": constant_op.constant([[20.], [110], [-3]])} diff --git a/tensorflow/contrib/layers/python/layers/initializers.py b/tensorflow/contrib/layers/python/layers/initializers.py index b12a882d9ae88f7cf4f920cfa5872e5de1c67290..51610f21b24f1d40f26630cc1e69ca723d130639 100644 --- a/tensorflow/contrib/layers/python/layers/initializers.py +++ b/tensorflow/contrib/layers/python/layers/initializers.py @@ -79,7 +79,8 @@ def variance_scaling_initializer(factor=2.0, mode='FAN_IN', uniform=False, ``` * To get [Delving Deep into Rectifiers]( - http://arxiv.org/pdf/1502.01852v1.pdf), use (Default):
+ http://arxiv.org/pdf/1502.01852v1.pdf) (also know as the "MSRA + initialization"), use (Default):
`factor=2.0 mode='FAN_IN' uniform=False` * To get [Convolutional Architecture for Fast Feature Embedding]( http://arxiv.org/abs/1408.5093), use:
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 30630852181e8f4fdf6f8dd83fb852759806b36b..0d25a09852544a7eb1ed5eb9c2f3402d9064d91a 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -309,7 +309,6 @@ def _fused_batch_norm(inputs, new_shape = [-1, channels, 1, 1] inputs = array_ops.reshape(inputs, new_shape) inputs_shape = inputs.get_shape() - dtype = inputs.dtype.base_dtype if data_format == DATA_FORMAT_NHWC: params_shape = inputs_shape[-1:] else: @@ -2562,7 +2561,10 @@ def separable_convolution2d( regularizer=weights_regularizer, trainable=trainable, collections=weights_collections) - strides = [1, 1, stride_h, stride_w] if data_format.startswith('NC') else [1, stride_h, stride_w, 1] + strides = [1, 1, stride_h, + stride_w] if data_format.startswith('NC') else [ + 1, stride_h, stride_w, 1 + ] outputs = nn.depthwise_conv2d(inputs, depthwise_weights, strides, padding, rate=utils.two_element_tuple(rate), @@ -2652,51 +2654,52 @@ def spatial_softmax(features, ValueError: If unexpected data_format specified. ValueError: If num_channels dimension is unspecified. """ - shape = array_ops.shape(features) - static_shape = features.shape - if data_format == DATA_FORMAT_NHWC: - height, width, num_channels = shape[1], shape[2], static_shape[3] - elif data_format == DATA_FORMAT_NCHW: - num_channels, height, width = static_shape[1], shape[2], shape[3] - else: - raise ValueError('data_format has to be either NCHW or NHWC.') - if num_channels.value is None: - raise ValueError('The num_channels dimension of the inputs to ' - '`spatial_softmax` should be defined. Found `None`.') - - with ops.name_scope(name, 'spatial_softmax', [features]) as name: - # Create tensors for x and y coordinate values, scaled to range [-1, 1]. - pos_x, pos_y = array_ops.meshgrid(math_ops.lin_space(-1., 1., num=height), - math_ops.lin_space(-1., 1., num=width), - indexing='ij') - pos_x = array_ops.reshape(pos_x, [height * width]) - pos_y = array_ops.reshape(pos_y, [height * width]) - if temperature is None: - temperature_collections = utils.get_variable_collections( - variables_collections, 'temperature') - temperature = variables.model_variable( - 'temperature', - shape=(), - dtype=dtypes.float32, - initializer=init_ops.ones_initializer(), - collections=temperature_collections, - trainable=trainable) - if data_format == 'NCHW': - features = array_ops.reshape(features, [-1, height * width]) + with variable_scope.variable_scope(name, 'spatial_softmax'): + shape = array_ops.shape(features) + static_shape = features.shape + if data_format == DATA_FORMAT_NHWC: + height, width, num_channels = shape[1], shape[2], static_shape[3] + elif data_format == DATA_FORMAT_NCHW: + num_channels, height, width = static_shape[1], shape[2], shape[3] else: - features = array_ops.reshape( - array_ops.transpose(features, [0, 3, 1, 2]), [-1, height * width]) - - softmax_attention = nn.softmax(features/temperature) - expected_x = math_ops.reduce_sum( - pos_x * softmax_attention, [1], keep_dims=True) - expected_y = math_ops.reduce_sum( - pos_y * softmax_attention, [1], keep_dims=True) - expected_xy = array_ops.concat([expected_x, expected_y], 1) - feature_keypoints = array_ops.reshape( - expected_xy, [-1, num_channels.value * 2]) - feature_keypoints.set_shape([None, num_channels.value * 2]) - return feature_keypoints + raise ValueError('data_format has to be either NCHW or NHWC.') + if num_channels.value is None: + raise ValueError('The num_channels dimension of the inputs to ' + '`spatial_softmax` should be defined. Found `None`.') + + with ops.name_scope('spatial_softmax_op', 'spatial_softmax_op', [features]): + # Create tensors for x and y coordinate values, scaled to range [-1, 1]. + pos_x, pos_y = array_ops.meshgrid(math_ops.lin_space(-1., 1., num=height), + math_ops.lin_space(-1., 1., num=width), + indexing='ij') + pos_x = array_ops.reshape(pos_x, [height * width]) + pos_y = array_ops.reshape(pos_y, [height * width]) + if temperature is None: + temperature_collections = utils.get_variable_collections( + variables_collections, 'temperature') + temperature = variables.model_variable( + 'temperature', + shape=(), + dtype=dtypes.float32, + initializer=init_ops.ones_initializer(), + collections=temperature_collections, + trainable=trainable) + if data_format == 'NCHW': + features = array_ops.reshape(features, [-1, height * width]) + else: + features = array_ops.reshape( + array_ops.transpose(features, [0, 3, 1, 2]), [-1, height * width]) + + softmax_attention = nn.softmax(features/temperature) + expected_x = math_ops.reduce_sum( + pos_x * softmax_attention, [1], keep_dims=True) + expected_y = math_ops.reduce_sum( + pos_y * softmax_attention, [1], keep_dims=True) + expected_xy = array_ops.concat([expected_x, expected_y], 1) + feature_keypoints = array_ops.reshape( + expected_xy, [-1, num_channels.value * 2]) + feature_keypoints.set_shape([None, num_channels.value * 2]) + return feature_keypoints def stack(inputs, layer, stack_args, **kwargs): diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 9019d3a60991fa0274de10c95986a61c21223bd7..ae64b75d939ce0ffab300b01d3cfcb67a9d0da1c 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1779,7 +1779,8 @@ class BatchNormTest(test.TestCase): dtype = dtypes.float32 height, width = 3, 3 with self.test_session(): - images = np.random.uniform(size=(5, height, width, 3)).astype(dtype.as_numpy_dtype) + images = np.random.uniform(size=(5, height, width, 3)).astype( + dtype.as_numpy_dtype) output = _layers.batch_norm(images, fused=fused) expected_name = ('BatchNorm/FusedBatchNorm' if fused else 'BatchNorm/batchnorm') @@ -2665,18 +2666,18 @@ class BatchNormTest(test.TestCase): # Test case for 11673 with self.test_session() as sess: a_32 = array_ops.placeholder(dtypes.float32, shape=(10, 10, 10, 10)) - b_32 = _layers.batch_norm(a_32, center=False, data_format='NCHW', - zero_debias_moving_mean=True) + _layers.batch_norm( + a_32, center=False, data_format='NCHW', zero_debias_moving_mean=True) a_16 = array_ops.placeholder(dtypes.float16, shape=(10, 10, 10, 10)) - b_16 = _layers.batch_norm(a_16, center=False, data_format='NCHW', - zero_debias_moving_mean=True) + _layers.batch_norm( + a_16, center=False, data_format='NCHW', zero_debias_moving_mean=True) sess.run(variables_lib.global_variables_initializer()) def testVariablesAreFloat32(self): height, width = 3, 3 with self.test_session(): - images = random_ops.random_uniform((5, height, width, 3), - seed=1, dtype=dtypes.float16) + images = random_ops.random_uniform( + (5, height, width, 3), seed=1, dtype=dtypes.float16) _layers.batch_norm(images, scale=True) beta = variables.get_variables_by_name('beta')[0] gamma = variables.get_variables_by_name('gamma')[0] @@ -2691,17 +2692,13 @@ class BatchNormTest(test.TestCase): channels = shape[1] images = np.arange(np.product(shape), dtype=dtype).reshape(shape) beta = init_ops.constant_initializer( - np.arange( - 2, channels + 2, dtype=np.float32)) + np.arange(2, channels + 2, dtype=np.float32)) gamma = init_ops.constant_initializer( - np.arange( - 10, channels + 10, dtype=np.float32) * 2.0) + np.arange(10, channels + 10, dtype=np.float32) * 2.0) mean = init_ops.constant_initializer( - np.arange( - 3, channels + 3, dtype=np.float32) * 5.0) + np.arange(3, channels + 3, dtype=np.float32) * 5.0) variance = init_ops.constant_initializer( - np.arange( - 1, channels + 1, dtype=np.float32) * 4.0) + np.arange(1, channels + 1, dtype=np.float32) * 4.0) output = _layers.batch_norm( images, fused=True, @@ -2726,7 +2723,6 @@ class BatchNormTest(test.TestCase): res_16 = self._runFusedBatchNorm(shape, np.float16) self.assertAllClose(res_32, res_16, rtol=1e-3) - def testAdjustmentCreated(self): # Tests that the adjustment is appropriately passed to and used by the core # BN layer. @@ -3336,11 +3332,18 @@ class SeparableConv2dTest(test.TestCase): batch, height, width = 4, 10, 12 kernel_dim, stride = 3, 2 images = random_ops.random_uniform((batch, 3, height, width), seed=1) - output = layers_lib.separable_conv2d(images, num_outputs=num_filters, kernel_size=[kernel_dim, kernel_dim], - depth_multiplier=2, stride=stride, padding='VALID', data_format='NCHW') - self.assertListEqual( - output.get_shape().as_list(), [batch, correct_output_filters, - (height - kernel_dim + 1) // stride, (width - kernel_dim + 1) // stride]) + output = layers_lib.separable_conv2d( + images, + num_outputs=num_filters, + kernel_size=[kernel_dim, kernel_dim], + depth_multiplier=2, + stride=stride, + padding='VALID', + data_format='NCHW') + self.assertListEqual(output.get_shape().as_list(), [ + batch, correct_output_filters, (height - kernel_dim + 1) // stride, + (width - kernel_dim + 1) // stride + ]) class ScaleGradientTests(test.TestCase): diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..123275e1fde047cd3772528641b2e3b09742fbdc --- /dev/null +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py @@ -0,0 +1,583 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Reversible Residual Block. + +From +[The Reversible Residual Network: Backpropagation Without Storing +Activations](https://arxiv.org/abs/1707.04585). + +Also contains the @recompute_grad decorator, which recomputes the forward +function on the backwards pass. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import re + +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.contrib.framework.python import ops as contrib_framework_ops +from tensorflow.python.framework import function +from tensorflow.python.framework import ops as framework_ops +from tensorflow.python.layers import base +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import nest + +__all__ = ["rev_block", "RevBlock", "recompute_grad"] + +LAYER_RE = re.compile(".*revlayer_([0-9]*)/([fg])/.*") + + +def _acc_grads(*lists_of_grads): + """Accumulates lists of gradients.""" + acc_grads = [] + for grads in zip(*lists_of_grads): + grads = [g for g in grads if g is not None] + if grads: + acc_grads.append(math_ops.add_n(grads)) + else: + acc_grads.append(None) + return acc_grads + + +def _rev_layer_forward(xs, f, g, f_side_input, g_side_input, + gate_outputs=False): + """Forward for 1 reversible layer.""" + x1, x2 = xs + y1 = x1 + (f(x2, f_side_input) if f_side_input else f(x2)) + y2 = x2 + (g(y1, g_side_input) if g_side_input else g(y1)) + if gate_outputs: + return control_flow_ops.tuple([y1, y2]) + else: + return (y1, y2) + + +def _rev_layer_backward(ys, grad_ys, f, g, f_vars, f_side_input, g_vars, + g_side_input): + """Backprop for 1 layer.""" + y1, y2 = ys + grad_y1, grad_y2 = grad_ys + + # Reconstruct intermediates and inputs (x1, x2) + # stop_gradients required on fn inputs to prevent infinite recursion into this + # grad function on the calls to gradients. + y1_stop = array_ops.stop_gradient(y1) + g_side_input = [array_ops.stop_gradient(t) for t in g_side_input] + gy1 = g(y1_stop, g_side_input) if g_side_input else g(y1_stop) + + x2 = y2 - gy1 + x2_stop = array_ops.stop_gradient(x2) + f_side_input = [array_ops.stop_gradient(t) for t in f_side_input] + fx2 = f(x2_stop, f_side_input) if f_side_input else f(x2_stop) + + x1 = y1 - fx2 + + # Compute gradients wrt to inputs + # dL/dy2 * dG(y1)/y1 + grad_gy1_y2 = gradients_impl.gradients(gy1, y1_stop, grad_y2)[0] + grad_x1 = grad_y1 + grad_gy1_y2 + grad_x2 = ( + gradients_impl.gradients(fx2, x2_stop, grad_y1)[0] + grad_y2 + + gradients_impl.gradients(fx2, x2_stop, grad_gy1_y2)[0]) + + # Compute gradients wrt to vars and side inputs in f and g + grads1 = gradients_impl.gradients(gy1, g_vars + g_side_input, grad_y2) + grad_g_vars, grad_g_side = grads1[:len(g_vars)], grads1[len(g_vars):] + grads2 = gradients_impl.gradients(fx2, f_vars + f_side_input, grad_y1) + grad_f_y1, grad_f_side1 = grads2[:len(f_vars)], grads2[len(f_vars):] + grads3 = gradients_impl.gradients(fx2, f_vars + f_side_input, grad_gy1_y2) + grad_f_y2, grad_f_side2 = grads3[:len(f_vars)], grads3[len(f_vars):] + grad_f_vars = _acc_grads(grad_f_y1, grad_f_y2) + + grad_f_side = _acc_grads(grad_f_side1, grad_f_side2) + + # Put returns in a tuple to ensure a constant memory budget (i.e. don't want + # the subsequent layer to start computing and consuming memory based on a + # subset of these values). + outputs = ((x1, x2), (grad_x1, grad_x2), (grad_f_vars, grad_f_side), + (grad_g_vars, grad_g_side)) + tupled = control_flow_ops.tuple(nest.flatten(outputs)) + return nest.pack_sequence_as(outputs, tupled) + + +def _rev_block_forward(x1, + x2, + f, + g, + num_layers=1, + f_side_input=None, + g_side_input=None, + gate_outputs=False): + """Forward for a series of reversible layers.""" + out = (x1, x2) + for i in xrange(num_layers): + out = _rev_layer_forward( + out, f[i], g[i], f_side_input, g_side_input, gate_outputs=gate_outputs) + + y1, y2 = out + return y1, y2 + + +def _scope_wrap(fn, scope): + + @functools.wraps(fn) + def wrap(*args, **kwargs): + with variable_scope.variable_scope(scope): + return fn(*args, **kwargs) + + return wrap + + +class RevBlock(base.Layer): + """Block of reversible layers. See rev_block.""" + + def __init__(self, + f, + g, + num_layers=1, + f_side_input=None, + g_side_input=None, + use_efficient_backprop=True, + name="revblock", + **kwargs): + super(RevBlock, self).__init__(name=name, **kwargs) + + if isinstance(f, list): + assert len(f) == num_layers + else: + f = [f] * num_layers + + if isinstance(g, list): + assert len(g) == num_layers + else: + g = [g] * num_layers + + f = [_scope_wrap(fn, "revlayer_%d/f" % i) for i, fn in enumerate(f)] + g = [_scope_wrap(fn, "revlayer_%d/g" % i) for i, fn in enumerate(g)] + + self.f = f + self.g = g + + self.num_layers = num_layers + self.f_side_input = f_side_input or [] + self.g_side_input = g_side_input or [] + + self._use_efficient_backprop = use_efficient_backprop + + def call(self, inputs, forward=True): + vs = variable_scope.get_variable_scope() + vars_before = vs.global_variables() + + if forward: + x1, x2 = inputs + out = self._forward(x1, x2) + else: + y1, y2 = inputs + out = self._backward(y1, y2) + + # Add any created variables to the Layer's variable stores + new_vars = vs.global_variables()[len(vars_before):] + train_vars = vs.trainable_variables() + for new_var in new_vars: + if new_var in train_vars: + self._trainable_weights.append(new_var) + else: + self._non_trainable_weights.append(new_var) + + return out + + def forward(self, x1, x2): + return self.apply([x1, x2]) + + def backward(self, y1, y2): + return self.apply([y1, y2], forward=False) + + def build(self, _): + logging.warn("RevBlock constructs its variables on first call, not on " + "build.") + self.built = True + + def _efficient_grad_fn(self, inputs, variables, ys, grad_ys): + """Custom gradient fn for a block of reversible residual layers.""" + side_inputs = inputs[2:] + f_side_idxs = [None] * len(self.f_side_input) + g_side_idxs = [None] * len(self.g_side_input) + assert len(side_inputs) == len(self.f_side_input) + len(self.g_side_input) + + for i, t in enumerate(side_inputs): + if t in self.f_side_input: + f_side_idxs[self.f_side_input.index(t)] = i + elif t in self.g_side_input: + g_side_idxs[self.g_side_input.index(t)] = i + else: + assert False + + f_vars = [[] for _ in range(self.num_layers)] + g_vars = [[] for _ in range(self.num_layers)] + f_vars_idxs = [[] for _ in range(self.num_layers)] + g_vars_idxs = [[] for _ in range(self.num_layers)] + + for i, t in enumerate(variables): + ref = _underlying_variable_ref(t) + + # Use the name to identify the layer number and function (f or g) + regex = LAYER_RE.match(ref.name) + layer_no = int(regex.group(1)) + fn_name = regex.group(2) + if fn_name == "f": + f_vars[layer_no].append(ref) + f_vars_idxs[layer_no].append(i) + else: + assert fn_name == "g" + g_vars[layer_no].append(ref) + g_vars_idxs[layer_no].append(i) + + f_var_grads = [] + g_var_grads = [] + f_side_grads = [] + g_side_grads = [] + + # Reverse variable containers to go backward + f_vars.reverse() + g_vars.reverse() + f = list(self.f) + g = list(self.g) + f.reverse() + g.reverse() + + with variable_scope.variable_scope(self.scope_name, reuse=True): + for i in xrange(self.num_layers): + ys, grad_ys, f_ret, g_ret = _rev_layer_backward( + ys, grad_ys, f[i], g[i], f_vars[i], self.f_side_input, g_vars[i], + self.g_side_input) + + grad_f_vars, grad_f_side = f_ret + grad_g_vars, grad_g_side = g_ret + f_var_grads.append(grad_f_vars) + g_var_grads.append(grad_g_vars) + f_side_grads.append(grad_f_side) + g_side_grads.append(grad_g_side) + + # Accumulate layer gradients for f_side_input and g_side_input + acc_f_side_grads = _acc_grads(*f_side_grads) + acc_g_side_grads = _acc_grads(*g_side_grads) + + # Use the stored idxs to put gradients in the passed-in order. + side_input_grads = [None] * len(side_inputs) + variable_grads = [None] * len(variables) + + # Variable gradients were collected in reverse layer order. Reverse to match + # idxs. + f_var_grads.reverse() + g_var_grads.reverse() + for idxs, grads in list(zip(f_vars_idxs, f_var_grads)) + list( + zip(g_vars_idxs, g_var_grads)): + for i, grad in zip(idxs, grads): + variable_grads[i] = grad + + for i, grad in zip(f_side_idxs, acc_f_side_grads): + side_input_grads[i] = grad + for i, grad in zip(g_side_idxs, acc_g_side_grads): + side_input_grads[i] = grad + + grad_x1, grad_x2 = grad_ys + return [grad_x1, grad_x2] + side_input_grads, variable_grads + + def _forward(self, x1, x2): + """Run forward through the reversible layers.""" + + side_inputs = [self.f_side_input, self.g_side_input] + flat_side_inputs = nest.flatten(side_inputs) + + custom_grad_fn = ( + self._efficient_grad_fn if self._use_efficient_backprop else None) + + @_fn_with_custom_grad(custom_grad_fn) + def _forward_wrap(x1_, x2_, *flat_side_inputs): + f_side, g_side = nest.pack_sequence_as(side_inputs, flat_side_inputs) + return _rev_block_forward( + x1_, + x2_, + self.f, + self.g, + num_layers=self.num_layers, + f_side_input=f_side, + g_side_input=g_side, + gate_outputs=self._use_efficient_backprop) + + return _forward_wrap(x1, x2, *flat_side_inputs) + + def _backward(self, y1, y2): + """Run backward through the reversible layers.""" + + f = list(self.f) + g = list(self.g) + f.reverse() + g.reverse() + + for i in xrange(self.num_layers): + gy1 = g[i](y1, self.g_side_input) if self.g_side_input else g[i](y1) + x2 = y2 - gy1 + fx2 = f[i](x2, self.f_side_input) if self.f_side_input else f[i](x2) + x1 = y1 - fx2 + + y1, y2 = x1, x2 + + return x1, x2 + + +def rev_block(x1, + x2, + f, + g, + num_layers=1, + f_side_input=None, + g_side_input=None, + is_training=True): + """A block of reversible residual layers. + + A reversible residual layer is defined as: + + ``` + y1 = x1 + f(x2, f_side_input) + y2 = x2 + g(y1, g_side_input) + ``` + + A reversible residual block, defined here, is a series of reversible residual + layers. + + Limitations: + * f and g must not close over any Tensors; all side inputs to f and g should + be passed in with f_side_input and g_side_input which will be forwarded to + f and g. + * f and g must not change the dimensionality of their inputs in order for the + addition in the equations above to work. + + Args: + x1: a float Tensor. + x2: a float Tensor. + f: a function, (Tensor) -> (Tensor) (or list of such of length num_layers). + Should not change the shape of the Tensor. Can make calls to get_variable. + See f_side_input if there are side inputs. + g: a function, (Tensor) -> (Tensor) (or list of such of length num_layers). + Should not change the shape of the Tensor. Can make calls to get_variable. + See g_side_input if there are side inputs. + num_layers: int, number of reversible residual layers. Each layer will + apply f and g according to the equations above, with new variables in each + layer. + f_side_input: list of Tensors, side input to f. If not None, signature of f + should be (Tensor, list) -> (Tensor). + g_side_input: list of Tensors, side input to g. If not None, signature of g + should be (Tensor, list) -> (Tensor). + is_training: bool, whether to actually use the efficient backprop codepath. + + Returns: + y1, y2: tuple of float Tensors. + """ + block = RevBlock( + f=f, + g=g, + num_layers=num_layers, + f_side_input=f_side_input, + g_side_input=g_side_input, + use_efficient_backprop=is_training, + _reuse=variable_scope.get_variable_scope().reuse) + return block.forward(x1, x2) + + +def recompute_grad(fn): + """Decorator that recomputes the function on the backwards pass. + + Args: + fn: a function that takes Tensors (all as positional arguments) and returns + a tuple of Tensors. + + Returns: + A wrapped fn that is identical to fn when called, but its activations will + be discarded and recomputed on the backwards pass (i.e. on a call to + tf.gradients). + """ + + @functools.wraps(fn) + def wrapped(*args): + return _recompute_grad(fn, args) + + return wrapped + + +def _recompute_grad(fn, args): + """See recompute_grad.""" + + cached_vs = [] + cached_arg_scope = [] + + def grad_fn(inputs, variables, outputs, output_grads): + """Recompute outputs for gradient computation.""" + del outputs + # Recompute outputs + with framework_ops.control_dependencies(output_grads): + with contrib_framework_ops.arg_scope(cached_arg_scope[0]): + with variable_scope.variable_scope(cached_vs[0], reuse=True): + outputs = fn(*inputs) + + if not (isinstance(outputs, list) or isinstance(outputs, tuple)): + outputs = [outputs] + outputs = list(outputs) + grads = gradients_impl.gradients(outputs, inputs + variables, output_grads) + grad_inputs = grads[:len(inputs)] + grad_vars = grads[len(inputs):] + return grad_inputs, grad_vars + + @_fn_with_custom_grad(grad_fn) + def fn_with_recompute(*args): + cached_vs.append(variable_scope.get_variable_scope()) + # TODO(rsepassi): Rm conditional in TF 1.4 + if hasattr(contrib_framework_ops, "current_arg_scope"): + cached_arg_scope.append(contrib_framework_ops.current_arg_scope()) + else: + cached_arg_scope.append({}) + return fn(*args) + + return fn_with_recompute(*args) + + +def _underlying_variable_ref(t): + """Find the underlying variable ref. + + Traverses through Identity, ReadVariableOp, and Enter ops. + Stops when op type has Variable or VarHandle in name. + + Args: + t: a Tensor + + Returns: + a Tensor that is a variable ref, or None on error. + """ + while t.op.type in ["Identity", "ReadVariableOp", "Enter"]: + t = t.op.inputs[0] + + op_type = t.op.type + if "Variable" in op_type or "VarHandle" in op_type: + return t + else: + return None + + +def _fn_with_custom_grad(grad_fn, use_global_vars=False): + """Decorator to create a subgraph with a custom gradient function. + + The subgraph created by the decorated function is NOT put in a Defun and so + does not suffer from the limitations of the Defun (all subgraph ops on the + same device, no summaries). + + Args: + grad_fn: function with signature + (inputs, variables, outputs, output_grads) -> (grad_inputs, grad_vars), + all of which are lists of Tensors. + use_global_vars: if True, variables will be the global variables created. + If False, will be the trainable variables. + + Returns: + Decorator for function such that the gradient is defined by grad_fn. + """ + + def dec(fn): + + @functools.wraps(fn) + def wrapped(*args): + return _fn_with_custom_grad_internal( + fn, args, grad_fn, use_global_vars=use_global_vars) + + return wrapped + + return dec + + +def _fn_with_custom_grad_internal(fn, inputs, grad_fn, use_global_vars=False): + """Create a subgraph with a custom gradient. + + Args: + fn: function that takes inputs as arguments and produces 1 or more Tensors. + inputs: list, will be passed as fn(*inputs). + grad_fn: function with signature + (inputs, vars, outputs, output_grads) -> (grad_inputs, grad_vars), + all of which are lists of Tensors. + use_global_vars: if True, variables will be the global variables created. + If False, will be the trainable variables. + + Returns: + fn(*inputs) + """ + vs = variable_scope.get_variable_scope() + get_vars_fn = ( + vs.global_variables if use_global_vars else vs.trainable_variables) + len_before_vars = len(get_vars_fn()) + inputs = list(inputs) + outputs = fn(*inputs) + train_vars = get_vars_fn()[len_before_vars:] + + if grad_fn is None: + return outputs + + if not (isinstance(outputs, tuple) or isinstance(outputs, list)): + outputs = [outputs] + outputs = list(outputs) + + defun_inputs = [inputs, train_vars, outputs] + + def custom_grad_fn(op, *dys): + """Custom grad fn applying grad_fn for identity Defun.""" + fn_inputs, fn_vars, fn_outputs = nest.pack_sequence_as( + defun_inputs, list(op.inputs)) + dys = list(dys) + assert len(fn_outputs) == len(outputs) + assert len(fn_outputs) == len(dys) + + grad_inputs, grad_vars = grad_fn(fn_inputs, fn_vars, fn_outputs, dys) + grad_outputs = [None] * len(fn_outputs) + return tuple(grad_inputs + grad_vars + grad_outputs) + + # The Defun takes as input the original inputs, the trainable variables + # created in fn, and the outputs. In the forward it passes through the + # outputs. In the backwards, it produces gradients for the original inputs + # and the trainable variables. + in_types = [t.dtype for t in inputs] + out_types = [t.dtype for t in outputs] + var_types = [t.dtype for t in train_vars] + + # Get a unique name for the Defun + with framework_ops.name_scope("identity_custom_grad") as ns: + defun_name = ns + + @function.Defun( + *(in_types + var_types + out_types), + func_name=defun_name, + python_grad_func=custom_grad_fn, + shape_func=lambda _: [t.get_shape() for t in outputs]) + def identity(*args): + _, _, outs = nest.pack_sequence_as(defun_inputs, args) + return tuple([array_ops.identity(t) for t in outs]) + + flat_inputs = nest.flatten(defun_inputs) + id_out = identity(*flat_inputs) + return id_out diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py new file mode 100644 index 0000000000000000000000000000000000000000..cbcbcd75114a522b95631e4e7e95c1641b0a9987 --- /dev/null +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py @@ -0,0 +1,364 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for RevBlock.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.layers.python.layers import layers +from tensorflow.contrib.layers.python.layers import rev_block_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import random_seed +from tensorflow.python.layers import convolutional +from tensorflow.python.layers import core as core_layers +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class RevBlockTest(test.TestCase): + CHANNELS = 8 + NUM_LAYERS = 4 + BATCH_SIZE = 16 + + def testForwardBackward(self): + + def f(x): + return core_layers.dense(x, self.CHANNELS // 2, use_bias=True) + + def g(x): + return core_layers.dense(x, self.CHANNELS // 2, use_bias=True) + + x = random_ops.random_uniform( + [self.BATCH_SIZE, self.CHANNELS], dtype=dtypes.float32) + x1, x2 = array_ops.split(x, 2, axis=-1) + + block = rev_block_lib.RevBlock(f, g, num_layers=3) + y1, y2 = block.forward(x1, x2) + x1_inv, x2_inv = block.backward(y1, y2) + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + x1, x2, x1_inv, x2_inv = sess.run([x1, x2, x1_inv, x2_inv]) + + self.assertAllClose(x1, x1_inv) + self.assertAllClose(x2, x2_inv) + + def testBackwardForward(self): + + def f(x): + return core_layers.dense(x, self.CHANNELS // 2, use_bias=True) + + def g(x): + return core_layers.dense(x, self.CHANNELS // 2, use_bias=True) + + y = random_ops.random_uniform( + [self.BATCH_SIZE, self.CHANNELS], dtype=dtypes.float32) + y1, y2 = array_ops.split(y, 2, axis=-1) + + block = rev_block_lib.RevBlock(f, g, num_layers=3) + x1, x2 = block.backward(y1, y2) + y1_inv, y2_inv = block.forward(x1, x2) + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + y1, y2, y1_inv, y2_inv = sess.run([y1, y2, y1_inv, y2_inv]) + + self.assertAllClose(y1, y1_inv) + self.assertAllClose(y2, y2_inv) + + def _testRevBlock(self, + x=None, + f=None, + g=None, + f_side_input=None, + g_side_input=None): + random_seed.set_random_seed(1234) + + if f is None: + + def f(x): # pylint: disable=function-redefined + return core_layers.dense(x, self.CHANNELS // 2, use_bias=True) + + if g is None: + + def g(x): # pylint: disable=function-redefined + return core_layers.dense(x, self.CHANNELS // 2, use_bias=True) + + if f_side_input is None: + f_side_input = [] + + if g_side_input is None: + g_side_input = [] + + if x is None: + x = random_ops.random_uniform( + [self.BATCH_SIZE, self.CHANNELS], dtype=dtypes.float32) + x1, x2 = array_ops.split(x, 2, axis=-1) + + with variable_scope.variable_scope("rev_test") as vs: + y1_rev, y2_rev = rev_block_lib.rev_block( + x1, + x2, + f, + g, + f_side_input=f_side_input, + g_side_input=g_side_input, + num_layers=self.NUM_LAYERS) + y_rev = array_ops.concat([y1_rev, y2_rev], axis=1) + fg_vars = vs.trainable_variables() + + num_vars = len(variables.global_variables()) + with variable_scope.variable_scope(vs, reuse=True): + y1, y2 = rev_block_lib.rev_block( + x1, + x2, + f, + g, + f_side_input=f_side_input, + g_side_input=g_side_input, + num_layers=self.NUM_LAYERS, + is_training=False) + y = array_ops.concat([y1, y2], axis=1) + # Ensure no new vars were created - full reuse + assert len(variables.global_variables()) == num_vars + + loss_rev = math_ops.reduce_mean(y_rev + 10.) + loss = math_ops.reduce_mean(y + 10.) + + wrt = [x] + f_side_input + g_side_input + fg_vars + grads_rev = gradients_impl.gradients(loss_rev, wrt) + grads = gradients_impl.gradients(loss, wrt) + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + y_val, yd_val, gd_val, g_val = sess.run([y, y_rev, grads_rev, grads]) + self.assertAllClose(y_val, yd_val) + for g1, g2 in zip(gd_val, g_val): + self.assertAllClose(g1, g2) + + def testRevBlock(self): + self._testRevBlock() + + def testSideInput(self): + f_side_input = random_ops.random_uniform( + [self.BATCH_SIZE, self.CHANNELS // 2]) + + def f(x, side_input): + return core_layers.dense( + x, self.CHANNELS // 2, use_bias=True) + side_input[0] + + self._testRevBlock(f=f, f_side_input=[f_side_input]) + + def testMultipleFns(self): + + def f1(x): + return core_layers.dense(x, self.CHANNELS // 2) + + def f2(x): + return core_layers.dense(x, self.CHANNELS // 2, activation=nn_ops.relu) + + self._testRevBlock(f=[f1, f2, f1, f2]) + + # TODO(rsepassi): Recent change to conv seems to have broken this test. Find + # out why. + def _testConvAndBatchNorm(self): + + x = random_ops.random_uniform( + [self.BATCH_SIZE, 10, self.CHANNELS], dtype=dtypes.float32) + + def f(x): + x = convolutional.conv1d(x, self.CHANNELS // 2, 3, padding="same") + x = layers.batch_norm(x, is_training=True) + x = convolutional.conv1d(x, self.CHANNELS // 2, 3, padding="same") + x = layers.batch_norm(x, is_training=True) + return x + + self._testRevBlock(x=x, f=f) + + def testReuse(self): + + def f(x): + return core_layers.dense(x, self.CHANNELS // 2) + + def g(x): + return core_layers.dense(x, self.CHANNELS // 2) + + x = random_ops.random_uniform( + [self.BATCH_SIZE, self.CHANNELS], dtype=dtypes.float32) + x1, x2 = array_ops.split(x, 2, axis=-1) + + with variable_scope.variable_scope("test"): + y1, y2 = rev_block_lib.rev_block(x1, x2, f, g, num_layers=self.NUM_LAYERS) + + num_vars_before = len(variables.global_variables()) + + with variable_scope.variable_scope("test", reuse=True): + y1, y2 = rev_block_lib.rev_block(x1, x2, f, g, num_layers=self.NUM_LAYERS) + + num_vars_after = len(variables.global_variables()) + self.assertEqual(num_vars_before, num_vars_after) + + loss = math_ops.reduce_mean(y1 + y2) + _ = gradients_impl.gradients(loss, + [x] + variables.trainable_variables()) + + with variable_scope.variable_scope("test", reuse=True): + y1, y2 = rev_block_lib.rev_block(x1, x2, f, g, num_layers=self.NUM_LAYERS) + + num_vars_after = len(variables.global_variables()) + self.assertEqual(num_vars_before, num_vars_after) + + +class RecomputeTest(test.TestCase): + + def testRecompute(self): + + def layer(x, name=None): + with variable_scope.variable_scope(name, default_name="layer"): + x = layers.layer_norm(x) + x = convolutional.conv1d( + x, + 10, + 1, + use_bias=False, + kernel_initializer=init_ops.constant_initializer(42.42)) + x = nn_ops.relu(x) + return x + + def fn(x): + out = x + for _ in range(3): + out = layer(out) + return out + + @rev_block_lib.recompute_grad + def fn_recompute(x): + return fn(x) + + x = random_ops.random_uniform((3, 1, 3)) + recompute_vars = None + with variable_scope.variable_scope("recompute") as vs: + out1 = math_ops.reduce_sum(fn_recompute(x)) + recompute_vars = vs.trainable_variables() + reg_vars = None + with variable_scope.variable_scope("regular") as vs: + out2 = math_ops.reduce_sum(fn(x)) + reg_vars = vs.trainable_variables() + + grad1 = gradients_impl.gradients(out1, recompute_vars) + grad2 = gradients_impl.gradients(out2, reg_vars) + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + outs = sess.run([out1, out2, grad1, grad2]) + self.assertAllClose(outs[0], outs[1]) + for g1, g2 in zip(outs[2], outs[3]): + self.assertAllClose(g1, g2) + + +class FnWithCustomGradTest(test.TestCase): + + def testCorrectness(self): + + w = random_ops.random_uniform([6, 10]) + + def fn(a, b, c): + return core_layers.dense( + a, + 10, + use_bias=False, + kernel_initializer=lambda shape, dtype, partition_info: w + ) + math_ops.matmul(b, c) + + def grad_fn(inputs, trainable_variables, outputs, grad_outputs): + outputs = outputs[0] + grad_outputs = grad_outputs[0] + grad_inputs = gradients_impl.gradients( + outputs, inputs, grad_ys=grad_outputs) + grad_vars = gradients_impl.gradients( + outputs, trainable_variables, grad_ys=grad_outputs) + return grad_inputs, grad_vars + + custom_fn = rev_block_lib._fn_with_custom_grad(grad_fn)(fn) + + a = random_ops.random_uniform([11, 6]) + b = random_ops.random_uniform([11, 7]) + c = random_ops.random_uniform([7, 10]) + + out = fn(a, b, c) + custom_out = custom_fn(a, b, c) + self.assertEqual(out.get_shape().as_list(), + custom_out.get_shape().as_list()) + + loss = math_ops.reduce_mean(out) + custom_loss = math_ops.reduce_mean(custom_out) + + grads = gradients_impl.gradients( + loss, [a, b, c] + [variables.trainable_variables()[0]]) + custom_grads = gradients_impl.gradients( + custom_loss, [a, b, c] + [variables.trainable_variables()[1]]) + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + out_val, custom_out_val, grads_val, custom_grads_val = sess.run( + [out, custom_out, grads, custom_grads]) + self.assertAllClose(out_val, custom_out_val) + for g1, g2 in zip(grads_val, custom_grads_val): + self.assertAllClose(g1, g2) + + def testCustomGrad(self): + + def fn(a, b, c): + return core_layers.dense(a, 10, use_bias=False) + math_ops.matmul(b, c) + + def grad_fn(inputs, trainable_variables, unused_outputs, + unused_grad_outputs): + grad_inputs = [ + array_ops.ones_like(t) * (i + 1.) for i, t in enumerate(inputs) + ] + grad_vars = [ + array_ops.ones_like(t) * (i + len(inputs) + 1.) + for i, t in enumerate(trainable_variables) + ] + return grad_inputs, grad_vars + + a = random_ops.random_uniform([11, 6]) + b = random_ops.random_uniform([11, 7]) + c = random_ops.random_uniform([7, 10]) + w = random_ops.random_uniform([6, 10]) + out = rev_block_lib._fn_with_custom_grad(grad_fn)(fn)(a, b, c) + loss = math_ops.reduce_mean(out) + grads = gradients_impl.gradients( + loss, [a, b, c, variables.trainable_variables()[0]]) + expected_grads = [ + array_ops.ones_like(t) * (i + 1.) for i, t in enumerate([a, b, c, w]) + ] + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + g_val, eg_val = sess.run([grads, expected_grads]) + for g1, g2 in zip(g_val, eg_val): + self.assertAllClose(g1, g2) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 94920db574e07529c28313a78e0128676fcc7970..5df2c77249b81434125d838f896f0ace2a5ee130 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -10,7 +10,7 @@ package(default_visibility = [ "//tensorflow:internal", ]) -load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "py_test", "tf_py_test") py_library( name = "learn", @@ -154,12 +154,11 @@ py_test( ], ) -py_test( +tf_py_test( name = "experiment_test", size = "medium", srcs = ["python/learn/experiment_test.py"], - srcs_version = "PY2AND3", - deps = [ + additional_deps = [ ":learn", "//tensorflow/contrib/layers:layers_py", "//tensorflow/core:protos_all_py", @@ -346,6 +345,7 @@ py_test( srcs = ["python/learn/estimators/dnn_linear_combined_test.py"], shard_count = 4, srcs_version = "PY2AND3", + tags = ["no_oss"], # flaky b/70524820 deps = [ ":learn", "//tensorflow/contrib/layers:layers_py", @@ -461,6 +461,7 @@ py_test( size = "medium", srcs = ["python/learn/estimators/state_saving_rnn_estimator_test.py"], srcs_version = "PY2AND3", + tags = ["noasan"], deps = [ ":learn", "//tensorflow/contrib/layers:layers_py", @@ -715,12 +716,11 @@ py_test( ], ) -py_test( +tf_py_test( name = "graph_io_test", size = "small", srcs = ["python/learn/learn_io/graph_io_test.py"], - srcs_version = "PY2AND3", - deps = [ + additional_deps = [ ":learn", "//tensorflow/python:client", "//tensorflow/python:client_testlib", @@ -736,6 +736,7 @@ py_test( "//tensorflow/python:training", "//tensorflow/python:variables", ], + grpc_enabled = True, ) py_test( diff --git a/tensorflow/contrib/learn/python/learn/estimators/composable_model_test.py b/tensorflow/contrib/learn/python/learn/estimators/composable_model_test.py index 14750961efa30128708430fac038498de0a42118..ef5e620e8f08cffa7c2b945089aa5d150baefefc 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/composable_model_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/composable_model_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.framework.python.ops import variables as contrib_variables +from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.learn.python.learn.datasets import base from tensorflow.contrib.learn.python.learn.estimators import composable_model @@ -55,7 +55,7 @@ def _base_model_fn(features, labels, mode, params): raise NotImplementedError def _train_op_fn(loss): - global_step = contrib_variables.get_global_step() + global_step = training_util.get_global_step() assert global_step train_step = model.get_train_step(loss) diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn.py b/tensorflow/contrib/learn/python/learn/estimators/dnn.py index cb15ef23e95d27c737d8ae08065b804bafd39a07..c17b41c0f767e19d9c3635a8f60347a49b297cfb 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn.py @@ -23,7 +23,7 @@ import six from tensorflow.contrib import layers from tensorflow.contrib.framework import deprecated from tensorflow.contrib.framework import deprecated_arg_values -from tensorflow.contrib.framework.python.ops import variables as contrib_variables +from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn import metric_spec @@ -189,7 +189,7 @@ def _dnn_model_fn(features, labels, mode, params, config=None): """Returns the op to optimize the loss.""" return optimizers.optimize_loss( loss=loss, - global_step=contrib_variables.get_global_step(), + global_step=training_util.get_global_step(), learning_rate=_LEARNING_RATE, optimizer=_get_optimizer(optimizer), gradient_multipliers=( diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 788d2d0b1a58fad16712c968593b40de0d3979f0..05ed8b3409e68ae54e5ef89b3a1592a6f285565b 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -30,7 +30,6 @@ import six from google.protobuf import message from tensorflow.contrib import layers -from tensorflow.contrib import metrics as metrics_lib from tensorflow.contrib.framework import deprecated from tensorflow.contrib.framework import deprecated_args from tensorflow.contrib.framework import list_variables @@ -60,6 +59,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_util from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import lookup_ops +from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import resources from tensorflow.python.ops import variables from tensorflow.python.platform import gfile @@ -1230,7 +1230,7 @@ class Estimator(BaseEstimator): if metric_key.MetricKey.LOSS not in model_fn_ops.eval_metric_ops: model_fn_ops.eval_metric_ops[metric_key.MetricKey.LOSS] = ( - metrics_lib.streaming_mean(model_fn_ops.loss)) + metrics_lib.mean(model_fn_ops.loss)) return model_fn_ops def _get_predict_ops(self, features): diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_input_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_input_test.py index 248c6c733ffca351c848ba07110ba89928634a23..9d7c1a099aa4be64ca0296fa5b870597dabec7b4 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_input_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_input_test.py @@ -23,7 +23,7 @@ import tempfile import numpy as np -from tensorflow.contrib.framework.python.ops import variables +from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn import metric_spec from tensorflow.contrib.learn.python.learn import models @@ -114,7 +114,7 @@ def linear_model_params_fn(features, labels, mode, params): prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( loss, - variables.get_global_step(), + training_util.get_global_step(), optimizer='Adagrad', learning_rate=params['learning_rate']) return prediction, loss, train_op @@ -129,7 +129,7 @@ def linear_model_fn(features, labels, mode): (_, features), = features.items() prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) return prediction, loss, train_op @@ -139,7 +139,7 @@ def linear_model_fn_with_model_fn_ops(features, labels, mode): model_fn.ModeKeys.INFER) prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) return model_fn.ModelFnOps( mode=mode, predictions=prediction, loss=loss, train_op=train_op) @@ -150,7 +150,7 @@ def logistic_model_no_mode_fn(features, labels): labels = array_ops.one_hot(labels, 3, 1, 0) prediction, loss = (models.logistic_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) return { 'class': math_ops.argmax(prediction, 1), 'prob': prediction diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py index be2b0cb3ca959323b4de095ca072278f028be301..2a13a84627df35a68a4f04b25ab26ceecad0db0d 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py @@ -32,7 +32,7 @@ from google.protobuf import text_format from tensorflow.contrib import learn from tensorflow.contrib import lookup -from tensorflow.contrib.framework.python.ops import variables +from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn import experiment @@ -132,7 +132,7 @@ def linear_model_params_fn(features, labels, mode, params): prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( loss, - variables.get_global_step(), + training_util.get_global_step(), optimizer='Adagrad', learning_rate=params['learning_rate']) return prediction, loss, train_op @@ -147,7 +147,7 @@ def linear_model_fn(features, labels, mode): (_, features), = features.items() prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) return prediction, loss, train_op @@ -157,7 +157,7 @@ def linear_model_fn_with_model_fn_ops(features, labels, mode): model_fn.ModeKeys.INFER) prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) return model_fn.ModelFnOps( mode=mode, predictions=prediction, loss=loss, train_op=train_op) @@ -168,7 +168,7 @@ def logistic_model_no_mode_fn(features, labels): labels = array_ops.one_hot(labels, 3, 1, 0) prediction, loss = (models.logistic_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) return { 'class': math_ops.argmax(prediction, 1), 'prob': prediction @@ -241,7 +241,7 @@ def _build_estimator_for_resource_export_test(): const = constant_op.constant(-1, dtype=dtypes.int64) table = lookup.MutableHashTable( dtypes.string, dtypes.int64, const, name='LookupTableModel') - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) if mode in (model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL): key = constant_op.constant(['key']) value = constant_op.constant([42], dtype=dtypes.int64) @@ -306,7 +306,7 @@ def _model_fn_ops( mode=mode, predictions=constant_op.constant(0.), loss=constant_op.constant(0.), - train_op=variables.get_global_step().assign_add(1)) + train_op=training_util.get_global_step().assign_add(1)) def _make_input_fn(features, labels): @@ -389,7 +389,7 @@ class EstimatorModelFnTest(test.TestCase): self.assertEqual(expected_param, params) self.assertEqual(model_dir, expected_model_dir) return (constant_op.constant(0.), constant_op.constant(0.), - variables.get_global_step().assign_add(1)) + training_util.get_global_step().assign_add(1)) est = estimator.Estimator(model_fn=_argument_checker, params=expected_param, model_dir=expected_model_dir) @@ -400,7 +400,7 @@ class EstimatorModelFnTest(test.TestCase): def _invalid_model_fn(features, labels): # pylint: disable=unused-argument w = variables_lib.Variable(42.0, 'weight') - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) with ops.control_dependencies([update_global_step]): loss = 100.0 - w return None, loss, None @@ -415,7 +415,7 @@ class EstimatorModelFnTest(test.TestCase): # pylint: disable=unused-argument w = variables_lib.Variable(42.0, 'weight') loss = 100.0 - w - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) with ops.control_dependencies([update_global_step]): train_op = w.assign_add(loss / 100.0) predictions = loss @@ -434,7 +434,7 @@ class EstimatorModelFnTest(test.TestCase): # pylint: disable=unused-argument w = variables_lib.Variable(42.0, 'weight') loss = 100.0 - w - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) with ops.control_dependencies([update_global_step]): train_op = w.assign_add(loss / 100.0) return None, loss, train_op @@ -464,7 +464,7 @@ class EstimatorModelFnTest(test.TestCase): mode=mode, predictions=constant_op.constant(0.), loss=constant_op.constant(0.), - train_op=variables.get_global_step().assign_add(1), + train_op=training_util.get_global_step().assign_add(1), scaffold=monitored_session.Scaffold(init_fn=_init_fn)) est = estimator.Estimator(model_fn=_model_fn_scaffold) @@ -483,7 +483,7 @@ class EstimatorModelFnTest(test.TestCase): mode=mode, predictions=constant_op.constant([[1.]]), loss=constant_op.constant(0.), - train_op=variables.get_global_step().assign_add(1), + train_op=training_util.get_global_step().assign_add(1), scaffold=monitored_session.Scaffold(saver=self.mock_saver)) def input_fn(): diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py index 1d89dfb55b10b032cab7dcf434d396404d4eb83b..8131e0fde6fea5501cacc4714f53ed8d867ca70f 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py @@ -22,7 +22,7 @@ import random import numpy as np -from tensorflow.contrib.framework.python.ops import variables +from tensorflow.python.training import training_util from tensorflow.contrib.learn.python import learn from tensorflow.contrib.learn.python.learn import datasets from tensorflow.contrib.learn.python.learn import metric_spec @@ -62,7 +62,7 @@ class FeatureEngineeringFunctionTest(test.TestCase): _ = labels predictions = features["transformed_x"] loss = constant_op.constant([2.]) - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) return predictions, loss, update_global_step estimator = estimator_lib.Estimator( @@ -100,7 +100,7 @@ class FeatureEngineeringFunctionTest(test.TestCase): _ = labels predictions = features["x"] loss = constant_op.constant([2.]) - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) return predictions, loss, update_global_step estimator = estimator_lib.Estimator( @@ -139,7 +139,7 @@ class FeatureEngineeringFunctionTest(test.TestCase): _ = labels predictions = features["x"] loss = constant_op.constant([2.]) - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) return predictions, loss, update_global_step estimator_with_fe_fn = estimator_lib.Estimator( diff --git a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py index 992b804f59ecd88fedc2fba10d3079f93c4fe83d..8f9d6fc318a357853bdb8e3264f6691b410006b1 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py +++ b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py @@ -28,7 +28,7 @@ import time import numpy as np from tensorflow.contrib.factorization.python.ops import clustering_ops -from tensorflow.contrib.framework.python.ops import variables +from tensorflow.python.training import training_util from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators.model_fn import ModelFnOps from tensorflow.python.framework import ops @@ -128,7 +128,7 @@ def _kmeans_clustering_model_fn(features, labels, mode, params, config): random_seed=params.get('random_seed'), kmeans_plus_plus_num_retries=params.get( 'kmeans_plus_plus_num_retries')).training_graph() - incr_step = state_ops.assign_add(variables.get_global_step(), 1) + incr_step = state_ops.assign_add(training_util.get_global_step(), 1) loss = math_ops.reduce_sum(losses, name=KMeansClustering.LOSS_OP_NAME) summary.scalar('loss/raw', loss) training_op = with_dependencies([training_op, incr_step], loss) diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py index f5445ad4e728dbd3904279573771de9454b5d17c..37aa8b339622415d082933cdf66d2472a4119b48 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py @@ -26,7 +26,7 @@ import six from tensorflow.contrib import layers from tensorflow.contrib.framework import deprecated from tensorflow.contrib.framework import deprecated_arg_values -from tensorflow.contrib.framework.python.ops import variables as contrib_variables +from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib @@ -170,7 +170,7 @@ def _linear_model_fn(features, labels, mode, params, config=None): weight_collections=[parent_scope]) def _train_op_fn(loss): - global_step = contrib_variables.get_global_step() + global_step = training_util.get_global_step() my_vars = ops.get_collection(parent_scope) grads = gradients.gradients(loss, my_vars) if gradient_clip_norm: @@ -252,7 +252,7 @@ def sdca_model_fn(features, labels, mode, params): _add_bias_column(feature_columns, features, bias, columns_to_variables) def _train_op_fn(unused_loss): - global_step = contrib_variables.get_global_step() + global_step = training_util.get_global_step() sdca_model, train_op = optimizer.get_train_step(columns_to_variables, weight_column_name, loss_type, features, diff --git a/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor_test.py b/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor_test.py index 93c62f87e8495f299a8c456574c7b40534186304..656d68b76888d9319c0b9be481f9b0478ac4314c 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor_test.py @@ -21,7 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.contrib import layers -from tensorflow.contrib.framework.python.ops import variables +from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn.datasets import base from tensorflow.contrib.learn.python.learn.estimators import logistic_regressor @@ -57,7 +57,7 @@ def _logistic_regression_model_fn(features, labels, mode): predictions = math_ops.sigmoid(logits) loss = losses.sigmoid_cross_entropy(labels, logits) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) return predictions, loss, train_op diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index 307db76afe20a7743df16d169270a6f319497eb6..9576ff21c243022276bb0641882dfaf0decf05c0 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -35,6 +35,7 @@ from tensorflow.contrib.learn.python.learn import trainable from tensorflow.contrib.learn.python.learn.estimators import run_config from tensorflow.contrib.tpu.python.tpu import tpu_estimator from tensorflow.python.estimator import estimator as core_estimator +from tensorflow.python.estimator import util as estimator_util from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import basic_session_run_hooks @@ -46,6 +47,18 @@ from tensorflow.python.util import compat __all__ = ["Experiment"] +def _get_standardized_predicate_fn(predicate_fn): + pred_fn_args = estimator_util.fn_args(predicate_fn) + if "checkpoint_path" not in pred_fn_args: + # pylint: disable=unused-argument + def _pred_fn_wrapper(eval_results, checkpoint_path): + return predicate_fn(eval_results) + + return _pred_fn_wrapper + else: + return predicate_fn + + class _EvalAndExportListener(basic_session_run_hooks.CheckpointSaverListener): """Listener that evaluates and exports a model after creating a checkpoint. @@ -140,7 +153,8 @@ class Experiment(object): delay_workers_by_global_step=False, export_strategies=None, train_steps_per_iteration=None, - checkpoint_and_export=False): + checkpoint_and_export=False, + saving_listeners=None): """Constructor for `Experiment`. Creates an Experiment instance. None of the functions passed to this @@ -200,6 +214,9 @@ class Experiment(object): `save_checkpoints_steps`. Also, this parameter leads to the creation of a default `CheckpointSaverHook` instead of a `ValidationMonitor`, so the provided `train_monitors` will need to be adjusted accordingly. + saving_listeners: list of `CheckpointSaverListener` objects. Used by + tf.estimator.Estimator for callbacks that run immediately before or + after checkpoint savings. Raises: ValueError: if `estimator` does not implement Estimator interface, @@ -221,6 +238,9 @@ class Experiment(object): raise ValueError( "`estimator` must implement `tf.contrib.learn.Trainable`" "or `tf.estimator.`Estimator`.") + if saving_listeners is not None: + raise ValueError("`saving_listeners` must be `None` with " + "`tf.contrib.learn.Estimator`.") if isinstance(estimator, tpu_estimator.TPUEstimator): logging.warn( @@ -242,6 +262,7 @@ class Experiment(object): self._eval_delay_secs = eval_delay_secs self._continuous_eval_throttle_secs = continuous_eval_throttle_secs self._checkpoint_and_export = checkpoint_and_export + self._saving_listeners = saving_listeners # Using 1 on a non-cached file system requires a lot of overhead to # read the checkpoint state file. This is particular bad on GCS, so # we use a different default. This is a temporary band-aid, to be @@ -362,9 +383,11 @@ class Experiment(object): logging.info("Waiting %d secs before starting training.", remaining) time.sleep(delay_secs) - return self._call_train(input_fn=self._train_input_fn, - max_steps=self._train_steps, - hooks=self._train_monitors + extra_hooks) + return self._call_train( + input_fn=self._train_input_fn, + max_steps=self._train_steps, + hooks=self._train_monitors + extra_hooks, + saving_listeners=self._saving_listeners) def evaluate(self, delay_secs=None, name=None): """Evaluate on the evaluation data. @@ -436,22 +459,33 @@ class Experiment(object): evaluate_checkpoint_only_once: Whether to skip evaluation of checkpoints that have already been evaluated. Default is `True`. continuous_eval_predicate_fn: A predicate function determining whether to - continue eval after each iteration. `predicate_fn` takes the evaluation - results as arguments. At the beginning of evaluation, the passed eval - results will be None so it's expected that the predicate function - handles that gracefully. When `predicate_fn` is not specified, - continuous eval will run in an infinite loop (if `train_steps` is None) - or exit once global step reaches `train_steps`. + continue eval after each iteration. A `predicate_fn` has one of the + following signatures: + * (eval_results) -> boolean + * (eval_results, checkpoint_path) -> boolean + Where `eval_results` is the dictionary of metric evaluations and + checkpoint_path is the path to the checkpoint containing the parameters + on which that evaluation was based. + At the beginning of evaluation, the passed `eval_results` will be None + so it's expected that the predicate function handles that gracefully. + When `predicate_fn` is not specified, continuous eval will run in an + infinite loop (if `train_steps` is None). or exit once global step + reaches `train_steps`. + export: Whether to export from this step. Default is 'True'. Raises: ValueError: if `continuous_eval_predicate_fn` is neither None nor callable. """ - if (continuous_eval_predicate_fn is not None and - not callable(continuous_eval_predicate_fn)): - raise ValueError( - "`continuous_eval_predicate_fn` must be a callable, or None.") + if continuous_eval_predicate_fn is not None: + if not callable(continuous_eval_predicate_fn): + raise ValueError( + "`continuous_eval_predicate_fn` must be a callable, or None.") + predicate_fn = _get_standardized_predicate_fn( + continuous_eval_predicate_fn) + else: + predicate_fn = None if delay_secs is None: delay_secs = self._eval_delay_secs @@ -465,8 +499,10 @@ class Experiment(object): previous_path = None eval_result = None last_warning_time = 0 - while (not continuous_eval_predicate_fn or - continuous_eval_predicate_fn(eval_result)): + while (not predicate_fn or + predicate_fn( + eval_result, + checkpoint_path=previous_path if eval_result else None)): # Exit if we have already reached number of steps to train. if self._has_training_stopped(eval_result): logging.info("Exiting continuous eval, global_step=%s >= " @@ -672,11 +708,19 @@ class Experiment(object): Args: continuous_eval_predicate_fn: A predicate function determining whether to - continue after each iteration. `predicate_fn` takes the evaluation - results as its arguments. At the beginning of evaluation, the passed - eval results will be None so it's expected that the predicate function - handles that gracefully. When `predicate_fn` is not specified, this will - run in an infinite loop or exit when global_step reaches `train_steps`. + continue eval after each iteration. A `predicate_fn` has one of the + following signatures: + * (eval_results) -> boolean + * (eval_results, checkpoint_path) -> boolean + Where `eval_results` is the dictionary of metric evaluations and + checkpoint_path is the path to the checkpoint containing the parameters + on which that evaluation was based. + At the beginning of evaluation, the passed `eval_results` and + `checkpoint_path` will be None so it's expected that the predicate + function handles that gracefully. + When `predicate_fn` is not specified, continuous eval will run in an + infinite loop (if `train_steps` is None). or exit once global step + reaches `train_steps`. Returns: A tuple of the result of the `evaluate` call to the `Estimator` and the @@ -687,13 +731,18 @@ class Experiment(object): callable. """ - if (continuous_eval_predicate_fn is not None and - not callable(continuous_eval_predicate_fn)): - raise ValueError( - "`continuous_eval_predicate_fn` must be a callable, or None.") + if continuous_eval_predicate_fn is not None: + if not callable(continuous_eval_predicate_fn): + raise ValueError( + "`continuous_eval_predicate_fn` must be a callable, or None.") + predicate_fn = _get_standardized_predicate_fn( + continuous_eval_predicate_fn) + else: + predicate_fn = None - eval_result = None export_results = None + latest_checkpoint = None + eval_result = None # Set the default value for train_steps_per_iteration, which will be # overridden by other settings. @@ -703,8 +752,10 @@ class Experiment(object): elif self._train_steps is not None: train_steps_per_iteration = int(self._train_steps / 10) - while (not continuous_eval_predicate_fn or - continuous_eval_predicate_fn(eval_result)): + while (not predicate_fn or + predicate_fn( + eval_result, + checkpoint_path=latest_checkpoint if eval_result else None)): if self._has_training_stopped(eval_result): # Exits once max steps of training is satisfied. @@ -712,16 +763,21 @@ class Experiment(object): break logging.info("Training model for %s steps", train_steps_per_iteration) - self._call_train(input_fn=self._train_input_fn, - steps=train_steps_per_iteration, - hooks=self._train_monitors) + self._call_train( + input_fn=self._train_input_fn, + steps=train_steps_per_iteration, + hooks=self._train_monitors, + saving_listeners=self._saving_listeners) logging.info("Evaluating model now.") - eval_result = self._call_evaluate(input_fn=self._eval_input_fn, - steps=self._eval_steps, - metrics=self._eval_metrics, - name="one_pass", - hooks=self._eval_hooks) + latest_checkpoint = saver.latest_checkpoint(self._estimator.model_dir) + eval_result = self._call_evaluate( + input_fn=self._eval_input_fn, + steps=self._eval_steps, + metrics=self._eval_metrics, + name="one_pass", + checkpoint_path=latest_checkpoint, + hooks=self._eval_hooks) export_results = self._maybe_export(eval_result) return eval_result, export_results @@ -762,9 +818,11 @@ class Experiment(object): Returns: The result of the `evaluate` call to the `Estimator`. """ - self._call_train(input_fn=self._train_input_fn, - steps=1, - hooks=self._train_monitors) + self._call_train( + input_fn=self._train_input_fn, + steps=1, + hooks=self._train_monitors, + saving_listeners=self._saving_listeners) eval_result = self._call_evaluate(input_fn=self._eval_input_fn, steps=1, @@ -792,7 +850,8 @@ class Experiment(object): return server def _call_train(self, _sentinel=None, # pylint: disable=invalid-name, - input_fn=None, steps=None, hooks=None, max_steps=None): + input_fn=None, steps=None, hooks=None, max_steps=None, + saving_listeners=None): if _sentinel is not None: raise ValueError("_call_train should be called with keyword args only") @@ -801,10 +860,12 @@ class Experiment(object): # safe to convert for both cases. hooks = monitors.replace_monitors_with_hooks(hooks, self._estimator) if self._core_estimator_used: - return self._estimator.train(input_fn=input_fn, - steps=steps, - max_steps=max_steps, - hooks=hooks) + return self._estimator.train( + input_fn=input_fn, + steps=steps, + max_steps=max_steps, + hooks=hooks, + saving_listeners=saving_listeners) else: return self._estimator.fit(input_fn=input_fn, steps=steps, diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py index fe40d27c445d4f560c96fc9b50ceb0daed30ee93..545d7d8924c0c10544e6113e2968b7ae3d2090fc 100644 --- a/tensorflow/contrib/learn/python/learn/experiment_test.py +++ b/tensorflow/contrib/learn/python/learn/experiment_test.py @@ -232,14 +232,19 @@ class ExperimentTest(test.TestCase): def test_train(self): for est in self._estimators_for_tests(): - eval_metrics = 'eval_metrics' if not isinstance( - est, core_estimator.Estimator) else None + if isinstance(est, core_estimator.Estimator): + eval_metrics = None + saving_listeners = 'saving_listeners' + else: + eval_metrics = 'eval_metrics' + saving_listeners = None ex = experiment.Experiment( est, train_input_fn='train_input', train_steps='train_steps', eval_input_fn='eval_input', - eval_metrics=eval_metrics) + eval_metrics=eval_metrics, + saving_listeners=saving_listeners) fit_args = ex.train(delay_secs=0) self.assertEqual(1, est.fit_count) self.assertIn(('max_steps', 'train_steps'), fit_args) @@ -487,6 +492,33 @@ class ExperimentTest(test.TestCase): self.assertEqual(3, est.eval_count) self.assertEqual([noop_hook], est.eval_hooks) + def test_continuous_eval_predicate_fn_with_checkpoint(self): + for est in self._estimators_for_tests(): + eval_metrics = 'eval_metrics' if not isinstance( + est, core_estimator.Estimator) else None + est.fake_checkpoint() + noop_hook = _NoopHook() + + def _predicate_fn(eval_result, checkpoint_path): + self.assertEqual(not eval_result, + checkpoint_path is None) + return est.eval_count < 3 # pylint: disable=cell-var-from-loop + + ex = experiment.Experiment( + est, + train_input_fn='train_input', + eval_input_fn='eval_input', + eval_metrics=eval_metrics, + eval_hooks=[noop_hook], + eval_delay_secs=0, + continuous_eval_throttle_secs=0) + ex.continuous_eval( + evaluate_checkpoint_only_once=False, + continuous_eval_predicate_fn=_predicate_fn) + self.assertEqual(0, est.fit_count) + self.assertEqual(3, est.eval_count) + self.assertEqual([noop_hook], est.eval_hooks) + def test_run_local(self): for est in self._estimators_for_tests(): eval_metrics = 'eval_metrics' if not isinstance( @@ -675,8 +707,12 @@ class ExperimentTest(test.TestCase): def test_continuous_train_and_eval(self): for est in self._estimators_for_tests(eval_dict={'global_step': 100}): - eval_metrics = 'eval_metrics' if not isinstance( - est, core_estimator.Estimator) else None + if isinstance(est, core_estimator.Estimator): + eval_metrics = None + saving_listeners = 'saving_listeners' + else: + eval_metrics = 'eval_metrics' + saving_listeners = None noop_hook = _NoopHook() export_strategy = saved_model_export_utils.make_export_strategy( est, @@ -690,7 +726,8 @@ class ExperimentTest(test.TestCase): eval_hooks=[noop_hook], train_steps=100, eval_steps=100, - export_strategies=export_strategy) + export_strategies=export_strategy, + saving_listeners=saving_listeners) ex.continuous_train_and_eval() self.assertEqual(1, est.fit_count) self.assertEqual(1, est.eval_count) @@ -742,9 +779,10 @@ class ExperimentTest(test.TestCase): ex.continuous_train_and_eval(continuous_eval_predicate_fn=predicate_fn) mock_estimator.train.assert_called_once_with( input_fn='train_input', - steps=int(total_steps/10), + steps=int(total_steps / 10), max_steps=test.mock.ANY, - hooks=test.mock.ANY) + hooks=test.mock.ANY, + saving_listeners=test.mock.ANY) def test_continuous_train_and_eval_with_steps_per_iteration_from_user(self): mock_estimator = test.mock.Mock(core_estimator.Estimator) @@ -768,7 +806,8 @@ class ExperimentTest(test.TestCase): input_fn='train_input', steps=1234, max_steps=test.mock.ANY, - hooks=test.mock.ANY) + hooks=test.mock.ANY, + saving_listeners=test.mock.ANY) def test_continuous_train_and_eval_with_default_steps_per_iteration(self): mock_estimator = test.mock.Mock(core_estimator.Estimator) @@ -791,7 +830,8 @@ class ExperimentTest(test.TestCase): input_fn='train_input', steps=1000, max_steps=test.mock.ANY, - hooks=test.mock.ANY) + hooks=test.mock.ANY, + saving_listeners=test.mock.ANY) def test_continuous_train_and_eval_with_invalid_predicate_fn(self): for est in self._estimators_for_tests(): @@ -857,11 +897,19 @@ class ExperimentTest(test.TestCase): est, None if isinstance(est, core_estimator.Estimator) else 'export_input', exports_to_keep=None) + if isinstance(est, core_estimator.Estimator): + eval_metrics = None + saving_listeners = 'saving_listeners' + else: + eval_metrics = 'eval_metrics' + saving_listeners = None ex = experiment.Experiment( est, train_input_fn='train_input', eval_input_fn='eval_input', - export_strategies=(exp_strategy,)) + export_strategies=(exp_strategy,), + eval_metrics=eval_metrics, + saving_listeners=saving_listeners) ex.test() self.assertEqual(1, est.fit_count) self.assertEqual(1, est.eval_count) diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py index db18ebf05d5fb98e28e767be7bcccdf992a56fd8..f36a778b529a83f158241ddb060959c4b33e2e95 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py @@ -28,7 +28,6 @@ import six from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging @@ -369,10 +368,11 @@ class DataFeeder(object): if x_is_dict: num_samples = list(self._x.values())[0].shape[0] elif tensor_util.is_tensor(self._x): - num_samples = self._x.shape[0].value # shape will be a Dimension, extract an int + num_samples = self._x.shape[ + 0].value # shape will be a Dimension, extract an int else: num_samples = self._x.shape[0] - + if self._shuffle: self.indices = self.random_state.permutation(num_samples) else: @@ -857,8 +857,8 @@ class DaskDataFeeder(object): """Returns a function, that will sample data and provide it to placeholders. Args: - input_placeholder: tf.Placeholder for input features mini batch. - output_placeholder: tf.Placeholder for output labels. + input_placeholder: tf.placeholder for input features mini batch. + output_placeholder: tf.placeholder for output labels. Returns: A function that when called samples a random subset of batch size diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py index 4b34fc62849766370979bb2002d42ee03ea7161a..3a46c239688017f9204d2c6182a6f81cd325a417 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import io_ops @@ -280,14 +281,33 @@ def _get_file_names(file_pattern, randomize_input): def _get_examples(file_name_queue, reader, num_threads, read_batch_size, filter_fn, parse_fn): + """Get example filenames matching. + + Args: + file_name_queue: A queue implementation that dequeues elements in + first-in first-out order. + reader: A function or class that returns an object with + `read` method, (filename tensor) -> (example tensor). + num_threads: The number of threads enqueuing examples. + read_batch_size: An int or scalar `Tensor` specifying the number of + records to read at once. + filter_fn: Filtering function, takes both keys as well as an `Example` + Tensors and returns a boolean mask of the same shape as the input Tensors + to be applied for filtering. If `None`, no filtering is done. + parse_fn: Parsing function, takes `Example` Tensor returns parsed + representation. If `None`, no parsing is done. + + Returns: + List of example file names matching `file_name_queue`. + """ with ops.name_scope('read'): example_list = [] for _ in range(num_threads): - if read_batch_size > 1: - keys, examples_proto = reader().read_up_to(file_name_queue, - read_batch_size) - else: - keys, examples_proto = reader().read(file_name_queue) + keys, examples_proto = utils.smart_cond( + read_batch_size > 1, + lambda: reader().read_up_to(file_name_queue, read_batch_size), + lambda: reader().read(file_name_queue)) + if filter_fn: mask = filter_fn(keys, examples_proto) keys = array_ops.boolean_mask(keys, mask) @@ -379,14 +399,15 @@ def _read_keyed_batch_examples_helper(file_pattern, capacity=1, dtypes=[dtypes.string], shapes=[[]]) enqueue_op = file_name_queue.enqueue( input_pipeline_ops.seek_next( - file_names, shuffle=randomize_input, num_epochs=num_epochs, + file_names, + shuffle=randomize_input, + num_epochs=num_epochs, seed=seed)) queue_runner.add_queue_runner( queue_runner.QueueRunner(file_name_queue, [enqueue_op])) else: file_name_queue = input_ops.string_input_producer( - constant_op.constant( - file_names, name='input'), + constant_op.constant(file_names, name='input'), shuffle=randomize_input, num_epochs=num_epochs, name=file_name_queue_scope, @@ -496,7 +517,8 @@ def read_keyed_batch_features(file_pattern, """ with ops.name_scope(name, 'read_batch_features', [file_pattern]) as scope: - if read_batch_size is None: read_batch_size = batch_size + if read_batch_size is None: + read_batch_size = batch_size keys, examples = read_keyed_batch_examples( file_pattern, batch_size, diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py index 6f0fd9a2976d37d1c701a96f50c2b987562cb191..e11e8b698adc113486bbb45572c8129e964cc931 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py @@ -204,8 +204,7 @@ class GraphIOTest(test.TestCase): shape = (0,) features = { "feature": - parsing_ops.FixedLenFeature( - shape=shape, dtype=dtypes_lib.float32) + parsing_ops.FixedLenFeature(shape=shape, dtype=dtypes_lib.float32) } with ops.Graph().as_default() as g, self.test_session(graph=g) as sess: @@ -255,8 +254,8 @@ class GraphIOTest(test.TestCase): self.assertAllEqual((None,), inputs.get_shape().as_list()) self.assertEqual("%s:1" % name, inputs.name) file_name_queue_name = "%s/file_name_queue" % name - file_name_queue_limit_name = ("%s/limit_epochs/epochs" % - file_name_queue_name) + file_name_queue_limit_name = ( + "%s/limit_epochs/epochs" % file_name_queue_name) file_names_name = "%s/input" % file_name_queue_name example_queue_name = "%s/random_shuffle_queue" % name op_nodes = test_util.assert_ops_in_graph({ @@ -354,8 +353,8 @@ class GraphIOTest(test.TestCase): json_lines = [ "".join([ '{"features": { "feature": { "sequence": {', - '"bytes_list": { "value": ["', base64.b64encode(l).decode("ascii"), - '"]}}}}}\n' + '"bytes_list": { "value": ["', + base64.b64encode(l).decode("ascii"), '"]}}}}}\n' ]) for l in lines ] return self._create_temp_file("".join(json_lines)) @@ -823,6 +822,31 @@ class GraphIOTest(test.TestCase): coord.request_stop() coord.join(threads) + def test_read_keyed_batch_features_shared_queue(self): + batch_size = 17 + shape = (0,) + fixed_feature = parsing_ops.FixedLenFeature( + shape=shape, dtype=dtypes_lib.float32) + feature = {"feature": fixed_feature} + reader = io_ops.TFRecordReader + + _, queued_feature = graph_io.read_keyed_batch_features_shared_queue( + _VALID_FILE_PATTERN, batch_size, feature, reader) + + with ops.Graph().as_default() as g, self.test_session(graph=g) as session: + features_result = graph_io.read_batch_features( + _VALID_FILE_PATTERN, batch_size, feature, reader) + session.run(variables.local_variables_initializer()) + + self.assertAllEqual( + queued_feature.get("feature").get_shape().as_list(), + features_result.get("feature").get_shape().as_list()) + + def test_get_file_names_errors(self): + # Raise bad file_pattern. + with self.assertRaises(ValueError): + graph_io._get_file_names([], True) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/learn/python/learn/metric_spec.py b/tensorflow/contrib/learn/python/learn/metric_spec.py index ed6683abedbb8ae76ba364405158eb52cbb6d762..6440bc204b8e339ff51311dcc87b36f556b94092 100644 --- a/tensorflow/contrib/learn/python/learn/metric_spec.py +++ b/tensorflow/contrib/learn/python/learn/metric_spec.py @@ -42,10 +42,8 @@ def _args(fn): """ if hasattr(fn, 'func') and hasattr(fn, 'keywords'): # Handle functools.partial and similar objects. - return tuple([ - arg for arg in tf_inspect.getargspec(fn.func).args - if arg not in set(fn.keywords.keys()) - ]) + return tuple( + [arg for arg in _args(fn.func) if arg not in set(fn.keywords.keys())]) # Handle function. return tuple(tf_inspect.getargspec(fn).args) diff --git a/tensorflow/contrib/learn/python/learn/utils/export.py b/tensorflow/contrib/learn/python/learn/utils/export.py index 6af2287761299f6725f9547917101c18b0cc0164..cb34cb1d26b6812c7f3f39e9f965615de5a8ef07 100644 --- a/tensorflow/contrib/learn/python/learn/utils/export.py +++ b/tensorflow/contrib/learn/python/learn/utils/export.py @@ -20,7 +20,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.framework import deprecated -from tensorflow.contrib.framework.python.ops import variables as contrib_variables +from tensorflow.python.training import training_util from tensorflow.contrib.session_bundle import exporter from tensorflow.contrib.session_bundle import gc from tensorflow.python.client import session as tf_session @@ -78,7 +78,7 @@ def _export_graph(graph, saver, checkpoint_path, export_dir, default_graph_signature=default_graph_signature, named_graph_signatures=named_graph_signatures, assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)) - return export.export(export_dir, contrib_variables.get_global_step(), + return export.export(export_dir, training_util.get_global_step(), session, exports_to_keep=exports_to_keep) @@ -295,7 +295,7 @@ def _export_estimator(estimator, checkpoint_path = (checkpoint_path or tf_saver.latest_checkpoint(estimator._model_dir)) with ops.Graph().as_default() as g: - contrib_variables.create_global_step(g) + training_util.create_global_step(g) if use_deprecated_input_fn: examples = array_ops.placeholder(dtype=dtypes.string, diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py index 6ffd2a133995a6ff8b35540221fb5676bf5de19f..4b404a8e20e33a17a0d5f857e4220f90c7bc799f 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py @@ -33,7 +33,6 @@ from __future__ import division from __future__ import print_function import os -import tempfile import time from tensorflow.contrib.layers.python.layers import feature_column @@ -682,22 +681,36 @@ def extend_export_strategy(base_export_strategy, ValueError: If `estimator` is a ${tf.estimator.Estimator} instance and `default_output_alternative_key` was specified or if post_export_fn does not return a valid directory. + RuntimeError: If unable to create temporary or final export directory. """ - tmp_base_export_dir = tempfile.mkdtemp() + tmp_base_export_folder = 'temp-base-export-' + str(int(time.time())) + tmp_base_export_dir = os.path.join(export_dir_base, tmp_base_export_folder) + if gfile.Exists(tmp_base_export_dir): + raise RuntimeError('Failed to obtain base export directory') + gfile.MakeDirs(tmp_base_export_dir) tmp_base_export = base_export_strategy.export( estimator, tmp_base_export_dir, checkpoint_path) - tmp_post_export_dir = tempfile.mkdtemp() + + tmp_post_export_folder = 'temp-post-export-' + str(int(time.time())) + tmp_post_export_dir = os.path.join(export_dir_base, tmp_post_export_folder) + if gfile.Exists(tmp_post_export_dir): + raise RuntimeError('Failed to obtain temp export directory') + + gfile.MakeDirs(tmp_post_export_dir) tmp_post_export = post_export_fn(tmp_base_export, tmp_post_export_dir) if not tmp_post_export.startswith(tmp_post_export_dir): raise ValueError('post_export_fn must return a sub-directory of {}' .format(tmp_post_export_dir)) - export_relpath = os.path.relpath(tmp_post_export, tmp_post_export_dir) - - gfile.Rename( - os.path.join(tmp_post_export_dir, export_relpath), - os.path.join(export_dir_base, export_relpath)) - return os.path.join(export_dir_base, export_relpath) + post_export_relpath = os.path.relpath(tmp_post_export, tmp_post_export_dir) + post_export = os.path.join(export_dir_base, post_export_relpath) + if gfile.Exists(post_export): + raise RuntimeError('Failed to obtain final export directory') + gfile.Rename(tmp_post_export, post_export) + + gfile.DeleteRecursively(tmp_base_export_dir) + gfile.DeleteRecursively(tmp_post_export_dir) + return post_export name = post_export_name if post_export_name else base_export_strategy.name return export_strategy.ExportStrategy(name, export_fn) diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py index ec3a88003f01b3b62591c13472029601b11ba491..628eb254c3b1129648c453dc47f0c0919891de6f 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py @@ -766,10 +766,11 @@ class SavedModelExportUtilsTest(test.TestCase): test_estimator = TestEstimator() tmpdir = tempfile.mkdtemp() - final_path = final_export_strategy.export(test_estimator, tmpdir, - os.path.join( - tmpdir, "checkpoint")) - self.assertEqual(os.path.join(tmpdir, "rewrite"), final_path) + export_model_dir = os.path.join(tmpdir, "model") + checkpoint_path = os.path.join(tmpdir, "checkpoint") + final_path = final_export_strategy.export(test_estimator, export_model_dir, + checkpoint_path) + self.assertEqual(os.path.join(export_model_dir, "rewrite"), final_path) def test_extend_export_strategy_same_name(self): @@ -795,10 +796,11 @@ class SavedModelExportUtilsTest(test.TestCase): test_estimator = TestEstimator() tmpdir = tempfile.mkdtemp() - final_path = final_export_strategy.export(test_estimator, tmpdir, - os.path.join( - tmpdir, "checkpoint")) - self.assertEqual(os.path.join(tmpdir, "rewrite"), final_path) + export_model_dir = os.path.join(tmpdir, "model") + checkpoint_path = os.path.join(tmpdir, "checkpoint") + final_path = final_export_strategy.export(test_estimator, export_model_dir, + checkpoint_path) + self.assertEqual(os.path.join(export_model_dir, "rewrite"), final_path) def test_extend_export_strategy_raises_error(self): diff --git a/tensorflow/contrib/libsvm/BUILD b/tensorflow/contrib/libsvm/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..df96402a4ffd51840f77d58d8066487030362340 --- /dev/null +++ b/tensorflow/contrib/libsvm/BUILD @@ -0,0 +1,102 @@ +package( + default_visibility = ["//visibility:private"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") +load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") +load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") +load("//tensorflow:tensorflow.bzl", "tf_kernel_library") +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +load("//tensorflow:tensorflow.bzl", "tf_py_test") + +tf_custom_op_library( + name = "python/ops/_libsvm_ops.so", + srcs = [ + "kernels/decode_libsvm_op.cc", + "ops/libsvm_ops.cc", + ], + deps = [ + "//tensorflow/core/kernels:bounds_check_lib", + ], +) + +tf_kernel_library( + name = "libsvm_kernels", + srcs = ["kernels/decode_libsvm_op.cc"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core/kernels:bounds_check_lib", + ], +) + +tf_gen_op_libs( + op_lib_names = ["libsvm_ops"], + deps = [ + "//tensorflow/core:lib", + ], +) + +tf_gen_op_wrapper_py( + name = "libsvm_ops", + deps = [":libsvm_ops_op_lib"], +) + +tf_custom_op_py_library( + name = "libsvm", + srcs = [ + "__init__.py", + "python/ops/libsvm_ops.py", + ], + dso = [ + ":python/ops/_libsvm_ops.so", + ], + kernels = [ + ":libsvm_kernels", + ":libsvm_ops_op_lib", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":libsvm_ops", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + ], +) + +tf_py_test( + name = "decode_libsvm_op_test", + srcs = ["python/kernel_tests/decode_libsvm_op_test.py"], + additional_deps = [ + ":libsvm", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered_impl.py b/tensorflow/contrib/libsvm/__init__.py similarity index 58% rename from tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered_impl.py rename to tensorflow/contrib/libsvm/__init__.py index 223bc9d042c69be05b0e578835a31ed6e83c0c97..a875863caab29eb59a1834ca9184a5e272cb6656 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered_impl.py +++ b/tensorflow/contrib/libsvm/__init__.py @@ -12,28 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""SigmoidCentered bijector.""" +"""Libsvm decoder. + +@@decode_libsvm +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.distributions.python.ops.bijectors import softmax_centered +from tensorflow.contrib.libsvm.python.ops.libsvm_ops import decode_libsvm +from tensorflow.python.util.all_util import remove_undocumented -__all__ = [ - "SigmoidCentered", +_allowed_symbols = [ + "decode_libsvm", ] - -class SigmoidCentered(softmax_centered.SoftmaxCentered): - """Bijector which computes Y = g(X) = exp([X 0]) / (1 + exp(-X)). - - Equivalent to: `bijector.SoftmaxCentered(event_ndims=0)`. - - See `bijector.SoftmaxCentered` for more details. - """ - - def __init__(self, validate_args=False, name="sigmoid_centered"): - super(SigmoidCentered, self).__init__( - event_ndims=0, validate_args=validate_args, name=name) +remove_undocumented(__name__) diff --git a/tensorflow/contrib/libsvm/kernels/decode_libsvm_op.cc b/tensorflow/contrib/libsvm/kernels/decode_libsvm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..fc7889b27cd9ec50d8d2f7d34975ec8cd16c258f --- /dev/null +++ b/tensorflow/contrib/libsvm/kernels/decode_libsvm_op.cc @@ -0,0 +1,178 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace tensorflow { +namespace { +template +bool ConvertHelper(const string& s, T* value); +} + +template +class DecodeLibsvmOp : public OpKernel { + public: + explicit DecodeLibsvmOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_features", &num_features_)); + OP_REQUIRES(ctx, (num_features_ >= 1), + errors::InvalidArgument("Invalid number of features \"", + num_features_, "\"")); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor* input_tensor; + OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor)); + const auto& input_flat = input_tensor->flat(); + + Tensor* label_tensor; + OP_REQUIRES_OK( + ctx, ctx->allocate_output(0, input_tensor->shape(), &label_tensor)); + auto label = label_tensor->flat(); + + std::vector out_values; + std::vector> out_indices; + for (int i = 0; i < input_flat.size(); ++i) { + std::vector entries = + str_util::Split(input_flat(i), " ", str_util::SkipEmpty()); + OP_REQUIRES(ctx, (entries.size() > 0), + errors::InvalidArgument("No entries found for input[", i, + "]: \"", input_flat(i), "\"")); + Tlabel label_value; + OP_REQUIRES( + ctx, ConvertHelper(entries[0].c_str(), &label_value), + errors::InvalidArgument("Label format incorrect: ", entries[0])); + label(i) = label_value; + for (int j = 1; j < entries.size(); j++) { + std::vector pair = str_util::Split(entries[j], ":"); + OP_REQUIRES( + ctx, (pair.size() == 2), + errors::InvalidArgument("Invalid feature \"", entries[j], "\"")); + int64 feature_index; + OP_REQUIRES( + ctx, strings::safe_strto64(pair[0].c_str(), &feature_index), + errors::InvalidArgument("Feature format incorrect: ", entries[j])); + OP_REQUIRES(ctx, (feature_index >= 0), + errors::InvalidArgument( + "Feature index should be >= 0, got ", feature_index)); + T feature_value; + OP_REQUIRES( + ctx, ConvertHelper(pair[1], &feature_value), + errors::InvalidArgument("Feature format incorrect: ", entries[j])); + out_values.emplace_back(feature_value); + out_indices.emplace_back(std::pair(i, feature_index)); + } + } + + Tensor* indices_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_output( + 1, + TensorShape({static_cast(out_indices.size()), + input_tensor->shape().dims() + 1}), + &indices_tensor)); + auto indices = indices_tensor->matrix(); + // Translate flat index to shaped index like np.unravel_index + // Calculate factors for each dimension + std::vector factors(input_tensor->shape().dims()); + factors[input_tensor->shape().dims() - 1] = 1; + for (int j = input_tensor->shape().dims() - 2; j >= 0; j--) { + factors[j] = factors[j + 1] * input_tensor->shape().dim_size(j + 1); + } + for (int i = 0; i < out_indices.size(); i++) { + indices(i, 0) = out_indices[i].first; + int64 value = out_indices[i].first; + for (int j = 0; j < input_tensor->shape().dims(); j++) { + indices(i, j) = value / factors[j]; + value = value % factors[j]; + } + indices(i, input_tensor->shape().dims()) = out_indices[i].second; + } + + Tensor* values_tensor; + OP_REQUIRES_OK(ctx, + ctx->allocate_output( + 2, TensorShape({static_cast(out_values.size())}), + &values_tensor)); + auto values = values_tensor->vec(); + std::copy_n(out_values.begin(), out_values.size(), &values(0)); + + Tensor* shape_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_output( + 3, TensorShape({input_tensor->shape().dims() + 1}), + &shape_tensor)); + auto shape = shape_tensor->flat(); + for (int i = 0; i < input_tensor->shape().dims(); i++) { + shape(i) = input_tensor->shape().dim_size(i); + } + shape(input_tensor->shape().dims()) = num_features_; + } + + private: + int64 num_features_; +}; + +namespace { +template <> +bool ConvertHelper(const string& s, float* value) { + return strings::safe_strtof(s.c_str(), value); +} +template <> +bool ConvertHelper(const string& s, double* value) { + return strings::safe_strtod(s.c_str(), value); +} +template <> +bool ConvertHelper(const string& s, int32* value) { + return strings::safe_strto32(s.c_str(), value); +} +template <> +bool ConvertHelper(const string& s, int64* value) { + return strings::safe_strto64(s.c_str(), value); +} +} // namespace + +#define REGISTER_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype") \ + .TypeConstraint("label_dtype"), \ + DecodeLibsvmOp); \ + REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype") \ + .TypeConstraint("label_dtype"), \ + DecodeLibsvmOp); \ + REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype") \ + .TypeConstraint("label_dtype"), \ + DecodeLibsvmOp); \ + REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype") \ + .TypeConstraint("label_dtype"), \ + DecodeLibsvmOp); + +REGISTER_KERNEL(float); +REGISTER_KERNEL(double); +REGISTER_KERNEL(int32); +REGISTER_KERNEL(int64); +#undef REGISTER_KERNEL + +} // namespace tensorflow diff --git a/tensorflow/contrib/libsvm/ops/libsvm_ops.cc b/tensorflow/contrib/libsvm/ops/libsvm_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..4c65e676291d1244f5224c43d32a321ae72ffe41 --- /dev/null +++ b/tensorflow/contrib/libsvm/ops/libsvm_ops.cc @@ -0,0 +1,59 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + +REGISTER_OP("DecodeLibsvm") + .Input("input: string") + .Output("label: label_dtype") + .Output("feature_indices: int64") + .Output("feature_values: dtype") + .Output("feature_shape: int64") + .Attr("dtype: {float, double, int32, int64} = DT_FLOAT") + .Attr("label_dtype: {float, double, int32, int64} = DT_INT64") + .Attr("num_features: int >= 1") + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->input(0)); + + c->set_output(1, c->Matrix(InferenceContext::kUnknownDim, + InferenceContext::kUnknownDim)); + c->set_output(2, c->Vector(InferenceContext::kUnknownDim)); + c->set_output(3, c->Vector(InferenceContext::kUnknownDim)); + + return Status::OK(); + }) + + .Doc(R"doc( +Convert LibSVM input to tensors. The output consists of +a label and a feature tensor. The shape of the label tensor +is the same as input and the shape of the feature tensor is +`[input_shape, num_features]`. + +input: Each string is a record in the LibSVM. +label: A tensor of the same shape as input. +feature_indices: A 2-D int64 tensor of dense_shape [N, ndims]. +feature_values: A 1-D tensor of any type and dense_shape [N]. +feature_shape: A 1-D int64 tensor of dense_shape [ndims]. +num_features: The number of features. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py b/tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7d9d5ceed393c03e6baa0872950670cf1ff71d3f --- /dev/null +++ b/tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py @@ -0,0 +1,72 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for DecodeLibsvm op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.libsvm.python.ops import libsvm_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import sparse_ops +from tensorflow.python.platform import test + + +class DecodeLibsvmOpTest(test.TestCase): + + def testBasic(self): + with self.test_session() as sess: + content = ["1 1:3.4 2:0.5 4:0.231", + "1 2:2.5 3:inf 5:0.503", + "2 3:2.5 2:nan 1:0.105"] + sparse_features, labels = libsvm_ops.decode_libsvm(content, + num_features=6) + features = sparse_ops.sparse_tensor_to_dense(sparse_features, + validate_indices=False) + + self.assertAllEqual(labels.get_shape().as_list(), [3]) + + features, labels = sess.run([features, labels]) + self.assertAllEqual(labels, [1, 1, 2]) + self.assertAllClose(features, [[0, 3.4, 0.5, 0, 0.231, 0], + [0, 0, 2.5, np.inf, 0, 0.503], + [0, 0.105, np.nan, 2.5, 0, 0]]) + + def testNDimension(self): + with self.test_session() as sess: + content = [["1 1:3.4 2:0.5 4:0.231", "1 1:3.4 2:0.5 4:0.231"], + ["1 2:2.5 3:inf 5:0.503", "1 2:2.5 3:inf 5:0.503"], + ["2 3:2.5 2:nan 1:0.105", "2 3:2.5 2:nan 1:0.105"]] + sparse_features, labels = libsvm_ops.decode_libsvm( + content, num_features=6, label_dtype=dtypes.float64) + features = sparse_ops.sparse_tensor_to_dense(sparse_features, + validate_indices=False) + + self.assertAllEqual(labels.get_shape().as_list(), [3, 2]) + + features, labels = sess.run([features, labels]) + self.assertAllEqual(labels, [[1, 1], [1, 1], [2, 2]]) + self.assertAllClose(features, [[[0, 3.4, 0.5, 0, 0.231, 0], + [0, 3.4, 0.5, 0, 0.231, 0]], + [[0, 0, 2.5, np.inf, 0, 0.503], + [0, 0, 2.5, np.inf, 0, 0.503]], + [[0, 0.105, np.nan, 2.5, 0, 0], + [0, 0.105, np.nan, 2.5, 0, 0]]]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/libsvm/python/ops/libsvm_ops.py b/tensorflow/contrib/libsvm/python/ops/libsvm_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..9c133e7e7f048966f222a5e3a1a61c5ac7c723eb --- /dev/null +++ b/tensorflow/contrib/libsvm/python/ops/libsvm_ops.py @@ -0,0 +1,51 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Libsvm decoder.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.libsvm.ops import gen_libsvm_ops +from tensorflow.contrib.util import loader +from tensorflow.python.framework import common_shapes +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import io_ops +from tensorflow.python.platform import resource_loader + + +_libsvm_ops_so = loader.load_op_library( + resource_loader.get_path_to_datafile("_libsvm_ops.so")) + +def decode_libsvm(content, num_features, dtype=None, label_dtype=None): + """Convert Libsvm records to a tensor of label and a tensor of feature. + + Args: + content: A `Tensor` of type `string`. Each string is a record/row in + the Libsvm format. + num_features: The number of features. + dtype: The type of the output feature tensor. Default to tf.float32. + label_dtype: The type of the output label tensor. Default to tf.int64. + + Returns: + features: A `SparseTensor` of the shape `[input_shape, num_features]`. + labels: A `Tensor` of the same shape as content. + """ + labels, indices, values, shape = gen_libsvm_ops.decode_libsvm( + content, num_features, dtype=dtype, label_dtype=label_dtype) + return sparse_tensor.SparseTensor(indices, values, shape), labels + + +ops.NotDifferentiable('DecodeLibSVM') diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py index 86d848439191aeb0dfa88bbe0fb9b3b654499423..7526f3ae0dbdb3d6827e9d7f690090b8438e4f6e 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py @@ -251,8 +251,9 @@ class SdcaModel(object): result_dense = 0.0 for i in range(len(dense_variables)): - result_dense += math_ops.matmul( - dense_features[i], array_ops.expand_dims(dense_variables[i], -1)) + result_dense += math_ops.matmul(dense_features[i], + array_ops.expand_dims( + dense_variables[i], -1)) # Reshaping to allow shape inference at graph construction time. return array_ops.reshape(result_dense, [-1]) + result_sparse diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py index 701fc1c0597d1de0b0189e86feafbd1c5bbdc818..05794a42c5f2d0eece6adab36fb5610078cece31 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py @@ -19,7 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib import layers -from tensorflow.contrib.framework.python.ops import variables as contrib_variables +from tensorflow.python.training import training_util from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib from tensorflow.contrib.learn.python.learn.estimators import prediction_key @@ -154,7 +154,7 @@ def sdca_model_fn(features, labels, mode, params, config=None): _add_bias_column(feature_columns, features, bias, columns_to_variables) def _train_op_fn(unused_loss): - global_step = contrib_variables.get_global_step() + global_step = training_util.get_global_step() sdca_model, train_op = optimizer.get_train_step( columns_to_variables, weight_column_name, loss_type, features, labels, global_step) diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index 96a9e281ad11009e8406bb6ccd583adba09f9f0d..3f1b0be1a73a3ff1da3452f4ee1a9125f9e26178 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -111,6 +111,7 @@ cc_test( deps = [ ":framework", ":string_util", + "//tensorflow/contrib/lite/testing:util", "@com_google_googletest//:gtest", ], ) @@ -134,6 +135,7 @@ cc_test( srcs = ["simple_memory_arena_test.cc"], deps = [ ":framework", + "//tensorflow/contrib/lite/testing:util", "@com_google_googletest//:gtest", ], ) @@ -152,6 +154,7 @@ cc_test( ], deps = [ ":framework", + "//tensorflow/contrib/lite/testing:util", "@com_google_googletest//:gtest", ], ) @@ -163,6 +166,7 @@ cc_test( srcs = ["context_test.cc"], deps = [ ":framework", + "//tensorflow/contrib/lite/testing:util", "@com_google_googletest//:gtest", ], ) @@ -191,6 +195,9 @@ filegroup( exclude = [ "**/METADATA", "**/OWNERS", + "downloads", + "examples", + "gen", ], ), visibility = ["//tensorflow:__subpackages__"], diff --git a/tensorflow/contrib/lite/README.md b/tensorflow/contrib/lite/README.md index 827c5d0baa90b73a72c3565e23c417c24b1d06d8..852284cbc7f33b5d9c0f7774bca89c1dff3fa3ec 100644 --- a/tensorflow/contrib/lite/README.md +++ b/tensorflow/contrib/lite/README.md @@ -1,10 +1,10 @@ # TensorFlow Lite -TensorFlow Lite is TensorFlow’s lightweight solution for mobile and embedded devices. It enables low-latency inference of on-device machine learning models with a small binary size and fast performance supporting hardware acceleration. +TensorFlow Lite is TensorFlow's lightweight solution for mobile and embedded devices. It enables low-latency inference of on-device machine learning models with a small binary size and fast performance supporting hardware acceleration. TensorFlow Lite uses many techniques for achieving low latency like optimizing the kernels for specific mobile apps, pre-fused activations, quantized kernels that allow smaller and faster (fixed-point math) models, and in the future, leverage specialized machine learning hardware to get the best possible performance for a particular model on a particular device. ![image](g3doc/TFLite-Architecture.jpg) -# Getting Started with a Demo App +# Getting Started with an Android Demo App This section contains an example application using TensorFlow Lite for Android devices. The demo is a sample camera app that classifies images continuously using a quantized Mobilenet model. A device running Android 5.0 ( API 21) or higher is required to run the demo. @@ -17,21 +17,21 @@ There are 3 ways to get the demo app to your device In the demo app, inference is done using the TensorFlow Lite Java API. The demo app classifies frames in real-time, displaying the top most probable classifications. It also displays the time taken to detect the object. ## Downloading the pre-built binary -The fastest path to trying the demo, is to download the pre-built binary +The fastest path to trying the demo, is to download the pre-built binary [TfLiteCameraDemo.apk](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk) -Once the apk is installed, click the app icon to start the app. The first-time the app is opened, the app asks for runtime permissions to access the device camera. The demo app opens the back-camera of the device and recognizes the objects in the camera’s field of view. At the bottom of the image (or at the left of the image if the device is in landscape mode), it shows the latency of classification and the top three objects classified. +Once the apk is installed, click the app icon to start the app. The first-time the app is opened, the app asks for runtime permissions to access the device camera. The demo app opens the back-camera of the device and recognizes the objects in the camera's field of view. At the bottom of the image (or at the left of the image if the device is in landscape mode), it shows the latency of classification and the top three objects classified. ## Building in Android Studio using TensorFlow Lite AAR from JCenter The simplest way to compile the demo app, and try out changes to the project code is to use AndroidStudio. - Install the latest version of Android Studio 3 as specified [here](https://developer.android.com/studio/index.html). - Make sure the Android SDK version is greater than 26 and NDK version is greater than 14 (in the Android Studio Settings). - - Import the tensorflow/contrib/lite/java/demo directory as a new Android Studio project. + - Import the `tensorflow/contrib/lite/java/demo` directory as a new Android Studio project. - Click through installing all the Gradle extensions it requests. - Download the quantized Mobilenet TensorFlow Lite model from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip) - unzip and copy mobilenet_quant_v1_224.tflite to the assets directory: - tensorflow/contrib/lite/java/demo/app/src/main/assets/ + `tensorflow/contrib/lite/java/demo/app/src/main/assets/` - Build and run the demo app ## Building TensorFlow Lite and the demo app from source @@ -43,7 +43,7 @@ The simplest way to compile the demo app, and try out changes to the project cod ### Install Bazel If bazel is not installed on your system, install it now by following [these directions](https://bazel.build/versions/master/docs/install.html) -NOTE: Bazel does not currently support building for Android on Windows. Full support for gradle/cmake builds is coming soon, but in the meantime Windows users should download the [prebuilt binary](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk) instead. +NOTE: Bazel does not fully support building Android on Windows yet. Full support for Gradle/CMake builds is coming soon, but in the meantime Windows users should download the [prebuilt binary](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk) instead. ### Install Android NDK and SDK Bazel is the primary build system for TensorFlow. Bazel and the Android NDK and SDK must be installed on your system. @@ -53,25 +53,30 @@ Bazel is the primary build system for TensorFlow. Bazel and the Android NDK and - In the root of the TensorFlow repository update the `WORKSPACE` file with the `api_level` and location of the SDK and NDK. If you installed it with AndroidStudio the SDK path can be found in the SDK manager, and the default NDK path is:`{SDK path}/ndk-bundle.` ``` - Android_sdk_repository ( - name = "androidsdk", - api_level = 23, - build_tools_version = "23.0.2", - path = "/home/xxxx/android-sdk-linux/", ) +android_sdk_repository ( + name = "androidsdk", + api_level = 23, + build_tools_version = "23.0.2", + path = "/home/xxxx/android-sdk-linux/", +) android_ndk_repository( - name="androidndk", - path="/home/xxxx/android-ndk-r10e/", - api_level=19) - + name = "androidndk", + path = "/home/xxxx/android-ndk-r10e/", + api_level = 19, +) ``` + Additional details on building with Android can be found [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/README.md). -### Build the source code +### Build the source code Run bazel with the following command to build the demo. Build the demo app: -bazel build --cxxopt='--std=c++11' //tensorflow/contrib/lite/java/demo/app/src/main:TfLiteCameraDemo + +``` +bazel build --cxxopt=--std=c++11 //tensorflow/contrib/lite/java/demo/app/src/main:TfLiteCameraDemo +``` ### Note @@ -81,6 +86,17 @@ environment (due to a Bazel bug). ### More about the demo The demo is resizing each camera image frame to (224 width * 224 height) to match the quantized Mobilenet model being used. The resized image is converted into a ByteBuffer row by row of size 1 * 224 * 224 * 3 bytes, where 1 is the number of images in a batch 224 * 224 is the width and height of the image 3 bytes represents three colors of a pixel. This demo uses the TensorFlow Lite Java inference API for models which take a single input and provide a single output. This outputs a two-dimensional array, with the first dimension being the category index and the second dimension being the confidence of classification. The Mobilenet model has 1001 unique categories and the app sorts the probabilities of all the categories and displays the top three. The Mobilenet quantized model is bundled within the assets directory of the app. +# iOS Demo App + +Similar to the Android demo app, there's an iOS camera app that uses exactly the same model (224 * 224 quantized Mobilenet). + +This demo app requires a camera so it doesn't work with simulators. It need to be executed on a real iOS device. Follow the instructions to build and run the demo app: + +1. Follow the Building section [here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/ios.md#building) to build the universal iOS library for TensorFlow Lite. +1. Install [CocoaPods](https://cocoapods.org/) if it wasn't installed yet: `sudo gem install cocoapods`. +1. Run `pod install` in `tensorflow/contrib/lite/examples/ios/camera` to generate the workspace file. +1. Open the project by running `open tflite_camera_example.xcworkspace`, and build the app in XCode. + # TensorFlow Lite Quick Start ## Step 1. Decide which GraphDef to use @@ -105,7 +121,7 @@ The [TensorFlow for Poets](https://codelabs.developers.google.com/codelabs/tenso ### Train a custom model -A developer may choose to train a custom model using Tensorflow. TensorFlow documentation has [several tutorials](https://www.tensorflow.org/tutorials/) for building and training models. If the user has written a model using TensorFlow’s Slim Framework the first step is to export this to a GraphDef file. This is necessary because Slim does not store the model structure outside the code, so to communicate with other parts of the framework it needs to be exported. Documentation for the export can be found [here](https://github.com/tensorflow/models/tree/master/research/slim#Export). The output of this step will be a .pb file for the custom model. +A developer may choose to train a custom model using Tensorflow. TensorFlow documentation has [several tutorials](https://www.tensorflow.org/tutorials/) for building and training models. If the user has written a model using TensorFlow's Slim Framework the first step is to export this to a GraphDef file. This is necessary because Slim does not store the model structure outside the code, so to communicate with other parts of the framework it needs to be exported. Documentation for the export can be found [here](https://github.com/tensorflow/models/tree/master/research/slim#Export). The output of this step will be a .pb file for the custom model. TensorFlow Lite currently supports a subset of TensorFlow operators. Please refer to [this document](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md) for details of supported operators and their usage. This set will continue to expand in future releases of Tensorflow Lite. @@ -129,7 +145,7 @@ Since we employ several formats, the following definitions may be useful: - TensorFlow lite model (.lite) - a serialized flatbuffer, containing TensorFlow lite operators and Tensors for the TensorFlow lite interpreter. This is most analogous to TensorFlow frozen GraphDefs. ### Freeze Graph -To use this .pb GraphDef file within TensorFlow Lite, the application developer will need checkpoints containing trained weight parameters. The .pb contains only the structure of the graph. The process of merging the checkpoint values with the graph structure is known as “freezing” the graph. +To use this .pb GraphDef file within TensorFlow Lite, the application developer will need checkpoints containing trained weight parameters. The .pb contains only the structure of the graph. The process of merging the checkpoint values with the graph structure is known as "freezing" the graph. The developer should know where the checkpoints folder is present or checkpoints can also be downloaded for a pre-trained model (Example: Here is a link to the [MobileNets](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.md)). @@ -151,12 +167,13 @@ graphviz, or [in tensorboard](https://codelabs.developers.google.com/codelabs/te This frozen Graphdef is now ready to be converted to flatbuffer format (.lite) for use on Android or iOS. On Android users have the flexibility to use either the float or quantized versions of the frozen graphdef, if available, using the Tensorflow Optimizing Converter tool. Here is a sample command line to convert the frozen Graphdef to '.lite' format for The Tensorflow Optimizing Converter supports both float and quantized models, however, different configuration parameters are needed depending on whether a FLOAT or QUANTIZED mode is being used. +(Here is a link to the pb [file](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz)). ``` bazel build tensorflow/contrib/lite/toco:toco -bazel run --config=opt tensorflow/contrib/lite/toco:toco -- \ - --input_file=(pwd)/mobilenet_v1_1.0_224/frozen_graph.pb \ +bazel-bin/tensorflow/contrib/lite/toco/toco -- \ + --input_file=$(pwd)/mobilenet_v1_1.0_224/frozen_graph.pb \ --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \ --output_file=/tmp/mobilenet_v1_1.0_224.lite --inference_type=FLOAT \ --input_type=FLOAT --input_arrays=input \ @@ -169,7 +186,7 @@ bazel run --config=opt tensorflow/contrib/lite/toco:toco -- \ - Setting the input_array, output_array and input_shape arguments are a bit trickier. The easiest way to find these values is to explore the graph in tensorboard . The user should reuse the arguments that were used for specifying the output nodes for inference in the `freeze_graph`step. Note, it is also possible to use the Tensorflow Optimizing Converter through protos either from Python or from the command line see the -documentation [here](https://github.com/tensorflow/tensorflow/tree/mastertensorflow/contrib/lite/python:toco_from_protos target) A developer can then integrate the conversion step into their model design workflow to ensure that a model will be easily convertible to a mobile inference graph. For example, +documentation [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/toco/python/toco_from_protos.py). A developer can then integrate the conversion step into their model design workflow to ensure that a model will be easily convertible to a mobile inference graph. For example, ``` import tensorflow as tf @@ -184,7 +201,7 @@ with tf.Session() as sess: ``` For detailed instructions on how to use the Tensorflow Optimizing Converter, please see [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md). -You may refer to the [Ops compatibility guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md) for troubleshooting help. If that doesn’t help, please file an [issue](https://github.com/tensorflow/tensorflow/issues). +You may refer to the [Ops compatibility guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md) for troubleshooting help. If that doesn't help, please file an [issue](https://github.com/tensorflow/tensorflow/issues). ## Step 3. Use the TensorFlow Lite model for inference in a mobile app @@ -193,9 +210,13 @@ After completion of Step 2 the developer should have a .lite model. ### For Android Because Android apps need to be written in Java, and core TensorFlow is in C++, a JNI library is provided to interface between the two. Its interface is aimed only at inference, so it provides the ability to load a graph, set up inputs, and run the model to calculate particular outputs. The full documentation for the set of methods can be seen [here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/). The demo app is also open sourced on [github](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app). -The [demo app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app) uses this interface, so it’s a good place to look for example usage. You can also download the prebuilt binary [here](http://download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk). +The [demo app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app) uses this interface, so it's a good place to look for example usage. You can also download the prebuilt binary [here](http://download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk). -Note that you’d need to follow instructions for installing TensorFlow on Android, setting up bazel and Android Studio outlined [here](https://www.tensorflow.org/mobile/android_build). +Note that you'd need to follow instructions for installing TensorFlow on Android, setting up bazel and Android Studio outlined [here](https://www.tensorflow.org/mobile/android_build). ### For iOS -Follow the documentation [here](https://github.com/TensorFlow/TensorFlow/blob/master/TensorFlow/contrib/lite/g3doc/ios.md) to get integrate a TFLite model into your app. +Follow the documentation [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/ios.md) to get integrate a TFLite model into your app. + +## Core ML support + +Core ML is a machine learning framework used across Apple products. In addition to using Tensorflow Lite models directly in their applications, developers have the option to convert their trained Tensorflow models to the [CoreML](https://developer.apple.com/machine-learning/) format for use on Apple devices. For information on how to use the converter please refer to the [Tensorflow-CoreML converter documentation](https://github.com/tf-coreml/tf-coreml). diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index e3c9cdd99beb93e356c148298dcbe6498fbe0306..d1fcdce70a34393defce0f2d0f6d5bb53f21c45e 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -89,6 +89,7 @@ def tflite_jni_linkopts(): return tflite_jni_linkopts_unstripped() + select({ "//tensorflow:android": [ "-s", # Omit symbol table. + "-latomic", # Required for some uses of ISO C++11 in x86. ], "//conditions:default": [], }) @@ -223,11 +224,12 @@ def gen_selected_ops(name, model): """ out = name + "_registration.cc" tool = "//tensorflow/contrib/lite/tools:generate_op_registrations" + tflite_path = "//tensorflow/contrib/lite" native.genrule( name = name, srcs = [model], outs = [out], - cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s)") - % (tool, model, out), + cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s) --tflite_path=%s") + % (tool, model, out, tflite_path[2:]), tools = [tool], ) diff --git a/tensorflow/contrib/lite/build_ios_universal_lib.sh b/tensorflow/contrib/lite/build_ios_universal_lib.sh index e0f2ef768bfed544ed8acd6c0e3a5823e61a1e8c..cbc96e6edd4358f6666731caa4c208c77d9c6c54 100755 --- a/tensorflow/contrib/lite/build_ios_universal_lib.sh +++ b/tensorflow/contrib/lite/build_ios_universal_lib.sh @@ -1,4 +1,19 @@ #!/bin/bash -x +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + set -e make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=x86_64 -j 8 make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=i386 -j 8 diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index 93072bf90bd8a18d9011a74c2eec95d86dbdce8a..5c6f3016b1c7d06ba35229faeff9cec32e168ef2 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -104,6 +104,17 @@ typedef struct { TfLiteFusedActivation activation; } TfLiteAddParams; +typedef struct { + // Number of spatial dimensions. + // For now only NHWC is supported, and the value should always be 2. + int num_spatial_dimensions; + // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. + // For now we will fix the maximum possible number of dimensions. + int block_shape[2]; + int before_crops[2]; + int after_crops[2]; +} TfLiteBatchToSpaceNDParams; + typedef struct { TfLiteFusedActivation activation; } TfLiteMulParams; @@ -130,6 +141,14 @@ typedef struct { int new_width; } TfLiteResizeBilinearParams; +typedef struct { + // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. + // For now we will fix the maximum possible number of dimensions. + int before_padding[8]; + int after_padding[8]; + int num_dimensions; +} TfLitePadParams; + typedef struct { // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. // For now we will fix the maximum possible number of dimensions. @@ -157,6 +176,10 @@ typedef struct { TfLiteCombinerType combiner; } TfLiteEmbeddingLookupSparseParams; +typedef struct { + int axis; +} TfLiteGatherParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/contrib/lite/context_test.cc b/tensorflow/contrib/lite/context_test.cc index d0a104f43d9b9d148d80ce26b8ecf732d51ef110..20d6f69a25e9f0bb4323cf5d067b8ebd37bb3c23 100644 --- a/tensorflow/contrib/lite/context_test.cc +++ b/tensorflow/contrib/lite/context_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/contrib/lite/context.h" #include +#include "tensorflow/contrib/lite/testing/util.h" namespace tflite { @@ -68,7 +69,7 @@ TEST(IntArray, TestIntArrayEqual) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/download_dependencies.sh b/tensorflow/contrib/lite/download_dependencies.sh index 41480c20077f4b31928cf17ff02e357f5dea6851..7fce1ba3461066e6dada95246781440258d844c1 100755 --- a/tensorflow/contrib/lite/download_dependencies.sh +++ b/tensorflow/contrib/lite/download_dependencies.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,6 +19,13 @@ set -e DOWNLOADS_DIR=tensorflow/contrib/lite/downloads BZL_FILE_PATH=tensorflow/workspace.bzl +# Ensure it is being run from repo root +if [ ! -f $BZL_FILE_PATH ]; then + echo "Could not find ${BZL_FILE_PATH}": + echo "Likely you are not running this from the root directory of the repository."; + exit 1; +fi + EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)" GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" @@ -56,11 +63,19 @@ download_and_extract() { elif [[ "${url}" == *zip ]]; then tempdir=$(mktemp -d) tempdir2=$(mktemp -d) - wget -P ${tempdir} ${url} - unzip ${tempdir}/* -d ${tempdir2} - # unzip has no strip components, so unzip to a temp dir, and move the files - # we want from the tempdir to destination. - echo cp `find ${tempdir2} -type f` ${dir}/ + + curl -L ${url} > ${tempdir}/zipped.zip + unzip ${tempdir}/zipped.zip -d ${tempdir2} + + # If the zip file contains nested directories, extract the files from the + # inner directory. + if ls ${tempdir2}/*/* 1> /dev/null 2>&1; then + # unzip has no strip components, so unzip to a temp dir, and move the + # files we want from the tempdir to destination. + cp -R ${tempdir2}/*/* ${dir}/ + else + cp -R ${tempdir2}/* ${dir}/ + fi rm -rf ${tempdir2} ${tempdir} fi diff --git a/tensorflow/contrib/lite/error_reporter.cc b/tensorflow/contrib/lite/error_reporter.cc index 6ba5384a94dbf9de03fb2e4e2f63074525eafa2d..03fcd5409ceab1895cea3b9e0e4fcb5a127e6a45 100644 --- a/tensorflow/contrib/lite/error_reporter.cc +++ b/tensorflow/contrib/lite/error_reporter.cc @@ -39,7 +39,9 @@ int ErrorReporter::ReportError(void*, const char* format, ...) { } int StderrReporter::Report(const char* format, va_list args) { - return vfprintf(stderr, format, args); + const int result = vfprintf(stderr, format, args); + fputc('\n', stderr); + return result; } ErrorReporter* DefaultErrorReporter() { diff --git a/tensorflow/contrib/lite/error_reporter.h b/tensorflow/contrib/lite/error_reporter.h index 637d456ce7a754c7da34e551869e49b4efd18e3b..d5715e4f90aead79a617fe4576bfe5100d5e121a 100644 --- a/tensorflow/contrib/lite/error_reporter.h +++ b/tensorflow/contrib/lite/error_reporter.h @@ -25,10 +25,10 @@ namespace tflite { // // Usage: // ErrorReporter foo; -// foo.Report("test %d\n", 5); +// foo.Report("test %d", 5); // or // va_list args; -// foo.Report("test %d\n", args); // where args is va_list +// foo.Report("test %d", args); // where args is va_list // // Sublclass ErrorReporter to provide another reporting destination. // For example, if you have a GUI program, you might redirect to a buffer diff --git a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm index ea398ad14e8be4c5a0021befc7cc076549b47e23..10f31bb6f17242c9f7f70f0648ec643f99c5ac86 100644 --- a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm +++ b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm @@ -123,7 +123,11 @@ static void GetTopN(const uint8_t* prediction, const int prediction_size, const AVCaptureDevice* device = [AVCaptureDevice defaultDeviceWithMediaType:AVMediaTypeVideo]; AVCaptureDeviceInput* deviceInput = [AVCaptureDeviceInput deviceInputWithDevice:device error:&error]; - assert(error == nil); + + if (error != nil) { + NSLog(@"Failed to initialize AVCaptureDeviceInput. Note: This app doesn't work with simulator"); + assert(NO); + } if ([session canAddInput:deviceInput]) [session addInput:deviceInput]; diff --git a/tensorflow/contrib/lite/examples/ios/camera/data/.gitignore b/tensorflow/contrib/lite/examples/ios/camera/data/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.h b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.h index 75b1f1da384b527e8332dfba08fec87c65eff8b1..94046d9728258901091f018fd0d081651145f400 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.h +++ b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.h @@ -14,8 +14,8 @@ #import -@interface AppDelegate : UIResponder +@interface AppDelegate : UIResponder -@property (strong, nonatomic) UIWindow *window; +@property(strong, nonatomic) UIWindow *window; @end diff --git a/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm index 1e808eb976ff3eeda4cf6f81b3c1794c6a037dc8..d1215fa0bffd978b4aaadbd8bc13b07723703c9a 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm +++ b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm @@ -22,8 +22,7 @@ didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { UITabBarController *bar = [[UITabBarController alloc] init]; - [bar setViewControllers: - @[[[RunModelViewController alloc] init]]]; + [bar setViewControllers:@[ [[RunModelViewController alloc] init] ]]; bar.selectedIndex = 0; self.window = [[UIWindow alloc] initWithFrame:[[UIScreen mainScreen] bounds]]; self.window.rootViewController = bar; @@ -31,14 +30,19 @@ return YES; } -- (void)applicationWillResignActive:(UIApplication *)application {} +- (void)applicationWillResignActive:(UIApplication *)application { +} -- (void)applicationDidEnterBackground:(UIApplication *)application {} +- (void)applicationDidEnterBackground:(UIApplication *)application { +} -- (void)applicationWillEnterForeground:(UIApplication *)application {} +- (void)applicationWillEnterForeground:(UIApplication *)application { +} -- (void)applicationDidBecomeActive:(UIApplication *)application {} +- (void)applicationDidBecomeActive:(UIApplication *)application { +} -- (void)applicationWillTerminate:(UIApplication *)application {} +- (void)applicationWillTerminate:(UIApplication *)application { +} @end diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.h b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.h index 4e1a83ccf5a12c609baadab7359c55ec4f464ed8..a4b358b4eb7f6ba109638405091b798d30bd1768 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.h +++ b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.h @@ -18,7 +18,7 @@ - (IBAction)getUrl:(id)sender; -@property (weak, nonatomic) IBOutlet UITextView *urlContentTextView; -@property (weak, nonatomic) IBOutlet UITextField *urlTextField; +@property(weak, nonatomic) IBOutlet UITextView *urlContentTextView; +@property(weak, nonatomic) IBOutlet UITextField *urlTextField; @end diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm index 965d83010516c6db72c9e8b1c33079b3eda204de..0dafb1f61e19f46bb3b17f07c55e09f5813ed560 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm +++ b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm @@ -14,10 +14,10 @@ #import "RunModelViewController.h" -#include -#include #include #include +#include +#include #include #include #include @@ -30,7 +30,11 @@ #include "ios_image_load.h" #define LOG(x) std::cerr -#define CHECK(x) if (!(x)) { LOG(ERROR) << #x << "failed"; exit(1); } +#define CHECK(x) \ + if (!(x)) { \ + LOG(ERROR) << #x << "failed"; \ + exit(1); \ + } NSString* RunInferenceOnImage(); @@ -49,15 +53,12 @@ NSString* RunInferenceOnImage(); // Returns the top N confidence values over threshold in the provided vector, // sorted by confidence in descending order. -static void GetTopN( - const float* prediction, - const int prediction_size, - const int num_results, const float threshold, - std::vector >* top_results) { +static void GetTopN(const float* prediction, const int prediction_size, const int num_results, + const float threshold, std::vector >* top_results) { // Will contain top N results in ascending order. - std::priority_queue, - std::vector >, - std::greater > > top_result_pq; + std::priority_queue, std::vector >, + std::greater > > + top_result_pq; const long count = prediction_size; for (int i = 0; i < count; ++i) { @@ -88,8 +89,8 @@ static void GetTopN( NSString* FilePathForResourceName(NSString* name, NSString* extension) { NSString* file_path = [[NSBundle mainBundle] pathForResource:name ofType:extension]; if (file_path == NULL) { - LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." - << [extension UTF8String] << "' in bundle."; + LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." << [extension UTF8String] + << "' in bundle."; } return file_path; } @@ -102,7 +103,8 @@ NSString* RunInferenceOnImage() { NSString* graph_path = FilePathForResourceName(@"mobilenet_v1_1.0_224", @"tflite"); - std::unique_ptr model(tflite::FlatBufferModel::BuildFromFile([graph_path UTF8String])); + std::unique_ptr model( + tflite::FlatBufferModel::BuildFromFile([graph_path UTF8String])); if (!model) { LOG(FATAL) << "Failed to mmap model " << graph; } @@ -143,7 +145,7 @@ NSString* RunInferenceOnImage() { std::ifstream t; t.open([labels_path UTF8String]); std::string line; - while(t){ + while (t) { std::getline(t, line); label_strings.push_back(line); } @@ -154,7 +156,8 @@ NSString* RunInferenceOnImage() { int image_width; int image_height; int image_channels; - std::vector image_data = LoadImageFromFile([image_path UTF8String], &image_width, &image_height, &image_channels); + std::vector image_data = + LoadImageFromFile([image_path UTF8String], &image_width, &image_height, &image_channels); const int wanted_width = 224; const int wanted_height = 224; const int wanted_channels = 3; @@ -212,8 +215,7 @@ NSString* RunInferenceOnImage() { std::string predictions = ss.str(); NSString* result = @""; - result = [NSString stringWithFormat: @"%@ - %s", result, - predictions.c_str()]; - + result = [NSString stringWithFormat:@"%@ - %s", result, predictions.c_str()]; + return result; } diff --git a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h index 7287d0d63d5b4c0b9c9a528578b6341cdb9c9954..98934ce41d349b33d4fc010a39a956e52f3d5721 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h +++ b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h @@ -17,9 +17,7 @@ #include -std::vector LoadImageFromFile(const char* file_name, - int* out_width, - int* out_height, - int* out_channels); +std::vector LoadImageFromFile(const char* file_name, int* out_width, + int* out_height, int* out_channels); #endif // TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_ diff --git a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm index 789522d2a9900b136f91f77c4ada682f1a316848..cb0fe1a7650c572d3745066431f2759daa94ffc9 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm +++ b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm @@ -14,17 +14,16 @@ #include "ios_image_load.h" -#include -#include #include #include +#include +#include #import #import -std::vector LoadImageFromFile(const char* file_name, - int* out_width, int* out_height, - int* out_channels) { +std::vector LoadImageFromFile(const char* file_name, int* out_width, int* out_height, + int* out_channels) { FILE* file_handle = fopen(file_name, "rb"); fseek(file_handle, 0, SEEK_END); const size_t bytes_in_file = ftell(file_handle); @@ -32,11 +31,10 @@ std::vector LoadImageFromFile(const char* file_name, std::vector file_data(bytes_in_file); fread(file_data.data(), 1, bytes_in_file, file_handle); fclose(file_handle); - CFDataRef file_data_ref = CFDataCreateWithBytesNoCopy(NULL, file_data.data(), - bytes_in_file, - kCFAllocatorNull); - CGDataProviderRef image_provider = - CGDataProviderCreateWithCFData(file_data_ref); + + CFDataRef file_data_ref = + CFDataCreateWithBytesNoCopy(NULL, file_data.data(), bytes_in_file, kCFAllocatorNull); + CGDataProviderRef image_provider = CGDataProviderCreateWithCFData(file_data_ref); const char* suffix = strrchr(file_name, '.'); if (!suffix || suffix == file_name) { @@ -44,12 +42,10 @@ std::vector LoadImageFromFile(const char* file_name, } CGImageRef image; if (strcasecmp(suffix, ".png") == 0) { - image = CGImageCreateWithPNGDataProvider(image_provider, NULL, true, - kCGRenderingIntentDefault); - } else if ((strcasecmp(suffix, ".jpg") == 0) || - (strcasecmp(suffix, ".jpeg") == 0)) { - image = CGImageCreateWithJPEGDataProvider(image_provider, NULL, true, - kCGRenderingIntentDefault); + image = CGImageCreateWithPNGDataProvider(image_provider, NULL, true, kCGRenderingIntentDefault); + } else if ((strcasecmp(suffix, ".jpg") == 0) || (strcasecmp(suffix, ".jpeg") == 0)) { + image = + CGImageCreateWithJPEGDataProvider(image_provider, NULL, true, kCGRenderingIntentDefault); } else { CFRelease(image_provider); CFRelease(file_data_ref); @@ -68,9 +64,10 @@ std::vector LoadImageFromFile(const char* file_name, const int bytes_in_image = (bytes_per_row * height); std::vector result(bytes_in_image); const int bits_per_component = 8; - CGContextRef context = CGBitmapContextCreate(result.data(), width, height, - bits_per_component, bytes_per_row, color_space, - kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big); + + CGContextRef context = + CGBitmapContextCreate(result.data(), width, height, bits_per_component, bytes_per_row, + color_space, kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big); CGColorSpaceRelease(color_space); CGContextDrawImage(context, CGRectMake(0, 0, width, height), image); CGContextRelease(context); diff --git a/tensorflow/contrib/lite/examples/ios/simple/main.mm b/tensorflow/contrib/lite/examples/ios/simple/main.mm index d70550a730720e5d6799a186c1beb3cfa04b0b9d..05cb55ddd7a230593863e64b351f6aac31a1b4d7 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/main.mm +++ b/tensorflow/contrib/lite/examples/ios/simple/main.mm @@ -14,7 +14,7 @@ #import -int main(int argc, char * argv[]) { +int main(int argc, char *argv[]) { @autoreleasepool { NSString *delegateClassName = @"AppDelegate"; return UIApplicationMain(argc, argv, nil, delegateClassName); diff --git a/tensorflow/contrib/lite/g3doc/apis.md b/tensorflow/contrib/lite/g3doc/apis.md index 662ae2032c990b649fc6d34dcf915d58796c0665..fe208e47d1ac10995881e55c8596ae14ff4242df 100644 --- a/tensorflow/contrib/lite/g3doc/apis.md +++ b/tensorflow/contrib/lite/g3doc/apis.md @@ -52,7 +52,7 @@ typedef enum { Failures can be easily verified with: ```c++ if (status != kTfLiteOk) { - // ... error handling here ... + // ... error handling here ... } ``` diff --git a/tensorflow/contrib/lite/g3doc/ios.md b/tensorflow/contrib/lite/g3doc/ios.md index ce8b37fbf9b0db5dee60784e85a3cbf0326fddb6..a359b8d4b481dbc15cc86db14eabda5433722b8b 100644 --- a/tensorflow/contrib/lite/g3doc/ios.md +++ b/tensorflow/contrib/lite/g3doc/ios.md @@ -45,6 +45,10 @@ into a universal file containing armv7, armv7s, arm64, i386, and x86_64 architectures. The resulting library is in `tensorflow/contrib/lite/gen/lib/libtensorflow-lite.a`. +If you get an error such as `no such file or directory: 'x86_64'` when running +`build_ios_universal_lib.sh`: open Xcode > Preferences > Locations, and ensure +a value is selected in the "Command Line Tools" dropdown. + ## Using in your own application You'll need to update various settings in your app to link against TensorFlow diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index 8bf60e91f769338aa0751761c2dc0df417ee0943..65c61e44bee48535f884a3afaddc691972f5e04b 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/error_reporter.h" #include "tensorflow/contrib/lite/simple_memory_arena.h" -#include "tensorflow/core/platform/platform.h" namespace tflite { @@ -232,7 +231,6 @@ class Interpreter { // If you know that your sizes are not changing, you need not call this. // Returns status of success or failure. - // TODO(aselle): Madde TfLiteStatus AllocateTensors(); // Invoke the interpreter (run the whole graph in dependency order). diff --git a/tensorflow/contrib/lite/java/demo/README.md b/tensorflow/contrib/lite/java/demo/README.md index 71b633c5774d93684f651821adad13c378a8243c..2e818f728ef208d30b0eeb27ffd7e3fa0c7c1a2d 100644 --- a/tensorflow/contrib/lite/java/demo/README.md +++ b/tensorflow/contrib/lite/java/demo/README.md @@ -8,7 +8,12 @@ It's easiest with Android Studio. - You'll need at least SDK version 23. + - Make sure to install the latest version of Bazel. Some distributions + ship with Bazel 0.5.4, which is too old. - Bazel requires Android Build Tools `26.0.1` or higher. + - **Bazel is incompatible with NDK revisions 15 and above,** with revision + 16 being a compile-breaking change. [Download an older version manually + instead of using the SDK Manager.](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install-bazel-and-android-prerequisites) - You also need to install the Android Support Repository, available through Android Studio under `Android SDK Manager -> SDK Tools -> Android Support Repository`. @@ -16,10 +21,15 @@ 2. [Edit your `WORKSPACE`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#edit-workspace) to add SDK and NDK targets. + NOTE: As long as you have the SDK and NDK installed, the `./configure` + script will create these rules for you. Answer "Yes" when the script asks + to automatically configure the `./WORKSPACE`. + - Make sure the `api_level` in `WORKSPACE` is set to an SDK version that you have installed. - By default, Android Studio will install the SDK to `~/Android/Sdk` and - the NDK to `~/Android/Sdk/ndk-bundle`. + the NDK to `~/Android/Sdk/ndk-bundle` (but the NDK should be a manual + download until Bazel supports NDK 16. See bullet points under (1)). 2. Build the app with Bazel. The demo needs C++11: diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml index ab7d3fd496376ae702ca75a8c496863b1ff93a90..0a71dbd0e8010f5e3a176de1f7e8321331289f7c 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml @@ -19,12 +19,12 @@ TfLiteCameraDemo diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index bbbfa3e7415bfd7a34dfc7d764da55cac22e7d42..cc02cddb3d6cce3787fd15ee1734a490389fb9b3 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -32,6 +32,7 @@ cc_library( "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:schema_fbs_version", "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/testing:util", "//tensorflow/core:lib", "@com_google_googletest//:gtest", ], @@ -76,12 +77,14 @@ cc_library( "activations.cc", "add.cc", "basic_rnn.cc", + "batch_to_space_nd.cc", "concatenation.cc", "conv.cc", "depthwise_conv.cc", "embedding_lookup.cc", "embedding_lookup_sparse.cc", "fully_connected.cc", + "gather.cc", "hashtable_lookup.cc", "kernel_util.cc", "l2norm.cc", @@ -89,6 +92,7 @@ cc_library( "lsh_projection.cc", "lstm.cc", "mul.cc", + "pad.cc", "pooling.cc", "register.cc", "reshape.cc", @@ -96,6 +100,7 @@ cc_library( "skip_gram.cc", "space_to_depth.cc", "svdf.cc", + "unidirectional_sequence_rnn.cc", ], hdrs = [ "kernel_util.h", @@ -152,6 +157,18 @@ tf_cc_test( ], ) +tf_cc_test( + name = "batch_to_space_nd_test", + size = "small", + srcs = ["batch_to_space_nd_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + tf_cc_test( name = "concatenation_test", size = "small", @@ -200,6 +217,18 @@ tf_cc_test( ], ) +tf_cc_test( + name = "unidirectional_sequence_rnn_test", + size = "small", + srcs = ["unidirectional_sequence_rnn_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + tf_cc_test( name = "l2norm_test", size = "small", @@ -224,6 +253,18 @@ tf_cc_test( ], ) +tf_cc_test( + name = "pad_test", + size = "small", + srcs = ["pad_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + tf_cc_test( name = "reshape_test", size = "small", @@ -236,6 +277,19 @@ tf_cc_test( ], ) +tf_cc_test( + name = "gather_test", + size = "small", + srcs = ["gather_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + tf_cc_test( name = "resize_bilinear_test", size = "small", diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc index f10aee70170d4a94ed54376fa410b22a60f109af..33ca56e745c043efd12b851af14f273fb273d577 100644 --- a/tensorflow/contrib/lite/kernels/activations_test.cc +++ b/tensorflow/contrib/lite/kernels/activations_test.cc @@ -317,7 +317,7 @@ TEST(QuantizedActivationsOpTest, Softmax2D) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/add_test.cc b/tensorflow/contrib/lite/kernels/add_test.cc index 8e12a837c4954832ff37a6d1ab377bee9e8d5763..ddf45bb576755d57d50c9e6e01bf50f15612c56d 100644 --- a/tensorflow/contrib/lite/kernels/add_test.cc +++ b/tensorflow/contrib/lite/kernels/add_test.cc @@ -164,8 +164,7 @@ TEST(QuantizedAddOpModel, QuantizedVariousInputShapes) { } // namespace } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); - tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc index dfa75655bcfe7762c6cc4c9a98a71d529028c03a..5ecccb985e91238f1183c8f94a2b5f468758ce55 100644 --- a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc +++ b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc @@ -261,7 +261,7 @@ TEST(FullyConnectedOpTest, BlackBoxTest) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc new file mode 100644 index 0000000000000000000000000000000000000000..0eed680fdcc2afc4bc72be55a5e7722310fa4538 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc @@ -0,0 +1,161 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace batch_to_space_nd { + +// This file has two implementations of BatchToSpaceND. +enum KernelType { + kReference, + kGenericOptimized, +}; + +struct BatchToSpaceNDContext { + BatchToSpaceNDContext(TfLiteContext* context, TfLiteNode* node) { + params = reinterpret_cast(node->builtin_data); + input = GetInput(context, node, 0); + output = GetOutput(context, node, 0); + } + TfLiteBatchToSpaceNDParams* params; + TfLiteTensor* input; + TfLiteTensor* output; +}; + +// Currently, only 4D NHWC input/output op_context are supported. +// The 4D array need to have exactly 2 spatial dimensions. +// TODO(ycling): Support arbitrary dimension in BatchToSpaceND. +const int kInputDimensionNum = 4; +const int kOutputDimensionNum = 4; +const int kSpatialDimensionNum = 2; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + // The 2nd tensor (block_shape) and the 3rd tensor (crops) are ignored now. + TF_LITE_ENSURE(context, NumInputs(node) >= 1 && NumInputs(node) <= 3); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + BatchToSpaceNDContext op_context(context, node); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.input), + kInputDimensionNum); + TF_LITE_ENSURE_EQ(context, op_context.params->num_spatial_dimensions, + kSpatialDimensionNum); + TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); + + const TfLiteIntArray* input_size = op_context.input->dims; + const int* block_shape = op_context.params->block_shape; + + // Number of batch must be multiple of (block_shape[0] * block_shape[1]). + TF_LITE_ENSURE_EQ(context, + input_size->data[0] % (block_shape[0] * block_shape[1]), 0); + + const int output_batch_size = + input_size->data[0] / (block_shape[0] * block_shape[1]); + const int output_height = input_size->data[1] * block_shape[0]; + const int output_width = input_size->data[2] * block_shape[1]; + const int output_channel_size = input_size->data[3]; + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(kOutputDimensionNum); + output_size->data[0] = output_batch_size; + output_size->data[1] = output_height; + output_size->data[2] = output_width; + output_size->data[3] = output_channel_size; + + return context->ResizeTensor(context, op_context.output, output_size); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + BatchToSpaceNDContext op_context(context, node); + + int block_shape_dims_array[1] = {kSpatialDimensionNum}; + Dims<4> block_shape_dims = GetTensorDims(block_shape_dims_array, 1); + +#define TF_LITE_BATCH_TO_SPACE_ND(type, scalar) \ + type::BatchToSpaceND(GetTensorData(op_context.input), \ + GetTensorDims(op_context.input), \ + op_context.params->block_shape, block_shape_dims, \ + GetTensorData(op_context.output), \ + GetTensorDims(op_context.output)) + switch (op_context.input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + if (kernel_type == kReference) { + TF_LITE_BATCH_TO_SPACE_ND(reference_ops, float); + } else { + TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, float); + } + break; + case kTfLiteUInt8: + if (kernel_type == kReference) { + TF_LITE_BATCH_TO_SPACE_ND(reference_ops, uint8_t); + } else { + TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, uint8_t); + } + break; + case kTfLiteInt32: + if (kernel_type == kReference) { + TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int32_t); + } else { + TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, int32_t); + } + break; + case kTfLiteInt64: + if (kernel_type == kReference) { + TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int64_t); + } else { + TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, int64_t); + } + break; + default: + context->ReportError(context, + "Type is currently not supported by BatchToSpace."); + return kTfLiteError; + } +#undef TF_LITE_BATCH_TO_SPACE_ND + return kTfLiteOk; +} + +} // namespace batch_to_space_nd + +TfLiteRegistration* Register_BATCH_TO_SPACE_ND_REF() { + static TfLiteRegistration r = { + nullptr, nullptr, batch_to_space_nd::Prepare, + batch_to_space_nd::Eval}; + return &r; +} + +TfLiteRegistration* Register_BATCH_TO_SPACE_ND_GENERIC_OPT() { + static TfLiteRegistration r = { + nullptr, nullptr, batch_to_space_nd::Prepare, + batch_to_space_nd::Eval}; + return &r; +} + +TfLiteRegistration* Register_BATCH_TO_SPACE_ND() { + return Register_BATCH_TO_SPACE_ND_GENERIC_OPT(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3ec4efbebcef9d55d0042d93007018c9f6ee3b58 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc @@ -0,0 +1,78 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BatchToSpaceNDOpModel : public SingleOpModel { + public: + BatchToSpaceNDOpModel(std::initializer_list input_shape, + std::initializer_list block_shape, + std::initializer_list before_crops, + std::initializer_list after_crops) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND, + BuiltinOptions_BatchToSpaceNDOptions, + CreateBatchToSpaceNDOptions( + builder_, builder_.CreateVector(block_shape), + builder_.CreateVector(before_crops), + builder_.CreateVector(after_crops)) + .Union()); + BuildInterpreter({input_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int output_; +}; + +TEST(BatchToSpaceNDOpTest, SimpleTest) { + BatchToSpaceNDOpModel m({4, 2, 2, 1}, {2, 2}, {0, 0}, {0, 0}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 5, 2, 6, 9, 13, 10, 14, 3, 7, + 4, 8, 11, 15, 12, 16})); +} + +TEST(BatchToSpaceNDOpTest, InvalidShapeTest) { + EXPECT_DEATH(BatchToSpaceNDOpModel({3, 2, 2, 1}, {2, 2}, {0, 0}, {0, 0}), + "Cannot allocate tensors"); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/concatenation_test.cc b/tensorflow/contrib/lite/kernels/concatenation_test.cc index 94e5b2acdcabeedb4652baa1a008b22bf6bc8433..499856a93cbbfbf9aa1a326912e52ce32bbbdf83 100644 --- a/tensorflow/contrib/lite/kernels/concatenation_test.cc +++ b/tensorflow/contrib/lite/kernels/concatenation_test.cc @@ -156,7 +156,7 @@ TEST(ConcatenationOpTest, FourInputsQuantized) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/conv_test.cc b/tensorflow/contrib/lite/kernels/conv_test.cc index 18d7a31d594efb6a05fe7292a0194ea17599a65b..1d0a81c3135625c07a3566f5f9a8e5401f0d4db7 100644 --- a/tensorflow/contrib/lite/kernels/conv_test.cc +++ b/tensorflow/contrib/lite/kernels/conv_test.cc @@ -434,7 +434,7 @@ TEST(ConvolutionOpTest, SimpleTestQuantizedWithAnisotropicStrides) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc index 39227b2811e2be719a0be77f89793bcf9366d513..1439c8bce14ad127ed68dc54991aed8b8bb39383 100644 --- a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc +++ b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc @@ -180,7 +180,7 @@ TEST(QuantizedDepthwiseConvolutionOpTest, SimpleTestQuantized) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc index 69d9c5cc7dec13a65f1c5050f2f1c56812ad5aa1..dcdc5fffad9ceac1a9d23a4e91637a9ff92a8dda 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc +++ b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc @@ -158,9 +158,7 @@ TEST(EmbeddingLookupOpTest, Indices3DTest) { } // namespace tflite int main(int argc, char** argv) { -#ifdef OS_LINUX - tflite::LogToStderr(); -#endif + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc index 8c030b06772ac0c6af34a45897f03ebc4637d4de..9b501878f196216a61568bfa36e6615f4dd07478 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc +++ b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc @@ -88,7 +88,7 @@ TEST(EmbeddingLookupOpTest, SimpleTest) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/fully_connected_test.cc b/tensorflow/contrib/lite/kernels/fully_connected_test.cc index 112e3f1ba01a428023eea5ee8410fb76c1d67de6..a0f766c4f4580d7679275c0b63aa200410fcb5ad 100644 --- a/tensorflow/contrib/lite/kernels/fully_connected_test.cc +++ b/tensorflow/contrib/lite/kernels/fully_connected_test.cc @@ -370,8 +370,7 @@ TEST(FullyConnectedOpTest, BlackBoxTest) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); - tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/gather.cc b/tensorflow/contrib/lite/kernels/gather.cc new file mode 100644 index 0000000000000000000000000000000000000000..f8df797daf7338e33b16508c21fc61cd9836db1e --- /dev/null +++ b/tensorflow/contrib/lite/kernels/gather.cc @@ -0,0 +1,130 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace gather { +constexpr int kInputTensor = 0; +constexpr int kInputPositions = 1; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + const auto* params = + reinterpret_cast(node->builtin_data); + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* positions = GetInput(context, node, kInputPositions); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + // Only INT32 positions are supported. + TF_LITE_ENSURE_EQ(context, positions->type, kTfLiteInt32); + // Check that input and output types match. + TF_LITE_ENSURE_EQ(context, input->type, output->type); + // TODO(mgubin): only 1D positions are currently supported. + TF_LITE_ENSURE_EQ(context, NumDimensions(positions), 1); + // TODO(mgubin): Only default axis == 0 is supported. + // Check conditions for different types. + switch (input->type) { + case kTfLiteFloat32: + case kTfLiteUInt8: + case kTfLiteInt32: { + // Fully supported by reference_ops::Gather. + } break; + + case kTfLiteString: { + // Only 1D input is supported. + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1); + } break; + default: + context->ReportError(context, + "Only float32 and string types are supported"); + return kTfLiteError; + } + const int num_dimensions = + NumDimensions(input) + NumDimensions(positions) - 1; + TF_LITE_ENSURE(context, params->axis < num_dimensions); + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions); + int output_index = 0; + for (int i = 0; i < params->axis; ++i) { + output_shape->data[output_index++] = input->dims->data[i]; + } + for (int i = 0; i < positions->dims->size; ++i) { + output_shape->data[output_index++] = positions->dims->data[i]; + } + for (int i = params->axis + 1; i < input->dims->size; ++i) { + output_shape->data[output_index++] = input->dims->data[i]; + } + return context->ResizeTensor(context, output, output_shape); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* positions = GetInput(context, node, kInputPositions); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const int input_rank = NumDimensions(input); +#define TF_LITE_GATHER(data_type, index_type) \ + optimized_ops::Gather( \ + GetTensorData(input), GetTensorDims(input), input_rank, \ + GetTensorData(positions), GetTensorDims(positions), \ + GetTensorData(output), GetTensorDims(output)); + switch (input->type) { + case kTfLiteFloat32: + TF_LITE_GATHER(float, int32_t); + break; + case kTfLiteUInt8: + TF_LITE_GATHER(uint8_t, int32_t); + break; + case kTfLiteInt32: + TF_LITE_GATHER(int32_t, int32_t); + break; + case kTfLiteString: { + DynamicBuffer buffer; + const int32* indexes = positions->data.i32; + const int num_strings = GetStringCount(input); + for (int i = 0; i < positions->dims->data[0]; ++i) { + const int pos = indexes[i]; + TF_LITE_ENSURE(context, pos < num_strings); + const auto string_ref = GetString(input, pos); + buffer.AddString(string_ref.str, string_ref.len); + } + buffer.WriteToTensor(output); + } break; + default: + return kTfLiteError; + } +#undef TF_LITE_GATHER + return kTfLiteOk; +} +} // namespace gather + +TfLiteRegistration* Register_GATHER() { + static TfLiteRegistration r = {nullptr, nullptr, gather::Prepare, + gather::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/gather_test.cc b/tensorflow/contrib/lite/kernels/gather_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6343d3b4ef20ae3e030396ec1b6adbcf83a3e45f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/gather_test.cc @@ -0,0 +1,121 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class GatherOpModel : public SingleOpModel { + public: + GatherOpModel(std::initializer_list input_shape, TensorType input_type, + std::initializer_list positions_shape) { + input_ = AddInput(input_type); + positions_ = AddInput(TensorType_INT32); + output_ = AddOutput(input_type); + SetBuiltinOp(BuiltinOperator_GATHER, BuiltinOptions_GatherOptions, + CreateGatherOptions(builder_, 0).Union()); + BuildInterpreter({input_shape, positions_shape}); + } + + void SetInputFloat(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInputUint8(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInput(std::initializer_list data) { + PopulateStringTensor(input_, data); + } + + void SetPositions(std::initializer_list data) { + PopulateTensor(positions_, data); + } + + std::vector GetOutputFloat() { return ExtractVector(output_); } + std::vector GetOutputUint8() { + return ExtractVector(output_); + } + std::vector GetOutputString() { + return ExtractVector(output_); + } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int input_; + int positions_; + int output_; +}; + +TEST(GatherOpTest, Shuffle) { + GatherOpModel m({2, 2}, TensorType_FLOAT32, {2}); + m.SetInputFloat({-2.0, 0.2, 0.7, 0.8}); + m.SetPositions({1, 0}); + m.Invoke(); + EXPECT_THAT(m.GetOutputFloat(), + ElementsAreArray(ArrayFloatNear({0.7, 0.8, -2, 0.2}))); +} + +TEST(FloatGatherOpTest, Duplicate) { + GatherOpModel m({1, 2, 2}, TensorType_FLOAT32, {2}); + m.SetInputFloat({-2.0, 0.2, 0.7, 0.8}); + m.SetPositions({0, 0}); + m.Invoke(); + EXPECT_THAT( + m.GetOutputFloat(), + ElementsAreArray(ArrayFloatNear({-2, 0.2, 0.7, 0.8, -2, 0.2, 0.7, 0.8}))); +} + +TEST(FloatGatherOpTest, Slice) { + GatherOpModel m({4, 1}, TensorType_FLOAT32, {2}); + m.SetInputFloat({-2.0, 0.2, 0.7, 0.8}); + m.SetPositions({1, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutputFloat(), ElementsAreArray(ArrayFloatNear({0.2, 0.8}))); +} + +TEST(Uint8tGatherOpTest, Shuffle) { + GatherOpModel m({2, 2}, TensorType_UINT8, {2}); + m.SetInputUint8({133, 134, 14, 15}); + m.SetPositions({1, 0}); + m.Invoke(); + + EXPECT_THAT(m.GetOutputUint8(), ElementsAreArray({14, 15, 133, 134})); +} + +TEST(GatherOpTest, SimpleString) { + GatherOpModel m({3}, TensorType_STRING, {2}); + m.SetInput({"A", "B", "C"}); + m.SetPositions({0, 2}); + m.Invoke(); + ASSERT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutputString(), ElementsAreArray({"A", "C"})); +} +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc b/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc index 916a23225e2ad3c5645a7809169677a7a8880535..cb6038f9009a3865661e7b4f075c3033166d0f91 100644 --- a/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc +++ b/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc @@ -170,7 +170,7 @@ TEST(HashtableLookupOpTest, TestString) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD index 288534099b9e090ce0c223a401b4152ca6ffb61f..a3ecb2ebf6a889729954d1e447997c510e8ff6d4 100644 --- a/tensorflow/contrib/lite/kernels/internal/BUILD +++ b/tensorflow/contrib/lite/kernels/internal/BUILD @@ -124,6 +124,13 @@ config_setting( }, ) +config_setting( + name = "freebsd", + values = { + "cpu": "freebsd", + }, +) + cc_library( name = "optimized_base", srcs = [], @@ -147,6 +154,7 @@ cc_library( ":x86": tflite_deps_intel, ":x86_64": tflite_deps_intel, ":darwin": tflite_deps_intel, + ":freebsd": tflite_deps_intel, "//conditions:default": [], }), ) @@ -224,6 +232,7 @@ cc_library( ":x86": tflite_deps_intel, ":x86_64": tflite_deps_intel, ":darwin": tflite_deps_intel, + ":freebsd": tflite_deps_intel, "//conditions:default": [], }), ) diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h index 974611f52ac74cec275f978c5af5bd561688db78..da34c8aef94b1c69e661bd33fcb518e73034c4bd 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h @@ -311,6 +311,9 @@ struct FloatDepthwiseConvKernel { } }; +// Note this implementation is very slow for input_depths < 8 +// (e.g. comparable to reference implementation) see, specializations for +// input_depth=3 below. template <> struct FloatDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, @@ -417,6 +420,74 @@ struct FloatDepthwiseConvKernel { } }; +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Load the filters + float32x2_t filter[3]; + for (int i = 0; i < 3; i++) { + filter[i] = vld1_f32(filter_ptr + 2 * i); + } + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + const float32x2_t input01 = vld1_f32(input_ptr); + const float32x2_t input2 = vld1_dup_f32(input_ptr + 2); + // Load the accumulators from acc_buffer + float32x2_t acc[3]; + for (int i = 0; i < 3; i++) { + acc[i] = vld1_f32(acc_buffer_ptr + 2 * i); + } + // Multiply-accumulate for each input channel there 2 outputs + acc[0] = vmla_lane_f32(acc[0], filter[0], input01, 0); + acc[1] = vmla_lane_f32(acc[1], filter[1], input01, 1); + acc[2] = vmla_lane_f32(acc[2], filter[2], input2, 0); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 3; i++) { + vst1_f32(acc_buffer_ptr + 2 * i, acc[i]); + } + acc_buffer_ptr += 6; + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Load the filters + float32x4_t filter[3]; + for (int i = 0; i < 3; i++) { + filter[i] = vld1q_f32(filter_ptr + 4 * i); + } + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // NOTE: we only want 3 values, so we read it as two ops where + // the second op just duplicates the lane + const float32x2_t input01 = vld1_f32(input_ptr); + const float32x2_t input2 = vld1_dup_f32(input_ptr + 2); + // Load the accumulators from acc_buffer + float32x4_t acc[3]; + for (int i = 0; i < 3; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate all outputs. + acc[0] = vmlaq_lane_f32(acc[0], filter[0], input01, 0); + acc[1] = vmlaq_lane_f32(acc[1], filter[1], input01, 1); + acc[2] = vmlaq_lane_f32(acc[2], filter[2], input2, 0); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 3; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 12; + input_ptr += input_ptr_increment; + } + } +}; + template <> struct FloatDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, @@ -857,6 +928,8 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 8) TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 32) TFMINI_USE_DEPTHWISECONV_KERNEL(true, 2, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 3, 2) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 3, 4) TFMINI_USE_DEPTHWISECONV_KERNEL(true, 4, 1) // Finally, the kernels allowing a variable input depth, diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index cd565c16a1ee7226f83c19f0020beed75e401497..2df919e579efaaa283f191df91cd433374b31567 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -3704,6 +3704,43 @@ void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims, auto max_value = input2_data[0]; output_map.array() = input1_map.array().max(max_value); } + +template +void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims, + T2* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("ArgMax"); + + // The current ArgMax implemention can only determine the index of the maximum + // value in the last dimension. So the axis argument is ignored. + TFLITE_DCHECK_EQ(axis[0], 3); + + // For ArgMax, the number of output dimensions = (number of input dimensions - + // 1). For the sake of simplicity, the output dimensions are equal to the + // input dimensions here. We enforce the constraint that the last dimension + // must always be 1. + TFLITE_DCHECK_EQ(ArraySize(output_dims, 0), 1); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = ArraySize(input_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + auto max_value = input_data[Offset(input_dims, 0, x, y, b)]; + int max_index = 0; + for (int d = 1; d < depth; ++d) { + const auto& curr_value = input_data[Offset(input_dims, d, x, y, b)]; + if (curr_value > max_value) { + max_value = curr_value; + max_index = d; + } + } + output_data[Offset(output_dims, 0, x, y, b)] = max_index; + } + } + } +} + } // namespace optimized_ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h index c2ab78000b81485f037c507933cd024e70f39850..7f90d731b8454a020ab273e6b5591ed90aab14c7 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h @@ -22,7 +22,7 @@ limitations under the License. namespace tflite { namespace tensor_utils { -// Limit a float input f betweeen +abs_limit and -abs_limit. +// Limit a float input f between +abs_limit and -abs_limit. float PortableClip(float f, float abs_limit); // Multiply a matrix by a batch vector, and store results in a batch-size diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index b9ca3d5c626dff4ea8ba52949e8fea8e9b43689f..14c430258740b65dce65816f7c5c41fccf6dd5cf 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -2449,6 +2449,40 @@ void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims, } } +template +void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims, + T2* output_data, const Dims<4>& output_dims) { + // The current ArgMax implemention can only determine the index of the maximum + // value in the last dimension. So the axis argument is ignored. + TFLITE_DCHECK_EQ(axis[0], 3); + + // For ArgMax, the number of output dimensions = (number of input dimensions - + // 1). For the sake of simplicity, the output dimensions are equal to the + // input dimensions here. We enforce the constraint that the last dimension + // must always be 1. + TFLITE_DCHECK_EQ(ArraySize(output_dims, 0), 1); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = ArraySize(input_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + auto max_value = input_data[Offset(input_dims, 0, x, y, b)]; + int max_index = 0; + for (int d = 1; d < depth; ++d) { + const auto& curr_value = input_data[Offset(input_dims, d, x, y, b)]; + if (curr_value > max_value) { + max_value = curr_value; + max_index = d; + } + } + output_data[Offset(output_dims, 0, x, y, b)] = max_index; + } + } + } +} + } // namespace reference_ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h index 0e69ef5982f01e364d865684652d1dfecab6fee3..e7e2994397650004c7ba442fa1803290e6b12302 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h @@ -20,7 +20,7 @@ limitations under the License. namespace tflite { namespace tensor_utils { -// Limit a float input f betweeen +abs_limit and -abs_limit. +// Limit a float input f between +abs_limit and -abs_limit. float Clip(float f, float abs_limit); // Multiply a matrix by a batch vector, and store results in a batch-size diff --git a/tensorflow/contrib/lite/kernels/l2norm_test.cc b/tensorflow/contrib/lite/kernels/l2norm_test.cc index b1db89b8bd3474ac868d7215e4a0de12088c48ef..30e103f3303484c339ef98e6a68e0438291c102f 100644 --- a/tensorflow/contrib/lite/kernels/l2norm_test.cc +++ b/tensorflow/contrib/lite/kernels/l2norm_test.cc @@ -57,7 +57,7 @@ TEST(L2NormOpTest, SimpleTest) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/local_response_norm_test.cc b/tensorflow/contrib/lite/kernels/local_response_norm_test.cc index 63a8b0a3d0186def7da2c9f31481721f1a55281c..d75ce258a04c820d8f82735988c01d0154ef36f2 100644 --- a/tensorflow/contrib/lite/kernels/local_response_norm_test.cc +++ b/tensorflow/contrib/lite/kernels/local_response_norm_test.cc @@ -95,7 +95,7 @@ TEST(LocalResponseNormOpTest, SmallRadius) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/lsh_projection_test.cc b/tensorflow/contrib/lite/kernels/lsh_projection_test.cc index 1011927848d586c8541fb694914b5eee123cb8dc..414d728dfc153058ec878d3c766f58e86815cd3f 100644 --- a/tensorflow/contrib/lite/kernels/lsh_projection_test.cc +++ b/tensorflow/contrib/lite/kernels/lsh_projection_test.cc @@ -117,7 +117,7 @@ TEST(LSHProjectionOpTest2, Sparse3DInputs) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc index be4c7ddbf88fc902368cda13aff72f5aecb9dac4..c068286b0d84bcb51ebb0e239350a42863de6523 100644 --- a/tensorflow/contrib/lite/kernels/lstm_test.cc +++ b/tensorflow/contrib/lite/kernels/lstm_test.cc @@ -1081,8 +1081,7 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); - tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/mul_test.cc b/tensorflow/contrib/lite/kernels/mul_test.cc index 4b858e1f396252e7f7bdc231bc1e00f47277f08a..4255cfe18a043c55f3ce7292afdedb6e988a28a2 100644 --- a/tensorflow/contrib/lite/kernels/mul_test.cc +++ b/tensorflow/contrib/lite/kernels/mul_test.cc @@ -120,8 +120,7 @@ TEST(QuantizedMulOpTest, NoActivation) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); - tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/op_macros.h b/tensorflow/contrib/lite/kernels/op_macros.h index 7535afaf8ea52d855e2e4773e56ce2118a16447c..63670efcb1e6349317aa5c75756707fb7a7fa2aa 100644 --- a/tensorflow/contrib/lite/kernels/op_macros.h +++ b/tensorflow/contrib/lite/kernels/op_macros.h @@ -15,6 +15,8 @@ limitations under the License. #ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_ #define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_ +#include + #define TF_LITE_FATAL(msg) \ do { \ fprintf(stderr, "%s\n", (msg)); \ diff --git a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc index 8e9cc07656c8bea83f7cb78ca0b6cc5de7ad1b73..17166715ca30ff3d8ba3d384110e403f8910e39d 100644 --- a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc +++ b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc @@ -334,8 +334,7 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); - tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc new file mode 100644 index 0000000000000000000000000000000000000000..5e90282a43b1b6caf7918b3874fd4273f59e31b7 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/pad.cc @@ -0,0 +1,139 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace pad { + +// This file has two implementations of Pad. +enum KernelType { + kReference, + kGenericOptimized, +}; + +// TODO(nupurgarg): Padding represented as a tensor is ignored. Only use the +// `left_padding` and `right_padding` specified in `params`. +struct PadContext { + PadContext(TfLiteContext* context, TfLiteNode* node) { + params = reinterpret_cast(node->builtin_data); + input = GetInput(context, node, 0); + output = GetOutput(context, node, 0); + } + TfLitePadParams* params; + TfLiteTensor* input; + TfLiteTensor* output; +}; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + // Determines size of output tensor. + PadContext op_context(context, node); + int dims = NumDimensions(op_context.input); + TF_LITE_ENSURE_EQ(context, dims, op_context.params->num_dimensions); + + // TODO(nupurgarg): Our current implementations rely on the inputs being 4D. + TF_LITE_ENSURE_EQ(context, dims, 4); + + const TfLiteIntArray* input_size = op_context.input->dims; + TfLiteIntArray* output_size = TfLiteIntArrayCreate(dims); + for (int idx = 0; idx < dims; ++idx) { + TF_LITE_ENSURE_MSG(context, + (op_context.params->before_padding[idx] >= 0 && + op_context.params->after_padding[idx] >= 0), + "Pad value has to be greater than equal to 0."); + output_size->data[idx] = + (input_size->data[idx] + op_context.params->before_padding[idx] + + op_context.params->after_padding[idx]); + } + + return context->ResizeTensor(context, op_context.output, output_size); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + PadContext op_context(context, node); + + // TODO(nupurgarg): Support different data types. + if (op_context.output->type == kTfLiteFloat32) { + std::vector before_padding( + op_context.params->before_padding, + op_context.params->before_padding + op_context.params->num_dimensions); + std::vector after_padding( + op_context.params->after_padding, + op_context.params->after_padding + op_context.params->num_dimensions); + + // TODO(nupurgarg): Change TOCO's implementation to use padding arrays + // in forward order (depth, width, height, batch). + // Converts from int[] = {depth, width, height, batch} to int[] = {batch, + // height, width, depth} to match TOCO's implementation of pad in + // referenced_ops.h and optimized_ops.h. + std::reverse(before_padding.begin(), before_padding.end()); + std::reverse(after_padding.begin(), after_padding.end()); + +#define TF_LITE_PAD(type) \ + type::Pad(GetTensorData(op_context.input), \ + GetTensorDims(op_context.input), before_padding, after_padding, \ + GetTensorData(op_context.output), \ + GetTensorDims(op_context.output)) + + if (kernel_type == kReference) { + TF_LITE_PAD(reference_ops); + } + if (kernel_type == kGenericOptimized) { + TF_LITE_PAD(optimized_ops); + } +#undef TF_LITE_PAD + } else { + context->ReportError(context, "Inputs and outputs not all float types."); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace pad + +TfLiteRegistration* Register_PAD_REF() { + static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare, + pad::Eval}; + return &r; +} + +TfLiteRegistration* Register_PAD_GENERIC_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare, + pad::Eval}; + return &r; +} + +TfLiteRegistration* Register_PAD() { + return Register_PAD_GENERIC_OPT(); + // return Register_PAD_REF(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/pad_test.cc b/tensorflow/contrib/lite/kernels/pad_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f3ea9417df0e61dcff7a877726ab91c9b22691ba --- /dev/null +++ b/tensorflow/contrib/lite/kernels/pad_test.cc @@ -0,0 +1,99 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class PadOpModel : public SingleOpModel { + public: + PadOpModel(std::initializer_list input_shape, + std::initializer_list before_padding, + std::initializer_list after_padding) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_PAD, BuiltinOptions_PadOptions, + CreatePadOptions(builder_, builder_.CreateVector(before_padding), + builder_.CreateVector(after_padding)) + .Union()); + BuildInterpreter({input_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int output_; +}; + +TEST(PadOpTest, TooManyDimensions) { + EXPECT_DEATH( + PadOpModel({1, 2, 3, 4, 5, 6, 7, 8, 9}, {1, 2, 3, 4, 5, 6, 7, 8, 9}, + {1, 2, 3, 4, 5, 6, 7, 8, 9}), + "dims != 4"); +} + +// TODO(nupurgarg): Test case where before padding and after padding arrays +// don't contain the same number of dimensions. +TEST(PadOpTest, UnequalDimensions) { + EXPECT_DEATH(PadOpModel({1, 1, 2, 1}, {1, 2, 3}, {1, 2, 3}), + "dims != op_context.params->num_dimensions"); +} + +TEST(PadOpTest, InvalidPadValue) { + EXPECT_DEATH(PadOpModel({1, 1, 2, 1}, {0, 1, 2, 0}, {0, -1, -1, 0}), + "Pad value has to be greater than equal to 0."); +} + +TEST(PadOpTest, SimpleTest) { + PadOpModel m({1, 2, 2, 1}, {0, 1, 1, 0}, {0, 1, 1, 0}); + m.SetInput({1, 2, 3, 4}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, + 0, 0, 0, 0, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); +} + +TEST(PadOpTest, AdvancedTest) { + // The padding is input in the order of batch, height, width, depth. + PadOpModel m({1, 2, 3, 1}, {0, 0, 1, 0}, {0, 2, 3, 0}); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/pooling_test.cc b/tensorflow/contrib/lite/kernels/pooling_test.cc index e1b51ec7d5141bf2a41e7ede3e90ff20ec523819..01c91b2ba905e249c36af19f175c68a7e7f17f6d 100644 --- a/tensorflow/contrib/lite/kernels/pooling_test.cc +++ b/tensorflow/contrib/lite/kernels/pooling_test.cc @@ -155,7 +155,7 @@ TEST(FloatPoolingOpTest, L2Pool) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index ca7a0dd1949a3a31d26be770a7df781cc5fe7533..d4e7503f48debbdc092ad7950ee4c0e52854c432 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -31,6 +31,7 @@ TfLiteRegistration* Register_CONV_2D(); TfLiteRegistration* Register_DEPTHWISE_CONV_2D(); TfLiteRegistration* Register_SVDF(); TfLiteRegistration* Register_RNN(); +TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_RNN(); TfLiteRegistration* Register_EMBEDDING_LOOKUP(); TfLiteRegistration* Register_EMBEDDING_LOOKUP_SPARSE(); TfLiteRegistration* Register_FULLY_CONNECTED(); @@ -39,14 +40,17 @@ TfLiteRegistration* Register_HASHTABLE_LOOKUP(); TfLiteRegistration* Register_SOFTMAX(); TfLiteRegistration* Register_CONCATENATION(); TfLiteRegistration* Register_ADD(); +TfLiteRegistration* Register_BATCH_TO_SPACE_ND(); TfLiteRegistration* Register_MUL(); TfLiteRegistration* Register_L2_NORMALIZATION(); TfLiteRegistration* Register_LOCAL_RESPONSE_NORMALIZATION(); TfLiteRegistration* Register_LSTM(); +TfLiteRegistration* Register_PAD(); TfLiteRegistration* Register_RESHAPE(); TfLiteRegistration* Register_RESIZE_BILINEAR(); TfLiteRegistration* Register_SKIP_GRAM(); TfLiteRegistration* Register_SPACE_TO_DEPTH(); +TfLiteRegistration* Register_GATHER(); BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_RELU, Register_RELU()); @@ -61,6 +65,8 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D()); AddBuiltin(BuiltinOperator_SVDF, Register_SVDF()); AddBuiltin(BuiltinOperator_RNN, Register_RNN()); + AddBuiltin(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, + Register_UNIDIRECTIONAL_SEQUENCE_RNN()); AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP, Register_EMBEDDING_LOOKUP()); AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP_SPARSE, Register_EMBEDDING_LOOKUP_SPARSE()); @@ -70,15 +76,18 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX()); AddBuiltin(BuiltinOperator_CONCATENATION, Register_CONCATENATION()); AddBuiltin(BuiltinOperator_ADD, Register_ADD()); + AddBuiltin(BuiltinOperator_BATCH_TO_SPACE_ND, Register_BATCH_TO_SPACE_ND()); AddBuiltin(BuiltinOperator_MUL, Register_MUL()); AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION()); AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, Register_LOCAL_RESPONSE_NORMALIZATION()); AddBuiltin(BuiltinOperator_LSTM, Register_LSTM()); + AddBuiltin(BuiltinOperator_PAD, Register_PAD()); AddBuiltin(BuiltinOperator_RESHAPE, Register_RESHAPE()); AddBuiltin(BuiltinOperator_RESIZE_BILINEAR, Register_RESIZE_BILINEAR()); AddBuiltin(BuiltinOperator_SKIP_GRAM, Register_SKIP_GRAM()); AddBuiltin(BuiltinOperator_SPACE_TO_DEPTH, Register_SPACE_TO_DEPTH()); + AddBuiltin(BuiltinOperator_GATHER, Register_GATHER()); } TfLiteRegistration* BuiltinOpResolver::FindOp( diff --git a/tensorflow/contrib/lite/kernels/reshape_test.cc b/tensorflow/contrib/lite/kernels/reshape_test.cc index 59ce7d5648c04f78123b16a195d3a4928d28394b..0fbcf6e6aa311d2cac491336ee54ccf58bbda8fd 100644 --- a/tensorflow/contrib/lite/kernels/reshape_test.cc +++ b/tensorflow/contrib/lite/kernels/reshape_test.cc @@ -83,8 +83,7 @@ TEST(ReshapeOpTest, WithStretchDimension) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); - tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc index 0257c0b557feb352413bcc33cb4e2ecdb32c5111..314a71e210d9b5ea75bb137ef228273ef48f28b5 100644 --- a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc +++ b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc @@ -111,7 +111,7 @@ TEST(ResizeBilinearOpTest, ThreeDimensionalResize) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/skip_gram_test.cc b/tensorflow/contrib/lite/kernels/skip_gram_test.cc index e7f6bc904be5e4c23a88f5b4ae7e199346c78ab2..185b64cb44969b57588ea5d0b40f55b6ddf8e11f 100644 --- a/tensorflow/contrib/lite/kernels/skip_gram_test.cc +++ b/tensorflow/contrib/lite/kernels/skip_gram_test.cc @@ -251,7 +251,7 @@ TEST(SkipGramTest, TestInputWithExtraSpace) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/softmax_test.cc b/tensorflow/contrib/lite/kernels/softmax_test.cc index ec8ec03b0d0279cad8543352b1dbaf34c88a7957..6c5338ff0fd26337c9adc8e0b94a0a88edfde37f 100644 --- a/tensorflow/contrib/lite/kernels/softmax_test.cc +++ b/tensorflow/contrib/lite/kernels/softmax_test.cc @@ -136,8 +136,7 @@ TEST(SoftmaxOpTest, CompareWithTFminiBetaNotEq1) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); - tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/space_to_depth_test.cc b/tensorflow/contrib/lite/kernels/space_to_depth_test.cc index 911f08a92ccd6a97bee414c87bd79091808f0ed1..997f354861a235fb511235e4d64544dc8c3ddb34 100644 --- a/tensorflow/contrib/lite/kernels/space_to_depth_test.cc +++ b/tensorflow/contrib/lite/kernels/space_to_depth_test.cc @@ -95,8 +95,7 @@ TEST(SpaceToDepthOpModel, Int64) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); - tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/svdf_test.cc b/tensorflow/contrib/lite/kernels/svdf_test.cc index d956025e9dfc9b6c03e55657023fb042c8ac485d..4de2ceaf053df31a4bc857fb250db416c071e80f 100644 --- a/tensorflow/contrib/lite/kernels/svdf_test.cc +++ b/tensorflow/contrib/lite/kernels/svdf_test.cc @@ -306,7 +306,7 @@ TEST(SVDFOpTest, BlackBoxTestRank2) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc index f716ba8741fd469e7ee405ac300924b53c5c48e5..b69f2b3e4bc66c94fdfc7ed4c244151be63a1711 100644 --- a/tensorflow/contrib/lite/kernels/test_util.cc +++ b/tensorflow/contrib/lite/kernels/test_util.cc @@ -180,4 +180,17 @@ int32_t SingleOpModel::GetTensorSize(int index) const { return total_size; } +template <> +std::vector SingleOpModel::ExtractVector(int index) { + TfLiteTensor* tensor_ptr = interpreter_->tensor(index); + CHECK(tensor_ptr != nullptr); + const int num_strings = GetStringCount(tensor_ptr); + std::vector result; + result.reserve(num_strings); + for (int i = 0; i < num_strings; ++i) { + const auto str = GetString(tensor_ptr, i); + result.emplace_back(str.str, str.len); + } + return result; +} } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h index e68e49466119c50ec123edb84f1b1b6390a15a60..531c1366a87e20e140e779b767e29b1fd1111f97 100644 --- a/tensorflow/contrib/lite/kernels/test_util.h +++ b/tensorflow/contrib/lite/kernels/test_util.h @@ -24,16 +24,11 @@ limitations under the License. #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/model.h" #include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/contrib/lite/testing/util.h" #include "tensorflow/core/platform/logging.h" namespace tflite { -inline void LogToStderr() { -#ifdef PLATFORM_GOOGLE - FLAGS_logtostderr = true; -#endif -} - // A gmock matcher that check that elements of a float vector match to a given // tolerance. std::vector<::testing::Matcher> ArrayFloatNear( @@ -197,6 +192,9 @@ class SingleOpModel { std::map> custom_registrations_; }; +// Strings have a special implementation that is in test_util.cc +template <> +std::vector SingleOpModel::ExtractVector(int index); } // namespace tflite #endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_ diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc new file mode 100644 index 0000000000000000000000000000000000000000..85e09049eea5f66a2bb854990bf80e9ed5dcc88a --- /dev/null +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc @@ -0,0 +1,169 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace unidirectional_sequence_rnn { + +constexpr int kInputTensor = 0; +constexpr int kWeightsTensor = 1; +constexpr int kRecurrentWeightsTensor = 2; +constexpr int kBiasTensor = 3; +constexpr int KHiddenStateTensor = 0; +constexpr int kOutputTensor = 1; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + // Check we have all the inputs and outputs we need. + TF_LITE_ENSURE_EQ(context, node->inputs->size, 4); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); + + TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; + TfLiteTensor* input_weights = + &context->tensors[node->inputs->data[kWeightsTensor]]; + TfLiteTensor* recurrent_weights = + &context->tensors[node->inputs->data[kRecurrentWeightsTensor]]; + TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]]; + + // Check all the parameters of tensor match within themselves and match the + // input configuration. + const int batch_size = input->dims->data[0]; + const int max_time = input->dims->data[1]; + const int num_units = input_weights->dims->data[0]; + TF_LITE_ASSERT_EQ(input->dims->data[2], input_weights->dims->data[1]); + TF_LITE_ASSERT_EQ(input_weights->dims->data[0], bias->dims->data[0]); + TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[0], bias->dims->data[0]); + TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]); + + TfLiteTensor* hidden_state = + &context->tensors[node->outputs->data[KHiddenStateTensor]]; + TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]]; + + // Resize state. + TfLiteIntArray* hidden_state_size_array = TfLiteIntArrayCreate(2); + hidden_state_size_array->data[0] = batch_size; + hidden_state_size_array->data[1] = num_units; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, hidden_state, + hidden_state_size_array)); + + // Mark hidden state as a persistent tensor. + hidden_state->allocation_type = kTfLiteArenaRwPersistent; + + // Resize output. + TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(3); + output_size_array->data[0] = batch_size; + output_size_array->data[1] = max_time; + output_size_array->data[2] = num_units; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output, + output_size_array)); + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; + TfLiteTensor* input_weights = + &context->tensors[node->inputs->data[kWeightsTensor]]; + TfLiteTensor* recurrent_weights = + &context->tensors[node->inputs->data[kRecurrentWeightsTensor]]; + TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]]; + TfLiteTensor* hidden_state = + &context->tensors[node->outputs->data[KHiddenStateTensor]]; + TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]]; + + // Initialize the pointer bias. + const float* bias_ptr = bias->data.f; + + const int batch_size = input->dims->data[0]; + const int max_time = input->dims->data[1]; + const int num_units = input_weights->dims->data[0]; + const int input_size = input->dims->data[2]; + const int input_weights_stride = input_weights->dims->data[1]; + const int recurrent_weights_stride = recurrent_weights->dims->data[1]; + + // For each batch + for (int b = 0; b < batch_size; b++) { + // Initialize the pointer to hidden state. + float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units; + for (int s = 0; s < max_time; s++) { + // Initialize the pointer to input and output. + const float* input_ptr_batch = + input->data.f + b * input_size * max_time + s * input_size; + float* output_ptr_batch = + output->data.f + b * num_units * max_time + s * num_units; + + // Initialize input_weights and recurrent_weights. + const float* input_weights_ptr = input_weights->data.f; + const float* recurrent_weights_ptr = recurrent_weights->data.f; + + // Output = bias + for (int o = 0; o < num_units; o++) { + output_ptr_batch[o] = bias_ptr[o]; + } + + // Output += input * input_weights + for (int o = 0; o < num_units; o++) { + for (int i = 0; i < input_size; i++) { + output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i]; + } + input_weights_ptr += input_weights_stride; + } + + // Output += recurrent_weights * hidden_state + for (int o = 0; o < num_units; o++) { + for (int h = 0; h < num_units; h++) { + output_ptr_batch[o] += + hidden_state_ptr_batch[h] * recurrent_weights_ptr[h]; + } + recurrent_weights_ptr += recurrent_weights_stride; + } + + // Output = activation(Output) and update hidden_state + for (int o = 0; o < num_units; o++) { + output_ptr_batch[o] = + (ActivationFunctor(params->activation))(output_ptr_batch[o]); + hidden_state_ptr_batch[o] = output_ptr_batch[o]; + } + } + } + return kTfLiteOk; +} + +} // namespace unidirectional_sequence_rnn + +TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_RNN() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + unidirectional_sequence_rnn::Prepare, + unidirectional_sequence_rnn::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a1c1eda16034f83ca5c79fc18f4fa495a3e73f90 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc @@ -0,0 +1,270 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite RNN op. + +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +static float rnn_input[] = { + 0.23689353, 0.285385, 0.037029743, -0.19858193, -0.27569133, + 0.43773448, 0.60379338, 0.35562468, -0.69424844, -0.93421471, + -0.87287879, 0.37144363, -0.62476718, 0.23791671, 0.40060222, + 0.1356622, -0.99774903, -0.98858172, -0.38952237, -0.47685933, + 0.31073618, 0.71511042, -0.63767755, -0.31729108, 0.33468103, + 0.75801885, 0.30660987, -0.37354088, 0.77002847, -0.62747043, + -0.68572164, 0.0069220066, 0.65791464, 0.35130811, 0.80834007, + -0.61777675, -0.21095741, 0.41213346, 0.73784804, 0.094794154, + 0.47791874, 0.86496925, -0.53376222, 0.85315156, 0.10288584, + 0.86684, -0.011186242, 0.10513687, 0.87825835, 0.59929144, + 0.62827742, 0.18899453, 0.31440187, 0.99059987, 0.87170351, + -0.35091716, 0.74861872, 0.17831337, 0.2755419, 0.51864719, + 0.55084288, 0.58982027, -0.47443086, 0.20875752, -0.058871567, + -0.66609079, 0.59098077, 0.73017097, 0.74604273, 0.32882881, + -0.17503482, 0.22396147, 0.19379807, 0.29120302, 0.077113032, + -0.70331609, 0.15804303, -0.93407321, 0.40182066, 0.036301374, + 0.66521823, 0.0300982, -0.7747041, -0.02038002, 0.020698071, + -0.90300065, 0.62870288, -0.23068321, 0.27531278, -0.095755219, + -0.712036, -0.17384434, -0.50593495, -0.18646687, -0.96508682, + 0.43519354, 0.14744234, 0.62589407, 0.1653645, -0.10651493, + -0.045277178, 0.99032974, -0.88255352, -0.85147917, 0.28153265, + 0.19455957, -0.55479527, -0.56042433, 0.26048636, 0.84702539, + 0.47587705, -0.074295521, -0.12287641, 0.70117295, 0.90532446, + 0.89782166, 0.79817224, 0.53402734, -0.33286154, 0.073485017, + -0.56172788, -0.044897556, 0.89964068, -0.067662835, 0.76863563, + 0.93455386, -0.6324693, -0.083922029}; + +static float rnn_golden_output[] = { + 0.496726, 0, 0.965996, 0, 0.0584254, 0, + 0, 0.12315, 0, 0, 0.612266, 0.456601, + 0, 0.52286, 1.16099, 0.0291232, + + 0, 0, 0.524901, 0, 0, 0, + 0, 1.02116, 0, 1.35762, 0, 0.356909, + 0.436415, 0.0355727, 0, 0, + + 0, 0, 0, 0.262335, 0, 0, + 0, 1.33992, 0, 2.9739, 0, 0, + 1.31914, 2.66147, 0, 0, + + 0.942568, 0, 0, 0, 0.025507, 0, + 0, 0, 0.321429, 0.569141, 1.25274, 1.57719, + 0.8158, 1.21805, 0.586239, 0.25427, + + 1.04436, 0, 0.630725, 0, 0.133801, 0.210693, + 0.363026, 0, 0.533426, 0, 1.25926, 0.722707, + 0, 1.22031, 1.30117, 0.495867, + + 0.222187, 0, 0.72725, 0, 0.767003, 0, + 0, 0.147835, 0, 0, 0, 0.608758, + 0.469394, 0.00720298, 0.927537, 0, + + 0.856974, 0.424257, 0, 0, 0.937329, 0, + 0, 0, 0.476425, 0, 0.566017, 0.418462, + 0.141911, 0.996214, 1.13063, 0, + + 0.967899, 0, 0, 0, 0.0831304, 0, + 0, 1.00378, 0, 0, 0, 1.44818, + 1.01768, 0.943891, 0.502745, 0, + + 0.940135, 0, 0, 0, 0, 0, + 0, 2.13243, 0, 0.71208, 0.123918, 1.53907, + 1.30225, 1.59644, 0.70222, 0, + + 0.804329, 0, 0.430576, 0, 0.505872, 0.509603, + 0.343448, 0, 0.107756, 0.614544, 1.44549, 1.52311, + 0.0454298, 0.300267, 0.562784, 0.395095, + + 0.228154, 0, 0.675323, 0, 1.70536, 0.766217, + 0, 0, 0, 0.735363, 0.0759267, 1.91017, + 0.941888, 0, 0, 0, + + 0, 0, 1.5909, 0, 0, 0, + 0, 0.5755, 0, 0.184687, 0, 1.56296, + 0.625285, 0, 0, 0, + + 0, 0, 0.0857888, 0, 0, 0, + 0, 0.488383, 0.252786, 0, 0, 0, + 1.02817, 1.85665, 0, 0, + + 0.00981836, 0, 1.06371, 0, 0, 0, + 0, 0, 0, 0.290445, 0.316406, 0, + 0.304161, 1.25079, 0.0707152, 0, + + 0.986264, 0.309201, 0, 0, 0, 0, + 0, 1.64896, 0.346248, 0, 0.918175, 0.78884, + 0.524981, 1.92076, 2.07013, 0.333244, + + 0.415153, 0.210318, 0, 0, 0, 0, + 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453, + 0.628881, 3.58099, 1.49974, 0 +}; + +class UnidirectionalRNNOpModel : public SingleOpModel { + public: + UnidirectionalRNNOpModel(int batches, int sequence_len, int units, int size) + : batches_(batches), + sequence_len_(sequence_len), + units_(units), + input_size_(size) { + input_ = AddInput(TensorType_FLOAT32); + weights_ = AddInput(TensorType_FLOAT32); + recurrent_weights_ = AddInput(TensorType_FLOAT32); + bias_ = AddInput(TensorType_FLOAT32); + hidden_state_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, BuiltinOptions_RNNOptions, + CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union()); + BuildInterpreter({{batches_, sequence_len_, input_size_}, + {units_, input_size_}, + {units_, units_}, + {units_}}); + } + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetWeights(std::initializer_list f) { + PopulateTensor(weights_, f); + } + + void SetRecurrentWeights(std::initializer_list f) { + PopulateTensor(recurrent_weights_, f); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + void ResetHiddenState() { + const int zero_buffer_size = units_ * batches_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(hidden_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + int input_size() { return input_size_; } + int num_units() { return units_; } + int num_batches() { return batches_; } + int sequence_len() { return sequence_len_; } + + private: + int input_; + int weights_; + int recurrent_weights_; + int bias_; + int hidden_state_; + int output_; + + int batches_; + int sequence_len_; + int units_; + int input_size_; +}; + +// TODO(mirkov): add another test which directly compares to TF once TOCO +// supports the conversion from dynamic_rnn with BasicRNNCell. +TEST(FullyConnectedOpTest, BlackBoxTest) { + UnidirectionalRNNOpModel rnn(2, 16, 16, 8); + rnn.SetWeights( + {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, + 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, + 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113, + -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512, + -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188, + -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158, + -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241, + 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, + 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, + 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884, + -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726, + 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644, + -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461, + -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158, + 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, + 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, + 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345, + -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884, + 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274, + 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934, + -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077, + 0.277308, 0.415818}); + + rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, + -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796, + 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964, + -0.37609905}); + + rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1}); + + rnn.ResetHiddenState(); + const int input_sequence_size = rnn.input_size() * rnn.sequence_len(); + float* batch_start = rnn_input; + float* batch_end = batch_start + input_sequence_size; + rnn.SetInput(0, batch_start, batch_end); + rnn.SetInput(input_sequence_size, batch_start, batch_end); + + rnn.Invoke(); + + float* golden_start = rnn_golden_output; + float* golden_end = golden_start + rnn.num_units() * rnn.sequence_len(); + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index e2f3560e61baae88a4afaafaa202cde784063efc..94e22b265964b300c862a9ee52511d479c20c64d 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -60,6 +60,14 @@ std::unique_ptr FlatBufferModel::BuildFromBuffer( return model; } +std::unique_ptr FlatBufferModel::BuildFromModel( + const tflite::Model* model_spec, ErrorReporter* error_reporter) { + std::unique_ptr model; + model.reset(new FlatBufferModel(model_spec, error_reporter)); + if (!model->initialized()) model.reset(); + return model; +} + FlatBufferModel::FlatBufferModel(const char* filename, bool mmap_file, ErrorReporter* error_reporter, bool use_nnapi) : error_reporter_(error_reporter ? error_reporter @@ -99,6 +107,13 @@ FlatBufferModel::FlatBufferModel(const char* ptr, size_t num_bytes, model_ = VerifyAndGetModel(allocation_->base(), allocation_->bytes()); } +FlatBufferModel::FlatBufferModel(const Model* model, + ErrorReporter* error_reporter) + : error_reporter_(error_reporter ? error_reporter + : DefaultErrorReporter()) { + model_ = model; +} + FlatBufferModel::~FlatBufferModel() { delete allocation_; } InterpreterBuilder::InterpreterBuilder(const FlatBufferModel& model, @@ -160,6 +175,27 @@ std::vector FlatBufferIntArrayToVector(T* flat_array) { return ret; } +// Copies the contents from the flatbuffer int vector `flatbuffer` into the +// int array `buffer`. `flat_vector` and `buffer` represent the same +// configuration operation for a given operation. +void FlatBufferIntVectorToArray(int max_size_of_buffer, + const flatbuffers::Vector* flat_vector, + int* buffer, ErrorReporter* error_reporter) { + if (!flat_vector) { + error_reporter->Report("Input array not provided for operation.\n"); + } else { + int num_dimensions = flat_vector->Length(); + if (num_dimensions > max_size_of_buffer / sizeof(int)) { + error_reporter->Report( + "Found too many dimensions in the operation's input array.\n"); + } else { + for (int i = 0; i < num_dimensions; ++i) { + buffer[i] = flat_vector->Get(i); + } + } + } +} + // Allocate a structure using C malloc, but make sure the structure is a // POD structure that doesn't require constructors to run. The reason we do // this, is that Interpreter's C extension part will take ownership and wants @@ -175,6 +211,9 @@ T* MallocPOD() { // This handles builtin data explicitly as there are flatbuffer schemas. // // Returns memory that must be feed. +// +// TODO(nupurgarg): Pass in void ** and return TfLiteStatus to ensure program +// crashes if error reporter is called. void* ParseOpData(const Operator* op, BuiltinOperator op_type, ErrorReporter* error_reporter) { auto parse_padding = [](Padding padding) { @@ -301,6 +340,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: case BuiltinOperator_RNN: { TfLiteRNNParams* params = MallocPOD(); if (auto* rnn_params = op->builtin_options_as_RNNOptions()) { @@ -417,23 +457,35 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_PAD: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_PadOptions()) { + auto* before_padding = schema_params->before_padding(); + FlatBufferIntVectorToArray(sizeof(params->before_padding), + before_padding, params->before_padding, + error_reporter); + + auto* after_padding = schema_params->after_padding(); + FlatBufferIntVectorToArray(sizeof(params->after_padding), after_padding, + params->after_padding, error_reporter); + + if (before_padding->Length() != after_padding->Length()) { + error_reporter->Report( + "Before padding and after padding arrays need to contain the " + "same number of dimensions.\n"); + } + params->num_dimensions = after_padding->Length(); + } + builtin_data = reinterpret_cast(params); + break; + } case BuiltinOperator_RESHAPE: { auto* params = MallocPOD(); if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) { auto* new_shape = schema_params->new_shape(); - if (!new_shape) { - error_reporter->Report("No new_shape provided for Reshape\n"); - } else { - params->num_dimensions = new_shape->Length(); - if (params->num_dimensions > sizeof(params->shape) / sizeof(int)) { - error_reporter->Report( - "Found too many dimensions in Reshape's new_shape\n"); - } else { - for (int i = 0; i < params->num_dimensions; ++i) { - params->shape[i] = new_shape->Get(i); - } - } - } + FlatBufferIntVectorToArray(sizeof(params->shape), new_shape, + params->shape, error_reporter); + params->num_dimensions = new_shape->Length(); } builtin_data = reinterpret_cast(params); break; @@ -456,6 +508,34 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_GATHER: { + TfLiteGatherParams* params = MallocPOD(); + params->axis = 0; + if (auto* gather_params = op->builtin_options_as_GatherOptions()) { + params->axis = gather_params->axis(); + } + + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_BATCH_TO_SPACE_ND: { + auto* params = MallocPOD(); + if (auto* schema_params = + op->builtin_options_as_BatchToSpaceNDOptions()) { + const auto& block_shape = schema_params->block_shape(); + FlatBufferIntVectorToArray(sizeof(params->block_shape), block_shape, + params->block_shape, error_reporter); + const auto& before_crops = schema_params->before_crops(); + FlatBufferIntVectorToArray(sizeof(params->before_crops), before_crops, + params->before_crops, error_reporter); + const auto& after_crops = schema_params->after_crops(); + FlatBufferIntVectorToArray(sizeof(params->after_crops), after_crops, + params->after_crops, error_reporter); + params->num_spatial_dimensions = block_shape->Length(); + } + builtin_data = reinterpret_cast(params); + break; + } } return builtin_data; } diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h index 15659d33f37dfb2f119480ed88d2e1b81f34c145..e0c96f7f0480cd3146f95a22957477809cf0096d 100644 --- a/tensorflow/contrib/lite/model.h +++ b/tensorflow/contrib/lite/model.h @@ -45,18 +45,25 @@ namespace tflite { // or mmapped. This uses flatbuffers as the serialization format. class FlatBufferModel { public: - // Build a model based on a file. Return a nullptr in case of failure. + // Builds a model based on a file. Returns a nullptr in case of failure. static std::unique_ptr BuildFromFile( const char* filename, ErrorReporter* error_reporter = DefaultErrorReporter()); - // Build a model based on a pre-loaded flatbuffer. The caller retains + // Builds a model based on a pre-loaded flatbuffer. The caller retains // ownership of the buffer and should keep it alive until the returned object - // is destroyed. Return a nullptr in case of failure. + // is destroyed. Returns a nullptr in case of failure. static std::unique_ptr BuildFromBuffer( const char* buffer, size_t buffer_size, ErrorReporter* error_reporter = DefaultErrorReporter()); + // Builds a model directly from a flatbuffer pointer. The caller retains + // ownership of the buffer and should keep it alive until the returned object + // is destroyed. Returns a nullptr in case of failure. + static std::unique_ptr BuildFromModel( + const tflite::Model* model_spec, + ErrorReporter* error_reporter = DefaultErrorReporter()); + // Releases memory or unmaps mmaped meory. ~FlatBufferModel(); @@ -75,7 +82,7 @@ class FlatBufferModel { bool CheckModelIdentifier() const; private: - // Load a model from `filename`. If `mmap_file` is true then use mmap, + // Loads a model from `filename`. If `mmap_file` is true then use mmap, // otherwise make a copy of the model in a buffer. // // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be @@ -85,8 +92,8 @@ class FlatBufferModel { ErrorReporter* error_reporter = DefaultErrorReporter(), bool use_nnapi = false); - // Load a model from `ptr` and `num_bytes` of the model file. The `ptr` has to - // remain alive and unchanged until the end of this flatbuffermodel's + // Loads a model from `ptr` and `num_bytes` of the model file. The `ptr` has + // to remain alive and unchanged until the end of this flatbuffermodel's // lifetime. // // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be @@ -94,6 +101,10 @@ class FlatBufferModel { FlatBufferModel(const char* ptr, size_t num_bytes, ErrorReporter* error_reporter = DefaultErrorReporter()); + // Loads a model from Model flatbuffer. The `model` has to remain alive and + // unchanged until the end of this flatbuffermodel's lifetime. + FlatBufferModel(const Model* model, ErrorReporter* error_reporter); + // Flatbuffer traverser pointer. (Model* is a pointer that is within the // allocated memory of the data allocated by allocation's internals. const tflite::Model* model_ = nullptr; @@ -106,9 +117,9 @@ class FlatBufferModel { // model are mapped to executable function pointers (TfLiteRegistrations). class OpResolver { public: - // Find the op registration for a builtin operator by enum code. + // Finds the op registration for a builtin operator by enum code. virtual TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const = 0; - // Find the op registration of a custom operator by op name. + // Finds the op registration of a custom operator by op name. virtual TfLiteRegistration* FindOp(const char* op) const = 0; virtual ~OpResolver() {} }; @@ -131,7 +142,7 @@ class InterpreterBuilder { public: InterpreterBuilder(const FlatBufferModel& model, const OpResolver& op_resolver); - // Build an interpreter given only the raw flatbuffer Model object (instead + // Builds an interpreter given only the raw flatbuffer Model object (instead // of a FlatBufferModel). Mostly used for testing. // If `error_reporter` is null, then DefaultErrorReporter() is used. InterpreterBuilder(const ::tflite::Model* model, diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc index 61043866420752b552281e353be9a2b41a6aadc8..5330c8f594593655b2a8776cf6b399c0d16cdc19 100644 --- a/tensorflow/contrib/lite/model_test.cc +++ b/tensorflow/contrib/lite/model_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include #include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/testing/util.h" // Comparison for TfLiteRegistration. Since TfLiteRegistration is a C object, // we must declare this in global namespace, so argument-dependent operator @@ -254,6 +255,28 @@ TEST(BasicFlatBufferModel, TestBuildModelFromCorruptedData) { ASSERT_FALSE(model); } +// Test that loading model directly from a Model flatbuffer works. +TEST(BasicFlatBufferModel, TestBuildFromModel) { + TestErrorReporter reporter; + FileCopyAllocation model_allocation( + "tensorflow/contrib/lite/testdata/test_model.bin", &reporter); + ASSERT_TRUE(model_allocation.valid()); + ::flatbuffers::Verifier verifier( + reinterpret_cast(model_allocation.base()), + model_allocation.bytes()); + ASSERT_TRUE(VerifyModelBuffer(verifier)); + const Model* model_fb = ::tflite::GetModel(model_allocation.base()); + + auto model = FlatBufferModel::BuildFromModel(model_fb); + ASSERT_TRUE(model); + + std::unique_ptr interpreter; + ASSERT_EQ( + InterpreterBuilder(*model, TrivialResolver(&dummy_reg))(&interpreter), + kTfLiteOk); + ASSERT_NE(interpreter, nullptr); +} + // TODO(aselle): Add tests for serialization of builtin op data types. // These tests will occur with the evaluation tests of individual operators, // not here. @@ -261,7 +284,7 @@ TEST(BasicFlatBufferModel, TestBuildModelFromCorruptedData) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/models/smartreply/BUILD b/tensorflow/contrib/lite/models/smartreply/BUILD index fbdf19f2054cf01aec44e3fcb13d0d0a2ff6f914..733c3f4c7fa0605f24a1e6b4c458e34310c079c4 100644 --- a/tensorflow/contrib/lite/models/smartreply/BUILD +++ b/tensorflow/contrib/lite/models/smartreply/BUILD @@ -1,7 +1,92 @@ package(default_visibility = ["//visibility:public"]) +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops") + licenses(["notice"]) # Apache 2.0 +gen_selected_ops( + name = "smartreply_ops", + model = "@tflite_smartreply//:smartreply.tflite", +) + +cc_library( + name = "custom_ops", + srcs = [ + "ops/extract_feature.cc", + "ops/normalize.cc", + "ops/predict.cc", + ":smartreply_ops", + ], + copts = tflite_copts(), + deps = [ + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/tools:mutable_op_resolver", + "@com_google_absl//absl/strings", + "@com_googlesource_code_re2//:re2", + "@farmhash_archive//:farmhash", + ], +) + +cc_library( + name = "predictor_lib", + srcs = ["predictor.cc"], + hdrs = ["predictor.h"], + copts = tflite_copts(), + deps = [ + ":custom_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/tools:mutable_op_resolver", + "@com_google_absl//absl/strings", + "@com_googlesource_code_re2//:re2", + ], +) + +cc_test( + name = "extract_feature_op_test", + size = "small", + srcs = ["ops/extract_feature_test.cc"], + deps = [ + ":custom_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + "@farmhash_archive//:farmhash", + ], +) + +cc_test( + name = "normalize_op_test", + size = "small", + srcs = ["ops/normalize_test.cc"], + deps = [ + ":custom_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_test( + name = "predict_op_test", + size = "small", + srcs = ["ops/predict_test.cc"], + deps = [ + ":custom_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/AndroidManifest.xml b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..75ed9432c8fcdfd77a64d3c659e6336c977cdda2 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/AndroidManifest.xml @@ -0,0 +1,38 @@ + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..f8767b443a2aa64b666c3b6bfb7db30cc0be62ea --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD @@ -0,0 +1,65 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow/contrib/lite:build_def.bzl", + "tflite_copts", + "tflite_jni_binary", +) + +filegroup( + name = "assets", + srcs = [ + "@tflite_smartreply//:model_files", + ], +) + +android_binary( + name = "SmartReplyDemo", + srcs = glob(["java/**/*.java"]), + assets = [":assets"], + assets_dir = "", + custom_package = "com.example.android.smartreply", + manifest = "AndroidManifest.xml", + nocompress_extensions = [ + ".tflite", + ], + resource_files = glob(["res/**"]), + tags = ["manual"], + deps = [ + ":smartreply_runtime", + "@androidsdk//com.android.support:support-v13-25.2.0", + "@androidsdk//com.android.support:support-v4-25.2.0", + ], +) + +cc_library( + name = "smartreply_runtime", + srcs = ["libsmartreply_jni.so"], + visibility = ["//visibility:public"], +) + +tflite_jni_binary( + name = "libsmartreply_jni.so", + deps = [ + ":smartreply_jni_lib", + ], +) + +cc_library( + name = "smartreply_jni_lib", + srcs = [ + "smartreply_jni.cc", + ], + copts = tflite_copts(), + linkopts = [ + "-lm", + "-ldl", + ], + deps = [ + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/models/smartreply:predictor_lib", + ], + alwayslink = 1, +) diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/BUILD b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..3c882ffc43fde577801428151a43b592e8faaed1 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/BUILD @@ -0,0 +1,15 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(glob(["*"])) + +filegroup( + name = "assets_files", + srcs = glob( + ["**/*"], + exclude = [ + "BUILD", + ], + ), +) diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/backoff_response.txt b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/backoff_response.txt new file mode 100644 index 0000000000000000000000000000000000000000..a0a5b46b5f8d5fd6a0297c8056bb2fb9b6ad9ada --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/backoff_response.txt @@ -0,0 +1,16 @@ +Ok +Yes +No +👍 +☺ +😟 +❤️ +Lol +Thanks +Got it +Done +Nice +I don't know +What? +Why? +What's up? diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/MainActivity.java b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/MainActivity.java new file mode 100644 index 0000000000000000000000000000000000000000..02fec9ae5e971ad756ae6c2b0149a6aacfa27cad --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/MainActivity.java @@ -0,0 +1,99 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.smartreply; + +import android.app.Activity; +import android.os.Bundle; +import android.os.Handler; +import android.util.Log; +import android.view.View; +import android.widget.Button; +import android.widget.EditText; +import android.widget.TextView; + +/** + * The main (and only) activity of this demo app. Displays a text box which updates as messages are + * received. + */ +public class MainActivity extends Activity { + private static final String TAG = "SmartReplyDemo"; + private SmartReplyClient client; + + private Button sendButton; + private TextView messageTextView; + private EditText messageInput; + + private Handler handler; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + Log.v(TAG, "onCreate"); + setContentView(R.layout.main_activity); + + client = new SmartReplyClient(getApplicationContext()); + handler = new Handler(); + + sendButton = (Button) findViewById(R.id.send_button); + sendButton.setOnClickListener( + (View v) -> { + send(messageInput.getText().toString()); + }); + + messageTextView = (TextView) findViewById(R.id.message_text); + messageInput = (EditText) findViewById(R.id.message_input); + } + + @Override + protected void onStart() { + super.onStart(); + Log.v(TAG, "onStart"); + handler.post( + () -> { + client.loadModel(); + }); + } + + @Override + protected void onStop() { + super.onStop(); + Log.v(TAG, "onStop"); + handler.post( + () -> { + client.unloadModel(); + }); + } + + private void send(final String message) { + handler.post( + () -> { + messageTextView.append("Input: " + message + "\n"); + + SmartReply[] ans = client.predict(new String[] {message}); + for (SmartReply reply : ans) { + appendMessage("Reply: " + reply.getText()); + } + appendMessage("------"); + }); + } + + private void appendMessage(final String message) { + handler.post( + () -> { + messageTextView.append(message + "\n"); + }); + } +} diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReply.java b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReply.java new file mode 100644 index 0000000000000000000000000000000000000000..3357fd17c11f870d1b0998bb26ffa9abf149686b --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReply.java @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.smartreply; + +import android.support.annotation.Keep; + +/** + * SmartReply contains predicted message, and confidence. + * + *

NOTE: this class used by JNI, class name and constructor should not be obfuscated. + */ +@Keep +public class SmartReply { + + private final String text; + private final float score; + + @Keep + public SmartReply(String text, float score) { + this.text = text; + this.score = score; + } + + public String getText() { + return text; + } + + public float getScore() { + return score; + } +} diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java new file mode 100644 index 0000000000000000000000000000000000000000..d5b1ac0ffbc47283aa0c1bf68c0a85ad6228cdcc --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java @@ -0,0 +1,129 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.smartreply; + +import android.content.Context; +import android.content.res.AssetFileDescriptor; +import android.support.annotation.Keep; +import android.support.annotation.WorkerThread; +import android.util.Log; +import java.io.BufferedReader; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.util.ArrayList; +import java.util.List; + +/** Interface to load TfLite model and provide predictions. */ +public class SmartReplyClient implements AutoCloseable { + private static final String TAG = "SmartReplyDemo"; + private static final String MODEL_PATH = "smartreply.tflite"; + private static final String BACKOFF_PATH = "backoff_response.txt"; + private static final String JNI_LIB = "smartreply_jni"; + + private final Context context; + private long storage; + private MappedByteBuffer model; + + private volatile boolean isLibraryLoaded; + + public SmartReplyClient(Context context) { + this.context = context; + } + + public boolean isLoaded() { + return storage != 0; + } + + @WorkerThread + public synchronized void loadModel() { + if (!isLibraryLoaded) { + System.loadLibrary(JNI_LIB); + isLibraryLoaded = true; + } + + try { + model = loadModelFile(); + String[] backoff = loadBackoffList(); + storage = loadJNI(model, backoff); + } catch (IOException e) { + Log.e(TAG, "Fail to load model", e); + return; + } + } + + @WorkerThread + public synchronized SmartReply[] predict(String[] input) { + if (storage != 0) { + return predictJNI(storage, input); + } else { + return new SmartReply[] {}; + } + } + + @WorkerThread + public synchronized void unloadModel() { + close(); + } + + @Override + public synchronized void close() { + if (storage != 0) { + unloadJNI(storage); + storage = 0; + } + } + + private MappedByteBuffer loadModelFile() throws IOException { + AssetFileDescriptor fileDescriptor = context.getAssets().openFd(MODEL_PATH); + FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); + try { + FileChannel fileChannel = inputStream.getChannel(); + long startOffset = fileDescriptor.getStartOffset(); + long declaredLength = fileDescriptor.getDeclaredLength(); + return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); + } finally { + inputStream.close(); + } + } + + private String[] loadBackoffList() throws IOException { + List labelList = new ArrayList(); + BufferedReader reader = + new BufferedReader(new InputStreamReader(context.getAssets().open(BACKOFF_PATH))); + String line; + while ((line = reader.readLine()) != null) { + if (!line.isEmpty()) { + labelList.add(line); + } + } + reader.close(); + String[] ans = new String[labelList.size()]; + labelList.toArray(ans); + return ans; + } + + @Keep + private native long loadJNI(MappedByteBuffer buffer, String[] backoff); + + @Keep + private native SmartReply[] predictJNI(long storage, String[] text); + + @Keep + private native void unloadJNI(long storage); +} diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/res/layout/main_activity.xml b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/res/layout/main_activity.xml new file mode 100644 index 0000000000000000000000000000000000000000..23b4cadc007a4457d33b8c8fecf9b1e7b7436320 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/res/layout/main_activity.xml @@ -0,0 +1,44 @@ + + + + + + + + + + +